diff --git a/go.mod b/go.mod index e0d594b8c0..ef365f58c1 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dghubble/trie v0.0.0-20230228185955-dca8fa4fd7f8 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect github.com/golang/mock v1.6.0 // indirect diff --git a/go.sum b/go.sum index 2a9e838e3b..770fb75659 100644 --- a/go.sum +++ b/go.sum @@ -86,6 +86,8 @@ github.com/cristalhq/jwt/v4 v4.0.2/go.mod h1:HnYraSNKDRag1DZP92rYHyrjyQHnVEHPNqe github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dghubble/trie v0.0.0-20230228185955-dca8fa4fd7f8 h1:hyXHmCxADfwpski6P8WFoZg5ssJKWwrk34wS4wTar9g= +github.com/dghubble/trie v0.0.0-20230228185955-dca8fa4fd7f8/go.mod h1:sOmnzfBNH7H92ow2292dDFWNsVQuh/izuD7otCYb1ak= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= diff --git a/internal/router/config.go b/internal/router/config.go new file mode 100644 index 0000000000..8195d205d6 --- /dev/null +++ b/internal/router/config.go @@ -0,0 +1,8 @@ +package router + +type Config struct { + Routes []struct { + Name string `mapstructure:"name" json:"name"` + Addresses []string `mapstructure:"addresses" json:"addresses"` + } `mapstructure:"routes" json:"routes"` +} diff --git a/internal/router/router.go b/internal/router/router.go new file mode 100644 index 0000000000..def73543ac --- /dev/null +++ b/internal/router/router.go @@ -0,0 +1,60 @@ +package router + +import ( + "github.com/centrifugal/centrifuge" + + "github.com/dghubble/trie" +) + +type Router struct { + node *centrifuge.Node + exactRoutes map[string]any + prefixRoutes *trie.RuneTrie +} + +func New(n *centrifuge.Node) *Router { + return &Router{ + node: n, + exactRoutes: make(map[string]any), + prefixRoutes: trie.NewRuneTrie(), + } +} + +func (r *Router) AddExact(channel string, value any) { + r.exactRoutes[channel] = value +} + +func (r *Router) AddPrefix(channelPrefix string, value any) { + _ = r.prefixRoutes.Put(channelPrefix, value) +} + +func (r *Router) Find(channel string) any { + if value, ok := r.exactRoutes[channel]; ok { + if r.node.LogEnabled(centrifuge.LogLevelTrace) { + r.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelTrace, "exact route for channel", map[string]any{"channel": channel})) + } + return value + } + + var value any + _ = r.prefixRoutes.WalkPath(channel, func(key string, val any) error { + if val != nil { + if r.node.LogEnabled(centrifuge.LogLevelTrace) { + r.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelTrace, "prefix route for channel", map[string]any{"channel": channel})) + } + value = val + } + return nil + }) + + if value == nil { + if value, ok := r.exactRoutes["__default"]; ok { + if r.node.LogEnabled(centrifuge.LogLevelTrace) { + r.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelTrace, "default route for channel", map[string]any{"channel": channel})) + } + return value + } + } + + return value +} diff --git a/internal/tntengine/broker.go b/internal/tntengine/broker.go index 40f8c7e0bd..629ffd9fd4 100644 --- a/internal/tntengine/broker.go +++ b/internal/tntengine/broker.go @@ -37,6 +37,11 @@ type Broker struct { config BrokerConfig shards []*Shard nodeChannel string + router Router +} + +type Router interface { + Find(channel string) any } var _ centrifuge.Broker = (*Broker)(nil) @@ -60,6 +65,9 @@ type BrokerConfig struct { // Shards is a list of Tarantool instances to shard data by channel. Shards []*Shard + + // Router is a mapper of channels to Tarantool instances subset. + Router Router } // NewBroker initializes Tarantool Broker. @@ -76,6 +84,7 @@ func NewBroker(n *centrifuge.Node, config BrokerConfig) (*Broker, error) { config: config, sharding: len(config.Shards) > 1, nodeChannel: nodeChannel(n.ID()), + router: config.Router, } return e, nil } @@ -152,7 +161,7 @@ func (m *pubResponse) DecodeMsgpack(d *msgpack.Decoder) error { // Publish - see centrifuge.Broker interface description. func (b *Broker) Publish(ch string, data []byte, opts centrifuge.PublishOptions) (centrifuge.StreamPosition, error) { - s := consistentShard(ch, b.shards) + s := getShard(ch, b.getShards(ch)) protoPub := &protocol.Publication{ Data: data, @@ -187,7 +196,7 @@ func (b *Broker) Publish(ch string, data []byte, opts centrifuge.PublishOptions) // PublishJoin - see centrifuge.Broker interface description. func (b *Broker) PublishJoin(ch string, info *centrifuge.ClientInfo) error { - s := consistentShard(ch, b.shards) + s := getShard(ch, b.getShards(ch)) pr := pubRequest{ MsgType: "j", Channel: ch, @@ -199,7 +208,7 @@ func (b *Broker) PublishJoin(ch string, info *centrifuge.ClientInfo) error { // PublishLeave - see centrifuge.Broker interface description. func (b *Broker) PublishLeave(ch string, info *centrifuge.ClientInfo) error { - s := consistentShard(ch, b.shards) + s := getShard(ch, b.getShards(ch)) pr := pubRequest{ MsgType: "l", Channel: ch, @@ -255,7 +264,29 @@ func nodeChannel(nodeID string) string { } // Subscribe - see centrifuge.Broker interface description. -func (b *Broker) Subscribe(ch string) error { +func (b *Broker) Subscribe(ch string) (err error) { + var subscribed []*Shard + defer func() { + if err != nil { + for _, s := range subscribed { + if err := b.unsubscribe(ch, s); err != nil { + b.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "failed to unsubscribe node on channel", map[string]interface{}{"channel": ch, "err": err, "shard": s.GetAddresses()})) + } + } + } + }() + + for _, s := range b.getShards(ch) { + if err := b.subscribe(ch, s); err != nil { + b.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "failed to subscribe node on channel", map[string]interface{}{"channel": ch, "err": err, "shard": s.GetAddresses()})) + return err + } + subscribed = append(subscribed, s) + } + return nil +} + +func (b *Broker) subscribe(ch string, s *Shard) error { if strings.HasPrefix(ch, internalChannelPrefix) { return centrifuge.ErrorBadRequest } @@ -263,17 +294,22 @@ func (b *Broker) Subscribe(ch string) error { b.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelDebug, "subscribe node on channel", map[string]interface{}{"channel": ch})) } r := newSubRequest([]string{ch}, true) - s := b.shards[consistentIndex(ch, len(b.shards))] return b.sendSubscribe(s, r) } // Unsubscribe - see centrifuge.Broker interface description. func (b *Broker) Unsubscribe(ch string) error { + for _, s := range b.getShards(ch) { + b.unsubscribe(ch, s) + } + return nil +} + +func (b *Broker) unsubscribe(ch string, s *Shard) error { if b.node.LogEnabled(centrifuge.LogLevelDebug) { b.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelDebug, "unsubscribe node from channel", map[string]interface{}{"channel": ch})) } r := newSubRequest([]string{ch}, false) - s := b.shards[consistentIndex(ch, len(b.shards))] return b.sendSubscribe(s, r) } @@ -410,7 +446,7 @@ func (b *Broker) History(ch string, filter centrifuge.HistoryFilter) ([]*centrif limit = filter.Limit } historyMetaTTLSeconds := int(b.config.HistoryMetaTTL.Seconds()) - s := consistentShard(ch, b.shards) + s := getShard(ch, b.getShards(ch)) req := historyRequest{ Channel: ch, Offset: offset, @@ -435,7 +471,7 @@ type removeHistoryRequest struct { // RemoveHistory - see centrifuge.Broker interface description. func (b *Broker) RemoveHistory(ch string) error { - s := consistentShard(ch, b.shards) + s := getShard(ch, b.getShards(ch)) _, err := s.Exec(tarantool.Call("centrifuge.remove_history", removeHistoryRequest{Channel: ch})) return err } @@ -449,11 +485,19 @@ const ( tarantoolSubscribeBatchLimit = 512 ) -func (b *Broker) getShard(channel string) *Shard { - if !b.sharding { - return b.shards[0] +func (b *Broker) getShards(channel string) []*Shard { + if b.router != nil { + if value := b.router.Find(channel); value != nil { + return value.([]*Shard) + } + b.node.Log(centrifuge.NewLogEntry(centrifuge.LogLevelError, "Tarantool routing failure: no shard found", map[string]any{"channel": channel})) } - return b.shards[consistentIndex(channel, len(b.shards))] + + return b.shards +} + +func getShard(channel string, shards []*Shard) *Shard { + return consistentShard(channel, shards) } type pollRequest struct { @@ -653,7 +697,7 @@ func (b *Broker) runPubSub(s *Shard, eventHandler centrifuge.BrokerEventHandler) channels := b.node.Hub().Channels() for i := 0; i < len(channels); i++ { - if b.getShard(channels[i]) == s { + if getShard(channels[i], b.getShards(channels[i])) == s { chIDs = append(chIDs, channels[i]) } } diff --git a/internal/tntengine/shard.go b/internal/tntengine/shard.go index fda3972e40..f7a424dc24 100644 --- a/internal/tntengine/shard.go +++ b/internal/tntengine/shard.go @@ -77,6 +77,10 @@ func (s *Shard) ExecTyped(request *tarantool.Request, result interface{}) error return conn.ExecTyped(request, result) } +func (s *Shard) GetAddresses() []string { + return s.config.Addresses +} + func (s *Shard) pubSubConn() (*tarantool.Connection, func(), error) { conn, err := s.mc.NewLeaderConn(tarantool.Opts{ ConnectTimeout: defaultConnectTimeout, diff --git a/main.go b/main.go index e87679526c..5e5c068a3c 100644 --- a/main.go +++ b/main.go @@ -49,6 +49,7 @@ import ( "github.com/centrifugal/centrifugo/v4/internal/notify" "github.com/centrifugal/centrifugo/v4/internal/origin" "github.com/centrifugal/centrifugo/v4/internal/proxy" + "github.com/centrifugal/centrifugo/v4/internal/router" "github.com/centrifugal/centrifugo/v4/internal/rule" "github.com/centrifugal/centrifugo/v4/internal/survey" "github.com/centrifugal/centrifugo/v4/internal/tntengine" @@ -1864,6 +1865,33 @@ func rpcNamespacesFromConfig(v *viper.Viper) []rule.RpcNamespace { return ns } +func getRouterConfig(v *viper.Viper) *router.Config { + var cfg router.Config + if !v.IsSet("router") { + return &cfg + } + var err error + switch val := v.Get("router").(type) { + case string: + err = json.Unmarshal([]byte(val), &cfg) + case interface{}: + decoderCfg := tools.DecoderConfig(&cfg) + decoder, newErr := mapstructure.NewDecoder(decoderCfg) + if newErr != nil { + log.Fatal().Msg(newErr.Error()) + return &cfg + } + err = decoder.Decode(v.Get("router")) + default: + err = fmt.Errorf("unknown router type: %T", val) + } + if err != nil { + log.Error().Err(err).Msg("malformed router") + os.Exit(1) + } + return &cfg +} + func getPingPongConfig() centrifuge.PingPongConfig { pingInterval := GetDuration("ping_interval") pongTimeout := GetDuration("pong_timeout") @@ -2270,13 +2298,60 @@ func getTarantoolShards() ([]*tntengine.Shard, string, error) { return tarantoolShards, mode, nil } +func tarantoolMapRoutesToShards(n *centrifuge.Node, routerConfig *router.Config, shards []*tntengine.Shard) (*router.Router, error) { + findShard := func(addr string) (*tntengine.Shard, error) { + for _, s := range shards { + for _, shardAddr := range s.GetAddresses() { + if addr == shardAddr { + return s, nil + } + } + } + return nil, fmt.Errorf("%s was not found in tarantool_address", addr) + } + + if len(routerConfig.Routes) == 0 { + return nil, nil + } + + r := router.New(n) + for _, route := range routerConfig.Routes { + var routeShards []*tntengine.Shard + for _, addr := range route.Addresses { + s, err := findShard(addr) + if err != nil { + return nil, err + } + routeShards = append(routeShards, s) + } + + ch := route.Name + last := len(ch)-1 + if len(ch) > 0 && ch[last] == '*' { + r.AddPrefix(ch[0:last], routeShards) + log.Debug().Msgf("added prefix route %s for channel %s", tools.GetLogAddresses(route.Addresses), ch) + continue + } + r.AddExact(ch, routeShards) + log.Debug().Msgf("added exact route %s for channel %s", tools.GetLogAddresses(route.Addresses), ch) + } + return r, nil +} + func tarantoolEngine(n *centrifuge.Node) (centrifuge.Broker, centrifuge.PresenceManager, string, error) { tarantoolShards, mode, err := getTarantoolShards() if err != nil { return nil, nil, "", err } + routerConfig := getRouterConfig(viper.GetViper()) + channelRouter, err := tarantoolMapRoutesToShards(n, routerConfig, tarantoolShards) + if err != nil { + return nil, nil, "", err + } + broker, err := tntengine.NewBroker(n, tntengine.BrokerConfig{ Shards: tarantoolShards, + Router: channelRouter, HistoryMetaTTL: GetDuration("history_meta_ttl", true), UseJSON: viper.GetBool("tarantool_experimental_use_json_in_broker"), })