diff --git a/conn.go b/conn.go index 26d7e01f..74039fb5 100644 --- a/conn.go +++ b/conn.go @@ -1060,16 +1060,16 @@ func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, return "", err } - if err := validatePath(path, createMode.sequential); err != nil { + if err := validatePath(path, createMode.isSequential); err != nil { return "", err } if createMode.isTTL { - return "", ErrInvalidFlags + return "", fmt.Errorf("Create with TTL flag disallowed :%w", ErrInvalidFlags) } res := &createResponse{} - _, err = c.request(opCreate, &CreateRequest{path, data, acl, createMode.toFlag()}, res, nil) + _, err = c.request(opCreate, &CreateRequest{path, data, acl, createMode.flag}, res, nil) if err == ErrConnectionClosed { return "", err } @@ -1086,16 +1086,16 @@ func (c *Conn) CreateContainer(path string, data []byte, flag int32, acl []ACL) return "", err } - if err := validatePath(path, createMode.sequential); err != nil { + if err := validatePath(path, createMode.isSequential); err != nil { return "", err } if !createMode.isContainer { - return "", ErrInvalidFlags + return "", fmt.Errorf("CreateContainer requires container flag :%w", ErrInvalidFlags) } res := &createResponse{} - _, err = c.request(opCreateContainer, &CreateRequest{path, data, acl, createMode.toFlag()}, res, nil) + _, err = c.request(opCreateContainer, &CreateRequest{path, data, acl, createMode.flag}, res, nil) return res.Path, err } @@ -1106,16 +1106,16 @@ func (c *Conn) CreateTTL(path string, data []byte, flag int32, acl []ACL, ttl ti return "", err } - if err := validatePath(path, createMode.sequential); err != nil { + if err := validatePath(path, createMode.isSequential); err != nil { return "", err } if !createMode.isTTL { - return "", ErrInvalidFlags + return "", fmt.Errorf("CreateTTL requires TTL flag :%w", ErrInvalidFlags) } res := &createResponse{} - _, err = c.request(opCreateTTL, &CreateTTLRequest{path, data, acl, createMode.toFlag(), ttl.Milliseconds()}, res, nil) + _, err = c.request(opCreateTTL, &CreateTTLRequest{path, data, acl, createMode.flag, ttl.Milliseconds()}, res, nil) return res.Path, err } diff --git a/create_mode.go b/create_mode.go index 8e1eb0d9..0705ed61 100644 --- a/create_mode.go +++ b/create_mode.go @@ -2,6 +2,7 @@ package zk import "fmt" +// TODO: (v2) enum type for CreateMode API. const ( FlagPersistent = 0 FlagEphemeral = 1 @@ -13,15 +14,11 @@ const ( ) type createMode struct { - flag int32 - ephemeral bool - sequential bool - isContainer bool - isTTL bool -} - -func (cm *createMode) toFlag() int32 { - return cm.flag + flag int32 + isEphemeral bool + isSequential bool + isContainer bool + isTTL bool } // parsing a flag integer into the CreateMode needed to call the correct diff --git a/create_mode_test.go b/create_mode_test.go index ff4dc158..0d2d08c6 100644 --- a/create_mode_test.go +++ b/create_mode_test.go @@ -23,9 +23,9 @@ func TestParseCreateMode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cm, err := parseCreateMode(tt.flag) requireNoErrorf(t, err) - if cm.toFlag() != tt.wantIntValue { + if cm.flag != tt.wantIntValue { // change detector test for enum values. - t.Fatalf("createmode value of flag; want: %v, got: %v", cm.toFlag(), tt.wantIntValue) + t.Fatalf("createmode value of flag; want: %v, got: %v", cm.flag, tt.wantIntValue) } }) } diff --git a/zk_test.go b/zk_test.go index e9c03254..f6339631 100644 --- a/zk_test.go +++ b/zk_test.go @@ -3,6 +3,7 @@ package zk import ( "context" "encoding/hex" + "errors" "fmt" "io" "math/rand" @@ -152,60 +153,66 @@ func TestIntegration_CreateTTL(t *testing.T) { specifiedPath string giveDuration time.Duration wantErr string + wantErrIs error }{ { - name: "valid create ttl", - createFlags: FlagTTL, - giveDuration: time.Minute, + name: "valid create ttl", + specifiedPath: "/test-valid-create-ttl", + createFlags: FlagTTL, + giveDuration: time.Minute, }, { - name: "valid change detector", - createFlags: 5, - giveDuration: time.Minute, + name: "valid change detector", + specifiedPath: "/test-valid-change", + createFlags: 5, + giveDuration: time.Minute, }, { name: "invalid path", createFlags: FlagTTL, specifiedPath: "not/valid", wantErr: "zk: invalid path", + wantErrIs: ErrInvalidPath, }, { - name: "invalid container with ttl", - createFlags: FlagContainer, - wantErr: "zk: invalid flags specified", + name: "invalid container with ttl", + specifiedPath: "/test-invalid-flags", + createFlags: FlagContainer, + wantErr: "zk: invalid flags specified", + wantErrIs: ErrInvalidFlags, }, { - name: "invalid flag for create mode", - createFlags: 999, - giveDuration: time.Minute, - wantErr: "invalid flag value: [999]", + name: "invalid flag for create mode", + specifiedPath: "/test-invalid-mode", + createFlags: 999, + giveDuration: time.Minute, + wantErr: "invalid flag value: [999]", }, } - const testPath = "/ttl_znode_tests" - // create sub node to create per test in avoiding using the root path. - _, err = zk.Create(testPath, nil /* data */, FlagPersistent, WorldACL(PermAll)) - requireNoErrorf(t, err) - - for idx, tt := range tests { + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - path := filepath.Join(testPath, fmt.Sprint(idx)) - if tt.specifiedPath != "" { - path = tt.specifiedPath + if tt.specifiedPath == "" { + t.Fatalf("path for test required: %v", tt.name) } - _, err := zk.CreateTTL(path, []byte{12}, tt.createFlags, WorldACL(PermAll), tt.giveDuration) + _, err := zk.CreateTTL(tt.specifiedPath, []byte{12}, tt.createFlags, WorldACL(PermAll), tt.giveDuration) if tt.wantErr == "" { - requireNoErrorf(t, err, fmt.Sprintf("error not expected: path; %q; flags %v", path, tt.createFlags)) + requireNoErrorf(t, err, + fmt.Sprintf("error not expected: path; %q; flags %v", tt.specifiedPath, tt.createFlags)) return } - // want an error + if tt.wantErrIs != nil { + if !errors.Is(err, tt.wantErrIs) { + t.Errorf("error expected Is: %q", tt.wantErr) + } + } if err == nil { - t.Fatalf("did not get expected error: %q", tt.wantErr) + t.Errorf("did not get expected error: %q", tt.wantErr) } if !strings.Contains(err.Error(), tt.wantErr) { - t.Fatalf("wanted error not found: %v; got: %v", tt.wantErr, err.Error()) + t.Errorf("wanted error not found: %v; got: %v", tt.wantErr, err.Error()) } }) }