Skip to content

Commit

Permalink
Allow matches to update their labels.
Browse files Browse the repository at this point in the history
  • Loading branch information
zyro committed Jul 9, 2018
1 parent d5541b9 commit aae6914
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 43 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ data/*
install/cloud/**/*.json
install/cloud/**/*.tfvars

*.pprof

### Go ###
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ The format is based on [keep a changelog](http://keepachangelog.com) and this pr
- Authoritative match modules now allow a `match_join` callback that triggers when users have completed their join process.
- New stream API function to upsert a user presence.
- Extended validation of Google tokens to account for different token payloads.
- Authoritative match labels can now be updated using the dispatcher's `match_label_update` function.

### Changed
- Presence list in match join responses no longer contains the user's own presence.
- Presence list in channel join responses no longer contains the user's own presence.
- Socket read/write buffer sizes are now set based on the `socket.max_message_size_bytes` config value.
- Console GRPC port now set relative to `console.port` config value.

## [2.0.1] - 2018-06-15
### Added
Expand Down
8 changes: 8 additions & 0 deletions data/modules/match.lua
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ Dispatcher exposes useful functions to the match. Format:
-- a presence to tag on the message as the 'sender', or nil
match_kick = function(presences)
-- a list of presences to remove from the match
match_label_update = function(label)
-- a new label to set for the match
}
Tick is the current match tick number, starts at 0 and increments after every match_loop call. Does not increment with
Expand Down Expand Up @@ -119,6 +121,8 @@ Dispatcher exposes useful functions to the match. Format:
-- a presence to tag on the message as the 'sender', or nil
match_kick = function(presences)
-- a list of presences to remove from the match
match_label_update = function(label)
-- a new label to set for the match
}
Tick is the current match tick number, starts at 0 and increments after every match_loop call. Does not increment with
Expand Down Expand Up @@ -169,6 +173,8 @@ Dispatcher exposes useful functions to the match. Format:
-- a presence to tag on the message as the 'sender', or nil
match_kick = function(presences)
-- a list of presences to remove from the match
match_label_update = function(label)
-- a new label to set for the match
}
Tick is the current match tick number, starts at 0 and increments after every match_loop call. Does not increment with
Expand Down Expand Up @@ -219,6 +225,8 @@ Dispatcher exposes useful functions to the match. Format:
-- a presence to tag on the message as the 'sender', or nil
match_kick = function(presences)
-- a list of presences to remove from the match
match_label_update = function(label)
-- a new label to set for the match
}
Tick is the current match tick number, starts at 0 and increments after every match_loop call. Does not increment with
Expand Down
6 changes: 3 additions & 3 deletions server/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ func StartConsoleServer(logger *zap.Logger, startupLogger *zap.Logger, config Co
}

console.RegisterConsoleServer(grpcServer, s)
startupLogger.Info("Starting Console server for gRPC requests", zap.Int("port", config.GetSocket().Port-2))
startupLogger.Info("Starting Console server for gRPC requests", zap.Int("port", config.GetConsole().Port-3))
go func() {
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", config.GetSocket().Port-2))
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", config.GetConsole().Port-3))
if err != nil {
startupLogger.Fatal("Console server listener failed to start", zap.Error(err))
}
Expand All @@ -73,7 +73,7 @@ func StartConsoleServer(logger *zap.Logger, startupLogger *zap.Logger, config Co

ctx := context.Background()
grpcGateway := runtime.NewServeMux()
dialAddr := fmt.Sprintf("127.0.0.1:%d", config.GetSocket().Port-2)
dialAddr := fmt.Sprintf("127.0.0.1:%d", config.GetConsole().Port-3)
dialOpts := []grpc.DialOption{
//TODO (mo, zyro): Do we need to pass the statsHandler here as well?
grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(int(config.GetSocket().MaxMessageSizeBytes))),
Expand Down
56 changes: 25 additions & 31 deletions server/match_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/pkg/errors"
"github.com/satori/go.uuid"
"github.com/yuin/gopher-lua"
"go.uber.org/atomic"
"go.uber.org/zap"
)

Expand All @@ -39,7 +40,6 @@ type MatchDataMessage struct {
}

type MatchHandler struct {
sync.Mutex
logger *zap.Logger
matchRegistry MatchRegistry
tracker Tracker
Expand Down Expand Up @@ -67,10 +67,10 @@ type MatchHandler struct {
ticker *time.Ticker
callCh chan func(*MatchHandler)
stopCh chan struct{}
stopped bool
stopped *atomic.Bool

// Immutable configuration set by match init.
Label string
// Configuration set by match init.
Label *atomic.String
Rate int

// Match state.
Expand Down Expand Up @@ -225,18 +225,19 @@ func NewMatchHandler(logger *zap.Logger, db *sql.DB, config Config, socialClient
// Ticker below.
callCh: make(chan func(mh *MatchHandler), config.GetMatch().CallQueueSize),
stopCh: make(chan struct{}),
stopped: false,
stopped: atomic.NewBool(false),

Label: labelStr,
Label: atomic.NewString(labelStr),
Rate: rateInt,

state: state,
}

// Set up the dispatcher that exposes control functions to the match loop.
mh.dispatcher = vm.SetFuncs(vm.CreateTable(0, 2), map[string]lua.LGFunction{
"broadcast_message": mh.broadcastMessage,
"match_kick": mh.matchKick,
mh.dispatcher = vm.SetFuncs(vm.CreateTable(0, 3), map[string]lua.LGFunction{
"broadcast_message": mh.broadcastMessage,
"match_kick": mh.matchKick,
"match_label_update": mh.matchLabelUpdate,
})

// Set up the ticker that governs the match loop.
Expand Down Expand Up @@ -274,13 +275,9 @@ func (mh *MatchHandler) Stop() {

// Used when the match is closed externally.
func (mh *MatchHandler) Close() {
mh.Lock()
if mh.stopped {
mh.Unlock()
if !mh.stopped.CAS(false, true) {
return
}
mh.stopped = true
mh.Unlock()
close(mh.stopCh)
mh.ticker.Stop()
}
Expand Down Expand Up @@ -309,12 +306,9 @@ func (mh *MatchHandler) QueueData(m *MatchDataMessage) {
}

func loop(mh *MatchHandler) {
mh.Lock()
if mh.stopped {
mh.Unlock()
if mh.stopped.Load() {
return
}
mh.Unlock()

// Drain the input queue into a Lua table.
size := len(mh.inputCh)
Expand Down Expand Up @@ -379,13 +373,10 @@ func loop(mh *MatchHandler) {

func JoinAttempt(resultCh chan *MatchJoinResult, userID, sessionID uuid.UUID, username, node string) func(mh *MatchHandler) {
return func(mh *MatchHandler) {
mh.Lock()
if mh.stopped {
mh.Unlock()
if mh.stopped.Load() {
resultCh <- &MatchJoinResult{Allow: false}
return
}
mh.Unlock()

presence := mh.vm.CreateTable(0, 4)
presence.RawSetString("user_id", lua.LString(userID.String()))
Expand Down Expand Up @@ -473,7 +464,7 @@ func JoinAttempt(resultCh chan *MatchJoinResult, userID, sessionID uuid.UUID, us
mh.vm.Pop(1)

mh.state = state
resultCh <- &MatchJoinResult{Allow: allow, Reason: reason, Label: mh.Label}
resultCh <- &MatchJoinResult{Allow: allow, Reason: reason, Label: mh.Label.Load()}
}
}

Expand All @@ -483,12 +474,9 @@ func Join(joins []*MatchPresence) func(mh *MatchHandler) {
return
}

mh.Lock()
if mh.stopped {
mh.Unlock()
if mh.stopped.Load() {
return
}
mh.Unlock()

presences := mh.vm.CreateTable(len(joins), 0)
for i, p := range joins {
Expand Down Expand Up @@ -539,12 +527,9 @@ func Join(joins []*MatchPresence) func(mh *MatchHandler) {

func Leave(leaves []*MatchPresence) func(mh *MatchHandler) {
return func(mh *MatchHandler) {
mh.Lock()
if mh.stopped {
mh.Unlock()
if mh.stopped.Load() {
return
}
mh.Unlock()

presences := mh.vm.CreateTable(len(leaves), 0)
for i, p := range leaves {
Expand Down Expand Up @@ -814,3 +799,12 @@ func (mh *MatchHandler) matchKick(l *lua.LState) int {
mh.matchRegistry.Kick(mh.Stream, presences)
return 0
}

func (mh *MatchHandler) matchLabelUpdate(l *lua.LState) int {
input := l.OptString(1, "")

mh.Label.Store(input)
// This must be executed from inside a match call so safe to update here.
mh.ctx.RawSetString(__CTX_MATCH_LABEL, lua.LString(input))
return 0
}
13 changes: 9 additions & 4 deletions server/match_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,17 @@ func (r *LocalMatchRegistry) ListMatches(limit int, authoritative *wrappers.Bool
}

mh := r.GetMatch(stream.Subject)
if mh == nil || (label != nil && label.Value != mh.Label) {
if mh == nil {
continue
}
mhLabel := mh.Label.Load()
if label != nil && label.Value != mhLabel {
continue
}
results = append(results, &api.Match{
MatchId: mh.IDStr,
Authoritative: true,
Label: &wrappers.StringValue{Value: mh.Label},
Label: &wrappers.StringValue{Value: mhLabel},
Size: size,
})
if len(results) == limit {
Expand All @@ -233,7 +237,8 @@ func (r *LocalMatchRegistry) ListMatches(limit int, authoritative *wrappers.Bool
// Already checked and discarded this match for failing a filter, skip it.
continue
}
if label != nil && label.Value != mh.Label {
mhLabel := mh.Label.Load()
if label != nil && label.Value != mhLabel {
// Label mismatch.
continue
}
Expand All @@ -249,7 +254,7 @@ func (r *LocalMatchRegistry) ListMatches(limit int, authoritative *wrappers.Bool
results = append(results, &api.Match{
MatchId: mh.IDStr,
Authoritative: true,
Label: &wrappers.StringValue{Value: mh.Label},
Label: &wrappers.StringValue{Value: mhLabel},
Size: size,
})
if len(results) == limit {
Expand Down
4 changes: 2 additions & 2 deletions server/socket_ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ var SocketWsStatsCtx = context.Background()

func NewSocketWsAcceptor(logger *zap.Logger, config Config, sessionRegistry *SessionRegistry, matchmaker Matchmaker, tracker Tracker, jsonpbMarshaler *jsonpb.Marshaler, jsonpbUnmarshaler *jsonpb.Unmarshaler, pipeline *Pipeline) func(http.ResponseWriter, *http.Request) {
upgrader := &websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
ReadBufferSize: int(config.GetSocket().MaxMessageSizeBytes),
WriteBufferSize: int(config.GetSocket().MaxMessageSizeBytes),
CheckOrigin: func(r *http.Request) bool { return true },
}

Expand Down
19 changes: 19 additions & 0 deletions server/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ type Tracker interface {
CountByStreamModeFilter(modes map[uint8]*uint8) map[*PresenceStream]int32
// Check if a single presence on the current node exists.
GetLocalBySessionIDStreamUserID(sessionID uuid.UUID, stream PresenceStream, userID uuid.UUID) *PresenceMeta
// Check if a single presence on any node exists.
GetBySessionIDStreamUserID(node string, sessionID uuid.UUID, stream PresenceStream, userID uuid.UUID) *PresenceMeta
// List presences by stream, optionally include hidden ones.
ListByStream(stream PresenceStream, includeHidden bool) []*Presence

Expand Down Expand Up @@ -529,6 +531,23 @@ func (t *LocalTracker) GetLocalBySessionIDStreamUserID(sessionID uuid.UUID, stre
return &meta
}

func (t *LocalTracker) GetBySessionIDStreamUserID(node string, sessionID uuid.UUID, stream PresenceStream, userID uuid.UUID) *PresenceMeta {
pc := presenceCompact{ID: PresenceID{Node: node, SessionID: sessionID}, Stream: stream, UserID: userID}
t.RLock()
bySession, anyTracked := t.presencesBySession[sessionID]
if !anyTracked {
// Nothing tracked for the session.
t.RUnlock()
return nil
}
meta, found := bySession[pc]
t.RUnlock()
if !found {
return nil
}
return &meta
}

func (t *LocalTracker) ListByStream(stream PresenceStream, includeHidden bool) []*Presence {
t.RLock()
byStream, anyTracked := t.presencesByStream[stream.Mode][stream]
Expand Down
15 changes: 12 additions & 3 deletions social/social.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ type GoogleProfile struct {
Exp int64 `json:"exp"`
// Fields available only if the user granted the "profile" and "email" OAuth scopes.
Email string `json:"email"`
EmailVerified string `json:"email_verified"`
EmailVerified bool `json:"email_verified"`
Name string `json:"name"`
Picture string `json:"picture"`
GivenName string `json:"given_name"`
Expand Down Expand Up @@ -327,8 +327,17 @@ func (c *Client) CheckGoogleToken(idToken string) (*GoogleProfile, error) {
}
}
if v, ok := claims["email_verified"]; ok {
if profile.EmailVerified, ok = v.(string); !ok {
return nil, errors.New("google id token email verified field invalid")
switch v.(type) {
case bool:
profile.EmailVerified = v.(bool)
case string:
if vb, err := strconv.ParseBool(v.(string)); err != nil {
return nil, errors.New("google id token email_verified field invalid")
} else {
profile.EmailVerified = vb
}
default:
return nil, errors.New("google id token email_verified field unknown")
}
}
if v, ok := claims["name"]; ok {
Expand Down

0 comments on commit aae6914

Please sign in to comment.