diff --git a/.gitignore b/.gitignore index 5725df9790..0eae9cf740 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,8 @@ data/* install/cloud/**/*.json install/cloud/**/*.tfvars +*.pprof + ### Go ### # Compiled Object files, Static and Dynamic libs (Shared Objects) *.o diff --git a/CHANGELOG.md b/CHANGELOG.md index a22f5705d2..d189c94029 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,13 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr - Authoritative match modules now allow a `match_join` callback that triggers when users have completed their join process. - New stream API function to upsert a user presence. - Extended validation of Google tokens to account for different token payloads. +- Authoritative match labels can now be updated using the dispatcher's `match_label_update` function. ### Changed - Presence list in match join responses no longer contains the user's own presence. - Presence list in channel join responses no longer contains the user's own presence. +- Socket read/write buffer sizes are now set based on the `socket.max_message_size_bytes` config value. +- Console GRPC port now set relative to `console.port` config value. ## [2.0.1] - 2018-06-15 ### Added diff --git a/data/modules/match.lua b/data/modules/match.lua index 829e85bc78..1e82580432 100644 --- a/data/modules/match.lua +++ b/data/modules/match.lua @@ -69,6 +69,8 @@ Dispatcher exposes useful functions to the match. Format: -- a presence to tag on the message as the 'sender', or nil match_kick = function(presences) -- a list of presences to remove from the match + match_label_update = function(label) + -- a new label to set for the match } Tick is the current match tick number, starts at 0 and increments after every match_loop call. Does not increment with @@ -119,6 +121,8 @@ Dispatcher exposes useful functions to the match. Format: -- a presence to tag on the message as the 'sender', or nil match_kick = function(presences) -- a list of presences to remove from the match + match_label_update = function(label) + -- a new label to set for the match } Tick is the current match tick number, starts at 0 and increments after every match_loop call. Does not increment with @@ -169,6 +173,8 @@ Dispatcher exposes useful functions to the match. Format: -- a presence to tag on the message as the 'sender', or nil match_kick = function(presences) -- a list of presences to remove from the match + match_label_update = function(label) + -- a new label to set for the match } Tick is the current match tick number, starts at 0 and increments after every match_loop call. Does not increment with @@ -219,6 +225,8 @@ Dispatcher exposes useful functions to the match. Format: -- a presence to tag on the message as the 'sender', or nil match_kick = function(presences) -- a list of presences to remove from the match + match_label_update = function(label) + -- a new label to set for the match } Tick is the current match tick number, starts at 0 and increments after every match_loop call. Does not increment with diff --git a/server/console.go b/server/console.go index b7c8e908f7..c71362014f 100644 --- a/server/console.go +++ b/server/console.go @@ -59,9 +59,9 @@ func StartConsoleServer(logger *zap.Logger, startupLogger *zap.Logger, config Co } console.RegisterConsoleServer(grpcServer, s) - startupLogger.Info("Starting Console server for gRPC requests", zap.Int("port", config.GetSocket().Port-2)) + startupLogger.Info("Starting Console server for gRPC requests", zap.Int("port", config.GetConsole().Port-3)) go func() { - listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", config.GetSocket().Port-2)) + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", config.GetConsole().Port-3)) if err != nil { startupLogger.Fatal("Console server listener failed to start", zap.Error(err)) } @@ -73,7 +73,7 @@ func StartConsoleServer(logger *zap.Logger, startupLogger *zap.Logger, config Co ctx := context.Background() grpcGateway := runtime.NewServeMux() - dialAddr := fmt.Sprintf("127.0.0.1:%d", config.GetSocket().Port-2) + dialAddr := fmt.Sprintf("127.0.0.1:%d", config.GetConsole().Port-3) dialOpts := []grpc.DialOption{ //TODO (mo, zyro): Do we need to pass the statsHandler here as well? grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(int(config.GetSocket().MaxMessageSizeBytes))), diff --git a/server/match_handler.go b/server/match_handler.go index 703f655ac4..467b709ffa 100644 --- a/server/match_handler.go +++ b/server/match_handler.go @@ -25,6 +25,7 @@ import ( "github.com/pkg/errors" "github.com/satori/go.uuid" "github.com/yuin/gopher-lua" + "go.uber.org/atomic" "go.uber.org/zap" ) @@ -39,7 +40,6 @@ type MatchDataMessage struct { } type MatchHandler struct { - sync.Mutex logger *zap.Logger matchRegistry MatchRegistry tracker Tracker @@ -67,10 +67,10 @@ type MatchHandler struct { ticker *time.Ticker callCh chan func(*MatchHandler) stopCh chan struct{} - stopped bool + stopped *atomic.Bool - // Immutable configuration set by match init. - Label string + // Configuration set by match init. + Label *atomic.String Rate int // Match state. @@ -225,18 +225,19 @@ func NewMatchHandler(logger *zap.Logger, db *sql.DB, config Config, socialClient // Ticker below. callCh: make(chan func(mh *MatchHandler), config.GetMatch().CallQueueSize), stopCh: make(chan struct{}), - stopped: false, + stopped: atomic.NewBool(false), - Label: labelStr, + Label: atomic.NewString(labelStr), Rate: rateInt, state: state, } // Set up the dispatcher that exposes control functions to the match loop. - mh.dispatcher = vm.SetFuncs(vm.CreateTable(0, 2), map[string]lua.LGFunction{ - "broadcast_message": mh.broadcastMessage, - "match_kick": mh.matchKick, + mh.dispatcher = vm.SetFuncs(vm.CreateTable(0, 3), map[string]lua.LGFunction{ + "broadcast_message": mh.broadcastMessage, + "match_kick": mh.matchKick, + "match_label_update": mh.matchLabelUpdate, }) // Set up the ticker that governs the match loop. @@ -274,13 +275,9 @@ func (mh *MatchHandler) Stop() { // Used when the match is closed externally. func (mh *MatchHandler) Close() { - mh.Lock() - if mh.stopped { - mh.Unlock() + if !mh.stopped.CAS(false, true) { return } - mh.stopped = true - mh.Unlock() close(mh.stopCh) mh.ticker.Stop() } @@ -309,12 +306,9 @@ func (mh *MatchHandler) QueueData(m *MatchDataMessage) { } func loop(mh *MatchHandler) { - mh.Lock() - if mh.stopped { - mh.Unlock() + if mh.stopped.Load() { return } - mh.Unlock() // Drain the input queue into a Lua table. size := len(mh.inputCh) @@ -379,13 +373,10 @@ func loop(mh *MatchHandler) { func JoinAttempt(resultCh chan *MatchJoinResult, userID, sessionID uuid.UUID, username, node string) func(mh *MatchHandler) { return func(mh *MatchHandler) { - mh.Lock() - if mh.stopped { - mh.Unlock() + if mh.stopped.Load() { resultCh <- &MatchJoinResult{Allow: false} return } - mh.Unlock() presence := mh.vm.CreateTable(0, 4) presence.RawSetString("user_id", lua.LString(userID.String())) @@ -473,7 +464,7 @@ func JoinAttempt(resultCh chan *MatchJoinResult, userID, sessionID uuid.UUID, us mh.vm.Pop(1) mh.state = state - resultCh <- &MatchJoinResult{Allow: allow, Reason: reason, Label: mh.Label} + resultCh <- &MatchJoinResult{Allow: allow, Reason: reason, Label: mh.Label.Load()} } } @@ -483,12 +474,9 @@ func Join(joins []*MatchPresence) func(mh *MatchHandler) { return } - mh.Lock() - if mh.stopped { - mh.Unlock() + if mh.stopped.Load() { return } - mh.Unlock() presences := mh.vm.CreateTable(len(joins), 0) for i, p := range joins { @@ -539,12 +527,9 @@ func Join(joins []*MatchPresence) func(mh *MatchHandler) { func Leave(leaves []*MatchPresence) func(mh *MatchHandler) { return func(mh *MatchHandler) { - mh.Lock() - if mh.stopped { - mh.Unlock() + if mh.stopped.Load() { return } - mh.Unlock() presences := mh.vm.CreateTable(len(leaves), 0) for i, p := range leaves { @@ -814,3 +799,12 @@ func (mh *MatchHandler) matchKick(l *lua.LState) int { mh.matchRegistry.Kick(mh.Stream, presences) return 0 } + +func (mh *MatchHandler) matchLabelUpdate(l *lua.LState) int { + input := l.OptString(1, "") + + mh.Label.Store(input) + // This must be executed from inside a match call so safe to update here. + mh.ctx.RawSetString(__CTX_MATCH_LABEL, lua.LString(input)) + return 0 +} diff --git a/server/match_registry.go b/server/match_registry.go index 78bbcf5bb1..05915a86c9 100644 --- a/server/match_registry.go +++ b/server/match_registry.go @@ -201,13 +201,17 @@ func (r *LocalMatchRegistry) ListMatches(limit int, authoritative *wrappers.Bool } mh := r.GetMatch(stream.Subject) - if mh == nil || (label != nil && label.Value != mh.Label) { + if mh == nil { + continue + } + mhLabel := mh.Label.Load() + if label != nil && label.Value != mhLabel { continue } results = append(results, &api.Match{ MatchId: mh.IDStr, Authoritative: true, - Label: &wrappers.StringValue{Value: mh.Label}, + Label: &wrappers.StringValue{Value: mhLabel}, Size: size, }) if len(results) == limit { @@ -233,7 +237,8 @@ func (r *LocalMatchRegistry) ListMatches(limit int, authoritative *wrappers.Bool // Already checked and discarded this match for failing a filter, skip it. continue } - if label != nil && label.Value != mh.Label { + mhLabel := mh.Label.Load() + if label != nil && label.Value != mhLabel { // Label mismatch. continue } @@ -249,7 +254,7 @@ func (r *LocalMatchRegistry) ListMatches(limit int, authoritative *wrappers.Bool results = append(results, &api.Match{ MatchId: mh.IDStr, Authoritative: true, - Label: &wrappers.StringValue{Value: mh.Label}, + Label: &wrappers.StringValue{Value: mhLabel}, Size: size, }) if len(results) == limit { diff --git a/server/socket_ws.go b/server/socket_ws.go index d73479cfa0..e3f1ce5694 100644 --- a/server/socket_ws.go +++ b/server/socket_ws.go @@ -30,8 +30,8 @@ var SocketWsStatsCtx = context.Background() func NewSocketWsAcceptor(logger *zap.Logger, config Config, sessionRegistry *SessionRegistry, matchmaker Matchmaker, tracker Tracker, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, pipeline *Pipeline) func(http.ResponseWriter, *http.Request) { upgrader := &websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, + ReadBufferSize: int(config.GetSocket().MaxMessageSizeBytes), + WriteBufferSize: int(config.GetSocket().MaxMessageSizeBytes), CheckOrigin: func(r *http.Request) bool { return true }, } diff --git a/server/tracker.go b/server/tracker.go index 3a0d54cccd..e0c7264b8f 100644 --- a/server/tracker.go +++ b/server/tracker.go @@ -100,6 +100,8 @@ type Tracker interface { CountByStreamModeFilter(modes map[uint8]*uint8) map[*PresenceStream]int32 // Check if a single presence on the current node exists. GetLocalBySessionIDStreamUserID(sessionID uuid.UUID, stream PresenceStream, userID uuid.UUID) *PresenceMeta + // Check if a single presence on any node exists. + GetBySessionIDStreamUserID(node string, sessionID uuid.UUID, stream PresenceStream, userID uuid.UUID) *PresenceMeta // List presences by stream, optionally include hidden ones. ListByStream(stream PresenceStream, includeHidden bool) []*Presence @@ -529,6 +531,23 @@ func (t *LocalTracker) GetLocalBySessionIDStreamUserID(sessionID uuid.UUID, stre return &meta } +func (t *LocalTracker) GetBySessionIDStreamUserID(node string, sessionID uuid.UUID, stream PresenceStream, userID uuid.UUID) *PresenceMeta { + pc := presenceCompact{ID: PresenceID{Node: node, SessionID: sessionID}, Stream: stream, UserID: userID} + t.RLock() + bySession, anyTracked := t.presencesBySession[sessionID] + if !anyTracked { + // Nothing tracked for the session. + t.RUnlock() + return nil + } + meta, found := bySession[pc] + t.RUnlock() + if !found { + return nil + } + return &meta +} + func (t *LocalTracker) ListByStream(stream PresenceStream, includeHidden bool) []*Presence { t.RLock() byStream, anyTracked := t.presencesByStream[stream.Mode][stream] diff --git a/social/social.go b/social/social.go index fff5ba3f78..57b8bdd822 100644 --- a/social/social.go +++ b/social/social.go @@ -81,7 +81,7 @@ type GoogleProfile struct { Exp int64 `json:"exp"` // Fields available only if the user granted the "profile" and "email" OAuth scopes. Email string `json:"email"` - EmailVerified string `json:"email_verified"` + EmailVerified bool `json:"email_verified"` Name string `json:"name"` Picture string `json:"picture"` GivenName string `json:"given_name"` @@ -327,8 +327,17 @@ func (c *Client) CheckGoogleToken(idToken string) (*GoogleProfile, error) { } } if v, ok := claims["email_verified"]; ok { - if profile.EmailVerified, ok = v.(string); !ok { - return nil, errors.New("google id token email verified field invalid") + switch v.(type) { + case bool: + profile.EmailVerified = v.(bool) + case string: + if vb, err := strconv.ParseBool(v.(string)); err != nil { + return nil, errors.New("google id token email_verified field invalid") + } else { + profile.EmailVerified = vb + } + default: + return nil, errors.New("google id token email_verified field unknown") } } if v, ok := claims["name"]; ok {