Skip to content

Commit

Permalink
Add helpers for device list tracking tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Quah committed Sep 23, 2022
1 parent 4eb69d6 commit 6b26c8e
Showing 1 changed file with 270 additions and 0 deletions.
270 changes: 270 additions & 0 deletions tests/federation_room_join_partial_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"
"time"

"github.com/gorilla/mux"
"github.com/tidwall/gjson"

"github.com/matrix-org/gomatrix"
Expand Down Expand Up @@ -1489,6 +1490,275 @@ func TestPartialStateJoin(t *testing.T) {
t.Errorf("SendKnock: non-HTTPError: %v", err)
}
})

// test that:
// * remote device lists are correctly cached or not cached
// * local users are told about potential device list changes in `/sync`'s
// `device_lists.changed/left`
// * local users are told about potential device list changes in `/keys/changes`.
t.Run("Device list tracking", func(t *testing.T) {
// setupDeviceListCachingTest sets up a complement homeserver.
// A room is created on the complement server, containing only local users.
// Returns a channel for device list requests arriving at the complement homeserver, which
// can be used with `mustQueryKeysWithFederationRequest` and
// `mustQueryKeysWithoutFederationRequest`.
setupDeviceListCachingTest := func(
t *testing.T, deployment *docker.Deployment, aliceLocalpart string,
) (
alice *client.CSAPI, server *federation.Server, userDevicesQueryChannel chan string,
room *federation.ServerRoom, sendDeviceListUpdate func(string), cleanup func(),
) {
alice = deployment.RegisterUser(t, "hs1", aliceLocalpart, "secret", false)

userDevicesQueryChannel = make(chan string, 1)

makeRespUserDeviceKeys := func(
userID string, deviceID string,
) gomatrixserverlib.RespUserDeviceKeys {
return gomatrixserverlib.RespUserDeviceKeys{
UserID: userID,
DeviceID: deviceID,
Algorithms: []string{
"m.megolm.v1.aes-sha2",
},
Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
"ed25519:JLAFKJWSCS": []byte("lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"),
},
Signatures: map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
userID: {
"ed25519:JLAFKJWSCS": []byte("dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"),
},
},
}
}

lastDeviceStreamID := int64(2)
server = createTestServer(t, deployment,
federation.HandleEventAuthRequests(),
func(server *federation.Server) {
server.Mux().HandleFunc("/_matrix/federation/v1/user/devices/{userID}",
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
t.Logf("Incoming %s %s", req.Method, req.URL.Path)

vars := mux.Vars(req)
userID := vars["userID"]
deviceID := fmt.Sprintf("%s_device", userID)

userDevicesQueryChannel <- userID

// Make up a device list for the user.
responseBytes, _ := json.Marshal(gomatrixserverlib.RespUserDevices{
UserID: userID,
StreamID: lastDeviceStreamID,
Devices: []gomatrixserverlib.RespUserDevice{
{
DeviceID: deviceID,
DisplayName: fmt.Sprintf("%s's device", userID),
Keys: makeRespUserDeviceKeys(userID, deviceID),
},
},
})
w.WriteHeader(200)
w.Write(responseBytes)
}),
).Methods("GET")
},
func(server *federation.Server) {
server.Mux().HandleFunc("/_matrix/federation/v1/user/keys/query",
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
t.Logf("Incoming %s %s", req.Method, req.URL.Path)

body, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Fatalf("unable to read /user/keys/query request body: %s", err)
}

var queryKeysRequest struct {
DeviceKeys map[string][]string `json:"device_keys"`
}
if err := json.Unmarshal(body, &queryKeysRequest); err != nil {
t.Fatalf("unable to unmarshall /user/keys/query request body: %s", err)
}

// Make up keys for every device requested.
deviceKeys := make(map[string]map[string]gomatrixserverlib.DeviceKeys)
for userID := range queryKeysRequest.DeviceKeys {
userDevicesQueryChannel <- userID

deviceID := fmt.Sprintf("%s_device", userID)
deviceKeys[userID] = map[string]gomatrixserverlib.DeviceKeys{
deviceID: {
RespUserDeviceKeys: makeRespUserDeviceKeys(userID, deviceID),
},
}
}

responseBytes, _ := json.Marshal(gomatrixserverlib.RespQueryKeys{
DeviceKeys: deviceKeys,
})
w.WriteHeader(200)
w.Write(responseBytes)
}),
).Methods("POST")
},
)

cancel := server.Listen()

room = createTestRoom(t, server, alice.GetDefaultRoomVersion(t))

sendDeviceListUpdate = func(localpart string) {
t.Helper()

userID := server.UserID(localpart)
deviceID := fmt.Sprintf("%s_device", userID)

// Advance the stream ID by 2 each time, so that the homeserver under test thinks it
// has missed an update and is forced to make a federation request to request the
// updated device list.
lastDeviceStreamID += 2

keys, _ := json.Marshal(makeRespUserDeviceKeys(userID, deviceID))
deviceListUpdate, _ := json.Marshal(gomatrixserverlib.DeviceListUpdateEvent{
UserID: userID,
DeviceID: deviceID,
DeviceDisplayName: fmt.Sprintf("%s's device", userID),
StreamID: lastDeviceStreamID,
PrevID: []int64{lastDeviceStreamID - 1},
Deleted: false,
Keys: keys,
})
server.MustSendTransaction(t, deployment, "hs1", []json.RawMessage{}, []gomatrixserverlib.EDU{
{
Type: "m.device_list_update",
Origin: server.ServerName(),
Destination: "hs1",
Content: deviceListUpdate,
},
})
}

cleanup = func() {
cancel()
close(userDevicesQueryChannel)
}
return
}

// mustQueryKeys makes a /keys/query request to the homeserver under test.
mustQueryKeys := func(t *testing.T, user *client.CSAPI, userID string) {
t.Helper()

user.MustDoFunc(t, "POST", []string{"_matrix", "client", "v3", "keys", "query"},
client.WithJSONBody(t, map[string]interface{}{
"device_keys": map[string]interface{}{
userID: []string{},
},
}),
)
}

// mustQueryKeysWithFederationRequest makes a /keys/query request to the homeserver under
// test and checks that the complement homeserver has received a device list request since
// the previous call to `mustQueryKeysWithFederationRequest` or
// `mustQueryKeysWithoutFederationRequest`.
// Accepts the channel for device list requests returned by `setupDeviceListCachingTest`.
mustQueryKeysWithFederationRequest := func(
t *testing.T, user *client.CSAPI, userDevicesQueryChannel chan string, userID string,
) {
t.Helper()

mustQueryKeys(t, user, userID)

if len(userDevicesQueryChannel) == 0 {
t.Fatalf("%s's device list was cached when it should not be.", userID)
}

// Empty the channel.
for len(userDevicesQueryChannel) > 0 {
<-userDevicesQueryChannel
}
}

// mustQueryKeysWithoutFederationRequest makes a /keys/query request to the homeserver under
// test and checks that the complement homeserver has not received a device list request
// since the previous call to `mustQueryKeysWithFederationRequest` or
// `mustQueryKeysWithoutFederationRequest`.
// Accepts the channel for device list requests returned by `setupDeviceListCachingTest`.
mustQueryKeysWithoutFederationRequest := func(
t *testing.T, user *client.CSAPI, userDevicesQueryChannel chan string, userID string,
) {
t.Helper()

mustQueryKeys(t, user, userID)

if len(userDevicesQueryChannel) > 0 {
t.Fatalf("%s's device list was not cached when it should have been.", userID)
}

// Empty the channel.
for len(userDevicesQueryChannel) > 0 {
<-userDevicesQueryChannel
}
}

// syncDeviceListsHas checks that `device_lists.changed` or `device_lists.left` contains a
// given user ID.
syncDeviceListsHas := func(section string, expectedUserID string) client.SyncCheckOpt {
jsonPath := fmt.Sprintf("device_lists.%s", section)
return func(clientUserID string, topLevelSyncJSON gjson.Result) error {
usersWithChangedDeviceListsArray := topLevelSyncJSON.Get(jsonPath).Array()
for _, userID := range usersWithChangedDeviceListsArray {
if userID.Str == expectedUserID {
return nil
}
}
return fmt.Errorf(
"syncDeviceListsHas: %s not found in %s",
expectedUserID,
jsonPath,
)
}
}

// mustSyncUntilDeviceListsHas syncs until `device_lists.changed` or `device_lists.left`
// contains a given user ID.
// Also tests that /keys/changes returns the same information.
mustSyncUntilDeviceListsHas := func(
t *testing.T, user *client.CSAPI, syncToken string, section string,
expectedUserID string,
) string {
t.Helper()

nextSyncToken := user.MustSyncUntil(
t,
client.SyncReq{
Since: syncToken,
Filter: buildLazyLoadingSyncFilter(nil),
},
syncDeviceListsHas(section, expectedUserID),
)

res := user.MustDoFunc(t, "GET", []string{"_matrix", "client", "v3", "keys", "changes"},
client.WithQueries(url.Values{
"from": []string{syncToken},
"to": []string{nextSyncToken},
}),
)
must.MatchResponse(t, res, match.HTTPResponse{
StatusCode: 200,
JSON: []match.JSON{
match.JSONCheckOffAllowUnwanted(
section,
[]interface{}{expectedUserID},
func(r gjson.Result) interface{} { return r.Str },
nil,
),
},
})
return nextSyncToken
}
})
}

// test reception of an event over federation during a resync
Expand Down

0 comments on commit 6b26c8e

Please sign in to comment.