Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHANGED] Reject channels with different case (Foo vs foo) #1274

Merged
merged 1 commit into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions server/clustering.go
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ func (r *raftFSM) lookupOrCreateChannel(name string, id uint64) (*channel, error
return nil, err
}
delete(cs.channels, name)
delete(s.channels.channelsLC, strings.ToLower(name))
}
// Channel does exist or has been deleted. Create now with given ID.
return cs.createChannelLocked(s, name, id)
Expand Down
32 changes: 26 additions & 6 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,18 @@ func (state State) String() string {

type channelStore struct {
sync.RWMutex
channels map[string]*channel
store stores.Store
stan *StanServer
channels map[string]*channel
channelsLC map[string]*channel
store stores.Store
stan *StanServer
}

func newChannelStore(srv *StanServer, s stores.Store) *channelStore {
cs := &channelStore{
channels: make(map[string]*channel),
store: s,
stan: srv,
channels: make(map[string]*channel),
channelsLC: make(map[string]*channel),
store: s,
stan: srv,
}
return cs
}
Expand Down Expand Up @@ -313,6 +315,13 @@ func (cs *channelStore) createChannel(s *StanServer, name string) (*channel, err
return c, err
}

func (cs *channelStore) checkCase(name string) error {
if c := cs.channelsLC[strings.ToLower(name)]; c != nil {
return fmt.Errorf("rejecting channel %q because channel %q alreay exists (different cases not allowed)", name, c.name)
}
return nil
}

func (cs *channelStore) createChannelLocked(s *StanServer, name string, id uint64) (retChan *channel, retErr error) {
defer func() {
if retErr != nil {
Expand All @@ -329,6 +338,8 @@ func (cs *channelStore) createChannelLocked(s *StanServer, name string, id uint6
return nil, ErrChanDelInProgress
}
return c, nil
} else if err := cs.checkCase(name); err != nil {
return nil, err
}
if s.isClustered {
if s.isLeader() && id == 0 {
Expand Down Expand Up @@ -370,6 +381,7 @@ func (cs *channelStore) create(s *StanServer, name string, sc *stores.Channel) (
}
c.nextSequence = lastSequence + 1
cs.channels[name] = c
cs.channelsLC[strings.ToLower(name)] = c
cl := cs.store.GetChannelLimits(name)
if cl.MaxInactivity > 0 {
c.activity = &channelActivity{maxInactivity: cl.MaxInactivity}
Expand Down Expand Up @@ -900,6 +912,9 @@ func (s *StanServer) lookupOrCreateChannel(name string) (*channel, error) {
}
cs.RUnlock()
return c, nil
} else if err := cs.checkCase(name); err != nil {
cs.RUnlock()
return nil, err
}
cs.RUnlock()
return cs.createChannel(s, name)
Expand All @@ -914,6 +929,9 @@ func (s *StanServer) lookupOrCreateChannelPreventDelete(name string) (*channel,
cs.Unlock()
return nil, false, ErrChanDelInProgress
}
} else if err := cs.checkCase(name); err != nil {
cs.Unlock()
return nil, false, err
} else {
var err error
c, err = cs.createChannelLocked(s, name, 0)
Expand Down Expand Up @@ -3127,6 +3145,7 @@ func (s *StanServer) processDeleteChannel(channel string) {
return
}
delete(s.channels.channels, channel)
delete(s.channels.channelsLC, strings.ToLower(channel))
s.log.Noticef("Channel %q has been deleted", channel)
}

Expand Down Expand Up @@ -5287,6 +5306,7 @@ func (s *StanServer) processSubscriptionRequest(m *nats.Msg) {
}
}
if err != nil {
s.log.Errorf("Unable to create subscription on %q: %v", sr.Subject, err)
s.channels.turnOffPreventDelete(c)
s.channels.maybeStartChannelDeleteTimer(sr.Subject, c)
s.sendSubscriptionResponseErr(m.Reply, err)
Expand Down
146 changes: 146 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1703,3 +1703,149 @@ func TestInternalSubsLimits(t *testing.T) {
})
}
}

func TestChannelNameRejectedIfAlreadyExistsWithDifferentCase(t *testing.T) {
for _, tinfo := range []struct {
name string
st string
restart bool
}{
{"memory", stores.TypeMemory, false},
{"file", stores.TypeFile, true},
{"sql", stores.TypeSQL, true},
{"clustered", stores.TypeFile, true},
} {
t.Run(tinfo.name, func(t *testing.T) {
var o *Options
if tinfo.st == stores.TypeSQL {
if !doSQL {
t.SkipNow()
}
}
// Force persistent store to be the tinfo.st for this test.
orgps := persistentStoreType
persistentStoreType = tinfo.st
defer func() { persistentStoreType = orgps }()

if tinfo.st == stores.TypeSQL || tinfo.st == stores.TypeFile {
o = getTestDefaultOptsForPersistentStore()
} else if tinfo.st == stores.TypeMemory {
o = GetDefaultOptions()
}

cleanupDatastore(t)
defer cleanupDatastore(t)
cleanupRaftLog(t)
defer cleanupRaftLog(t)

var servers []*StanServer
if tinfo.name == "clustered" {
ns := natsdTest.RunDefaultServer()
defer ns.Shutdown()

o1 := getTestDefaultOptsForClustering("a", true)
servers = append(servers, runServerWithOpts(t, o1, nil))
o2 := getTestDefaultOptsForClustering("b", false)
servers = append(servers, runServerWithOpts(t, o2, nil))
o3 := getTestDefaultOptsForClustering("c", false)
servers = append(servers, runServerWithOpts(t, o3, nil))
} else {
servers = append(servers, runServerWithOpts(t, o, nil))
}
for _, s := range servers {
defer s.Shutdown()
}

sc := NewDefaultConnection(t)
defer sc.Close()

sendOK := func(channel, content string) {
t.Helper()
if err := sc.Publish(channel, []byte(content)); err != nil {
t.Fatalf("Error on send: %v", err)
}
}
sendFail := func(channel, content string) {
t.Helper()
err := sc.Publish(channel, []byte(content))
if err == nil || !strings.Contains(err.Error(), "exists") {
t.Fatalf("Expected error that channel already exists, got: %v", err)
}
}
sendOK("Foo", "1")
sendOK("Foo", "2")
sendOK("Foo", "3")
sendOK("Foo", "4")
// Change channel name case
sendFail("foo", "1")
sendFail("foo", "2")
// Back to "Foo"
sendOK("Foo", "5")
sendOK("Foo", "6")

recvOK := func(channel string) {
t.Helper()

ch := make(chan *stan.Msg, 6)
sub, err := sc.Subscribe(channel, func(m *stan.Msg) {
ch <- m
}, stan.DeliverAllAvailable())
if err != nil {
t.Fatalf("Error on subscribe: %v", err)
}
defer sub.Unsubscribe()

// We want to get all 6 messages
for i := 0; i < 6; i++ {
select {
case m := <-ch:
if v, err := strconv.ParseInt(string(m.Data), 10, 64); err != nil || int(v) != i+1 {
t.Fatalf("Invalid message %v: %s", i+1, m.Data)
}
case <-time.After(time.Second):
t.Fatalf("Failed receiving message %v", i+1)
}
}
}
recvFail := func(channel string) {
t.Helper()

_, err := sc.Subscribe(channel, func(m *stan.Msg) {}, stan.DeliverAllAvailable())
if err == nil || !strings.Contains(err.Error(), "exists") {
t.Fatalf("Expected error that channel already exists, got %v", err)
}
}
recvOK("Foo")
recvFail("foo")
recvFail("FoO")

if !tinfo.restart {
return
}

sc.Close()

for i, s := range servers {
s.Shutdown()
s.mu.RLock()
opts := s.opts
s.mu.RUnlock()
s = runServerWithOpts(t, opts, nil)
defer s.Shutdown()
servers[i] = s
}

if tinfo.name == "clustered" {
getLeader(t, 10*time.Second, servers...)
}

sc = NewDefaultConnection(t)
defer sc.Close()

// Try to receive again, but change the order...
recvFail("foo")
recvOK("Foo")
recvFail("FoO")
})
}
}