diff --git a/tests/federation_room_join_partial_state_test.go b/tests/federation_room_join_partial_state_test.go index 0be7192a..56930190 100644 --- a/tests/federation_room_join_partial_state_test.go +++ b/tests/federation_room_join_partial_state_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "github.com/gorilla/mux" "github.com/tidwall/gjson" "github.com/matrix-org/gomatrix" @@ -1489,6 +1490,275 @@ func TestPartialStateJoin(t *testing.T) { t.Errorf("SendKnock: non-HTTPError: %v", err) } }) + + // test that: + // * remote device lists are correctly cached or not cached + // * local users are told about potential device list changes in `/sync`'s + // `device_lists.changed/left` + // * local users are told about potential device list changes in `/keys/changes`. + t.Run("Device list tracking", func(t *testing.T) { + // setupDeviceListCachingTest sets up a complement homeserver. + // A room is created on the complement server, containing only local users. + // Returns a channel for device list requests arriving at the complement homeserver, which + // can be used with `mustQueryKeysWithFederationRequest` and + // `mustQueryKeysWithoutFederationRequest`. + setupDeviceListCachingTest := func( + t *testing.T, deployment *docker.Deployment, aliceLocalpart string, + ) ( + alice *client.CSAPI, server *federation.Server, userDevicesQueryChannel chan string, + room *federation.ServerRoom, sendDeviceListUpdate func(string), cleanup func(), + ) { + alice = deployment.RegisterUser(t, "hs1", aliceLocalpart, "secret", false) + + userDevicesQueryChannel = make(chan string, 1) + + makeRespUserDeviceKeys := func( + userID string, deviceID string, + ) gomatrixserverlib.RespUserDeviceKeys { + return gomatrixserverlib.RespUserDeviceKeys{ + UserID: userID, + DeviceID: deviceID, + Algorithms: []string{ + "m.megolm.v1.aes-sha2", + }, + Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + "ed25519:JLAFKJWSCS": []byte("lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"), + }, + Signatures: map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + userID: { + "ed25519:JLAFKJWSCS": []byte("dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"), + }, + }, + } + } + + lastDeviceStreamID := int64(2) + server = createTestServer(t, deployment, + federation.HandleEventAuthRequests(), + func(server *federation.Server) { + server.Mux().HandleFunc("/_matrix/federation/v1/user/devices/{userID}", + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + t.Logf("Incoming %s %s", req.Method, req.URL.Path) + + vars := mux.Vars(req) + userID := vars["userID"] + deviceID := fmt.Sprintf("%s_device", userID) + + userDevicesQueryChannel <- userID + + // Make up a device list for the user. + responseBytes, _ := json.Marshal(gomatrixserverlib.RespUserDevices{ + UserID: userID, + StreamID: lastDeviceStreamID, + Devices: []gomatrixserverlib.RespUserDevice{ + { + DeviceID: deviceID, + DisplayName: fmt.Sprintf("%s's device", userID), + Keys: makeRespUserDeviceKeys(userID, deviceID), + }, + }, + }) + w.WriteHeader(200) + w.Write(responseBytes) + }), + ).Methods("GET") + }, + func(server *federation.Server) { + server.Mux().HandleFunc("/_matrix/federation/v1/user/keys/query", + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + t.Logf("Incoming %s %s", req.Method, req.URL.Path) + + body, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatalf("unable to read /user/keys/query request body: %s", err) + } + + var queryKeysRequest struct { + DeviceKeys map[string][]string `json:"device_keys"` + } + if err := json.Unmarshal(body, &queryKeysRequest); err != nil { + t.Fatalf("unable to unmarshall /user/keys/query request body: %s", err) + } + + // Make up keys for every device requested. + deviceKeys := make(map[string]map[string]gomatrixserverlib.DeviceKeys) + for userID := range queryKeysRequest.DeviceKeys { + userDevicesQueryChannel <- userID + + deviceID := fmt.Sprintf("%s_device", userID) + deviceKeys[userID] = map[string]gomatrixserverlib.DeviceKeys{ + deviceID: { + RespUserDeviceKeys: makeRespUserDeviceKeys(userID, deviceID), + }, + } + } + + responseBytes, _ := json.Marshal(gomatrixserverlib.RespQueryKeys{ + DeviceKeys: deviceKeys, + }) + w.WriteHeader(200) + w.Write(responseBytes) + }), + ).Methods("POST") + }, + ) + + cancel := server.Listen() + + room = createTestRoom(t, server, alice.GetDefaultRoomVersion(t)) + + sendDeviceListUpdate = func(localpart string) { + t.Helper() + + userID := server.UserID(localpart) + deviceID := fmt.Sprintf("%s_device", userID) + + // Advance the stream ID by 2 each time, so that the homeserver under test thinks it + // has missed an update and is forced to make a federation request to request the + // updated device list. + lastDeviceStreamID += 2 + + keys, _ := json.Marshal(makeRespUserDeviceKeys(userID, deviceID)) + deviceListUpdate, _ := json.Marshal(gomatrixserverlib.DeviceListUpdateEvent{ + UserID: userID, + DeviceID: deviceID, + DeviceDisplayName: fmt.Sprintf("%s's device", userID), + StreamID: lastDeviceStreamID, + PrevID: []int64{lastDeviceStreamID - 1}, + Deleted: false, + Keys: keys, + }) + server.MustSendTransaction(t, deployment, "hs1", []json.RawMessage{}, []gomatrixserverlib.EDU{ + { + Type: "m.device_list_update", + Origin: server.ServerName(), + Destination: "hs1", + Content: deviceListUpdate, + }, + }) + } + + cleanup = func() { + cancel() + close(userDevicesQueryChannel) + } + return + } + + // mustQueryKeys makes a /keys/query request to the homeserver under test. + mustQueryKeys := func(t *testing.T, user *client.CSAPI, userID string) { + t.Helper() + + user.MustDoFunc(t, "POST", []string{"_matrix", "client", "v3", "keys", "query"}, + client.WithJSONBody(t, map[string]interface{}{ + "device_keys": map[string]interface{}{ + userID: []string{}, + }, + }), + ) + } + + // mustQueryKeysWithFederationRequest makes a /keys/query request to the homeserver under + // test and checks that the complement homeserver has received a device list request since + // the previous call to `mustQueryKeysWithFederationRequest` or + // `mustQueryKeysWithoutFederationRequest`. + // Accepts the channel for device list requests returned by `setupDeviceListCachingTest`. + mustQueryKeysWithFederationRequest := func( + t *testing.T, user *client.CSAPI, userDevicesQueryChannel chan string, userID string, + ) { + t.Helper() + + mustQueryKeys(t, user, userID) + + if len(userDevicesQueryChannel) == 0 { + t.Fatalf("%s's device list was cached when it should not be.", userID) + } + + // Empty the channel. + for len(userDevicesQueryChannel) > 0 { + <-userDevicesQueryChannel + } + } + + // mustQueryKeysWithoutFederationRequest makes a /keys/query request to the homeserver under + // test and checks that the complement homeserver has not received a device list request + // since the previous call to `mustQueryKeysWithFederationRequest` or + // `mustQueryKeysWithoutFederationRequest`. + // Accepts the channel for device list requests returned by `setupDeviceListCachingTest`. + mustQueryKeysWithoutFederationRequest := func( + t *testing.T, user *client.CSAPI, userDevicesQueryChannel chan string, userID string, + ) { + t.Helper() + + mustQueryKeys(t, user, userID) + + if len(userDevicesQueryChannel) > 0 { + t.Fatalf("%s's device list was not cached when it should have been.", userID) + } + + // Empty the channel. + for len(userDevicesQueryChannel) > 0 { + <-userDevicesQueryChannel + } + } + + // syncDeviceListsHas checks that `device_lists.changed` or `device_lists.left` contains a + // given user ID. + syncDeviceListsHas := func(section string, expectedUserID string) client.SyncCheckOpt { + jsonPath := fmt.Sprintf("device_lists.%s", section) + return func(clientUserID string, topLevelSyncJSON gjson.Result) error { + usersWithChangedDeviceListsArray := topLevelSyncJSON.Get(jsonPath).Array() + for _, userID := range usersWithChangedDeviceListsArray { + if userID.Str == expectedUserID { + return nil + } + } + return fmt.Errorf( + "syncDeviceListsHas: %s not found in %s", + expectedUserID, + jsonPath, + ) + } + } + + // mustSyncUntilDeviceListsHas syncs until `device_lists.changed` or `device_lists.left` + // contains a given user ID. + // Also tests that /keys/changes returns the same information. + mustSyncUntilDeviceListsHas := func( + t *testing.T, user *client.CSAPI, syncToken string, section string, + expectedUserID string, + ) string { + t.Helper() + + nextSyncToken := user.MustSyncUntil( + t, + client.SyncReq{ + Since: syncToken, + Filter: buildLazyLoadingSyncFilter(nil), + }, + syncDeviceListsHas(section, expectedUserID), + ) + + res := user.MustDoFunc(t, "GET", []string{"_matrix", "client", "v3", "keys", "changes"}, + client.WithQueries(url.Values{ + "from": []string{syncToken}, + "to": []string{nextSyncToken}, + }), + ) + must.MatchResponse(t, res, match.HTTPResponse{ + StatusCode: 200, + JSON: []match.JSON{ + match.JSONCheckOffAllowUnwanted( + section, + []interface{}{expectedUserID}, + func(r gjson.Result) interface{} { return r.Str }, + nil, + ), + }, + }) + return nextSyncToken + } + }) } // test reception of an event over federation during a resync