diff --git a/command.go b/command.go index 1fc32558b..fd07d82a6 100644 --- a/command.go +++ b/command.go @@ -5057,3 +5057,112 @@ func parseClientInfo(txt string) (info *ClientInfo, err error) { return info, nil } + +// ------------------------------------------- + +type ACLLogEntry struct { + Count int64 + Reason string + Context string + Object string + Username string + AgeSeconds float64 + ClientInfo *ClientInfo + EntryID int64 + TimestampCreated int64 + TimestampLastUpdated int64 +} + +type ACLLogCmd struct { + baseCmd + + val []*ACLLogEntry +} + +var _ Cmder = (*ACLLogCmd)(nil) + +func NewACLLogCmd(ctx context.Context, args ...interface{}) *ACLLogCmd { + return &ACLLogCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *ACLLogCmd) SetVal(val []*ACLLogEntry) { + cmd.val = val +} + +func (cmd *ACLLogCmd) Val() []*ACLLogEntry { + return cmd.val +} + +func (cmd *ACLLogCmd) Result() ([]*ACLLogEntry, error) { + return cmd.Val(), cmd.Err() +} + +func (cmd *ACLLogCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *ACLLogCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + cmd.val = make([]*ACLLogEntry, n) + for i := 0; i < n; i++ { + cmd.val[i] = &ACLLogEntry{} + entry := cmd.val[i] + respLen, err := rd.ReadMapLen() + if err != nil { + return err + } + for j := 0; j < respLen; j++ { + key, err := rd.ReadString() + if err != nil { + return err + } + + switch key { + case "count": + entry.Count, err = rd.ReadInt() + case "reason": + entry.Reason, err = rd.ReadString() + case "context": + entry.Context, err = rd.ReadString() + case "object": + entry.Object, err = rd.ReadString() + case "username": + entry.Username, err = rd.ReadString() + case "age-seconds": + entry.AgeSeconds, err = rd.ReadFloat() + case "client-info": + txt, err := rd.ReadString() + if err != nil { + return err + } + entry.ClientInfo, err = parseClientInfo(strings.TrimSpace(txt)) + if err != nil { + return err + } + case "entry-id": + entry.EntryID, err = rd.ReadInt() + case "timestamp-created": + entry.TimestampCreated, err = rd.ReadInt() + case "timestamp-last-updated": + entry.TimestampLastUpdated, err = rd.ReadInt() + default: + return fmt.Errorf("redis: unexpected key %q in ACL LOG reply", key) + } + + if err != nil { + return err + } + } + } + + return nil +} diff --git a/commands.go b/commands.go index 1a33c1265..34f4d2c22 100644 --- a/commands.go +++ b/commands.go @@ -500,6 +500,8 @@ type Cmdable interface { GeoHash(ctx context.Context, key string, members ...string) *StringSliceCmd ACLDryRun(ctx context.Context, username string, command ...interface{}) *StringCmd + ACLLog(ctx context.Context, count int64) *ACLLogCmd + ACLLogReset(ctx context.Context) *StatusCmd ModuleLoadex(ctx context.Context, conf *ModuleLoadexConfig) *StringCmd } @@ -3946,3 +3948,20 @@ func (c cmdable) ModuleLoadex(ctx context.Context, conf *ModuleLoadexConfig) *St _ = c(ctx, cmd) return cmd } + +func (c cmdable) ACLLog(ctx context.Context, count int64) *ACLLogCmd { + args := make([]interface{}, 0, 3) + args = append(args, "acl", "log") + if count > 0 { + args = append(args, count) + } + cmd := NewACLLogCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) ACLLogReset(ctx context.Context) *StatusCmd { + cmd := NewStatusCmd(ctx, "acl", "log", "reset") + _ = c(ctx, cmd) + return cmd +} diff --git a/commands_test.go b/commands_test.go index a2abe1b0f..914a31494 100644 --- a/commands_test.go +++ b/commands_test.go @@ -1985,6 +1985,55 @@ var _ = Describe("Commands", func() { Expect(args).To(Equal(expectedArgs)) }) + + It("should ACL LOG", func() { + + err := client.Do(ctx, "acl", "setuser", "test", ">test", "on", "allkeys", "+get").Err() + Expect(err).NotTo(HaveOccurred()) + + clientAcl := redis.NewClient(redisOptions()) + clientAcl.Options().Username = "test" + clientAcl.Options().Password = "test" + clientAcl.Options().DB = 0 + _ = clientAcl.Set(ctx, "mystring", "foo", 0).Err() + _ = clientAcl.HSet(ctx, "myhash", "foo", "bar").Err() + _ = clientAcl.SAdd(ctx, "myset", "foo", "bar").Err() + + logEntries, err := client.ACLLog(ctx, 10).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(len(logEntries)).To(Equal(3)) + + for _, entry := range logEntries { + Expect(entry.Count).To(BeNumerically("==", 1)) + Expect(entry.Reason).To(Equal("command")) + Expect(entry.Context).To(Equal("toplevel")) + Expect(entry.Object).NotTo(BeEmpty()) + Expect(entry.Username).To(Equal("test")) + Expect(entry.AgeSeconds).To(BeNumerically(">=", 0)) + Expect(entry.ClientInfo).NotTo(BeNil()) + Expect(entry.EntryID).To(BeNumerically(">=", 0)) + Expect(entry.TimestampCreated).To(BeNumerically(">=", 0)) + Expect(entry.TimestampLastUpdated).To(BeNumerically(">=", 0)) + } + + limitedLogEntries, err := client.ACLLog(ctx, 2).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(len(limitedLogEntries)).To(Equal(2)) + + }) + + It("should ACL LOG RESET", func() { + // Call ACL LOG RESET + resetCmd := client.ACLLogReset(ctx) + Expect(resetCmd.Err()).NotTo(HaveOccurred()) + Expect(resetCmd.Val()).To(Equal("OK")) + + // Verify that the log is empty after the reset + logEntries, err := client.ACLLog(ctx, 10).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(len(logEntries)).To(Equal(0)) + }) + }) Describe("hashes", func() {