diff --git a/tests/federation_room_join_partial_state_test.go b/tests/federation_room_join_partial_state_test.go index feab40fe..cc5b78c8 100644 --- a/tests/federation_room_join_partial_state_test.go +++ b/tests/federation_room_join_partial_state_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "github.com/gorilla/mux" "github.com/tidwall/gjson" "github.com/matrix-org/gomatrix" @@ -1962,6 +1963,449 @@ func TestPartialStateJoin(t *testing.T) { ) }) }) + + // test that: + // * remote device lists are correctly cached or not cached + // * local users are told about potential device list changes in `/sync`'s + // `device_lists.changed/left` + // * local users are told about potential device list changes in `/keys/changes`. + t.Run("Device list tracking", func(t *testing.T) { + // setupDeviceListCachingTest sets up a complement homeserver. + // A room is created on the complement server, containing only local users. + // Returns a channel for device list requests arriving at the complement homeserver, which + // can be used with `mustQueryKeysWithFederationRequest` and + // `mustQueryKeysWithoutFederationRequest`. + setupDeviceListCachingTest := func( + t *testing.T, deployment *docker.Deployment, aliceLocalpart string, + ) ( + alice *client.CSAPI, server *federation.Server, userDevicesQueryChannel chan string, + room *federation.ServerRoom, sendDeviceListUpdate func(string), cleanup func(), + ) { + alice = deployment.RegisterUser(t, "hs1", aliceLocalpart, "secret", false) + + userDevicesQueryChannel = make(chan string, 1) + + makeRespUserDeviceKeys := func( + userID string, deviceID string, + ) gomatrixserverlib.RespUserDeviceKeys { + return gomatrixserverlib.RespUserDeviceKeys{ + UserID: userID, + DeviceID: deviceID, + Algorithms: []string{ + "m.megolm.v1.aes-sha2", + }, + Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + "ed25519:JLAFKJWSCS": []byte("lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"), + }, + Signatures: map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + userID: { + "ed25519:JLAFKJWSCS": []byte("dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"), + }, + }, + } + } + + lastDeviceStreamID := int64(2) + server = createTestServer(t, deployment, + federation.HandleEventAuthRequests(), + func(server *federation.Server) { + server.Mux().HandleFunc("/_matrix/federation/v1/user/devices/{userID}", + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + t.Logf("Incoming %s %s", req.Method, req.URL.Path) + + vars := mux.Vars(req) + userID := vars["userID"] + deviceID := fmt.Sprintf("%s_device", userID) + + userDevicesQueryChannel <- userID + + // Make up a device list for the user. + responseBytes, _ := json.Marshal(gomatrixserverlib.RespUserDevices{ + UserID: userID, + StreamID: lastDeviceStreamID, + Devices: []gomatrixserverlib.RespUserDevice{ + { + DeviceID: deviceID, + DisplayName: fmt.Sprintf("%s's device", userID), + Keys: makeRespUserDeviceKeys(userID, deviceID), + }, + }, + }) + w.WriteHeader(200) + w.Write(responseBytes) + }), + ).Methods("GET") + }, + func(server *federation.Server) { + server.Mux().HandleFunc("/_matrix/federation/v1/user/keys/query", + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + t.Logf("Incoming %s %s", req.Method, req.URL.Path) + + body, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatalf("unable to read /user/keys/query request body: %s", err) + } + + var queryKeysRequest struct { + DeviceKeys map[string][]string `json:"device_keys"` + } + if err := json.Unmarshal(body, &queryKeysRequest); err != nil { + t.Fatalf("unable to unmarshall /user/keys/query request body: %s", err) + } + + // Make up keys for every device requested. + deviceKeys := make(map[string]map[string]gomatrixserverlib.DeviceKeys) + for userID := range queryKeysRequest.DeviceKeys { + userDevicesQueryChannel <- userID + + deviceID := fmt.Sprintf("%s_device", userID) + deviceKeys[userID] = map[string]gomatrixserverlib.DeviceKeys{ + deviceID: { + RespUserDeviceKeys: makeRespUserDeviceKeys(userID, deviceID), + }, + } + } + + responseBytes, _ := json.Marshal(gomatrixserverlib.RespQueryKeys{ + DeviceKeys: deviceKeys, + }) + w.WriteHeader(200) + w.Write(responseBytes) + }), + ).Methods("POST") + }, + ) + + cancel := server.Listen() + + room = createTestRoom(t, server, alice.GetDefaultRoomVersion(t)) + + sendDeviceListUpdate = func(localpart string) { + t.Helper() + + userID := server.UserID(localpart) + deviceID := fmt.Sprintf("%s_device", userID) + + // Advance the stream ID by 2 each time, so that the homeserver under test thinks it + // has missed an update and is forced to make a federation request to request the + // updated device list. + lastDeviceStreamID += 2 + + keys, _ := json.Marshal(makeRespUserDeviceKeys(userID, deviceID)) + deviceListUpdate, _ := json.Marshal(gomatrixserverlib.DeviceListUpdateEvent{ + UserID: userID, + DeviceID: deviceID, + DeviceDisplayName: fmt.Sprintf("%s's device", userID), + StreamID: lastDeviceStreamID, + PrevID: []int64{lastDeviceStreamID - 1}, + Deleted: false, + Keys: keys, + }) + server.MustSendTransaction(t, deployment, "hs1", []json.RawMessage{}, []gomatrixserverlib.EDU{ + { + Type: "m.device_list_update", + Origin: server.ServerName(), + Destination: "hs1", + Content: deviceListUpdate, + }, + }) + } + + cleanup = func() { + cancel() + close(userDevicesQueryChannel) + } + return + } + + // mustQueryKeys makes a /keys/query request to the homeserver under test. + mustQueryKeys := func(t *testing.T, user *client.CSAPI, userID string) { + t.Helper() + + user.MustDoFunc(t, "POST", []string{"_matrix", "client", "v3", "keys", "query"}, + client.WithJSONBody(t, map[string]interface{}{ + "device_keys": map[string]interface{}{ + userID: []string{}, + }, + }), + ) + } + + // mustQueryKeysWithFederationRequest makes a /keys/query request to the homeserver under + // test and checks that the complement homeserver has received a device list request since + // the previous call to `mustQueryKeysWithFederationRequest` or + // `mustQueryKeysWithoutFederationRequest`. + // Accepts the channel for device list requests returned by `setupDeviceListCachingTest`. + mustQueryKeysWithFederationRequest := func( + t *testing.T, user *client.CSAPI, userDevicesQueryChannel chan string, userID string, + ) { + t.Helper() + + mustQueryKeys(t, user, userID) + + if len(userDevicesQueryChannel) == 0 { + t.Fatalf("%s's device list was cached when it should not be.", userID) + } + + // Empty the channel. + for len(userDevicesQueryChannel) > 0 { + <-userDevicesQueryChannel + } + } + + // mustQueryKeysWithoutFederationRequest makes a /keys/query request to the homeserver under + // test and checks that the complement homeserver has not received a device list request + // since the previous call to `mustQueryKeysWithFederationRequest` or + // `mustQueryKeysWithoutFederationRequest`. + // Accepts the channel for device list requests returned by `setupDeviceListCachingTest`. + mustQueryKeysWithoutFederationRequest := func( + t *testing.T, user *client.CSAPI, userDevicesQueryChannel chan string, userID string, + ) { + t.Helper() + + mustQueryKeys(t, user, userID) + + if len(userDevicesQueryChannel) > 0 { + t.Fatalf("%s's device list was not cached when it should have been.", userID) + } + + // Empty the channel. + for len(userDevicesQueryChannel) > 0 { + <-userDevicesQueryChannel + } + } + + // syncDeviceListsHas checks that `device_lists.changed` or `device_lists.left` contains a + // given user ID. + syncDeviceListsHas := func(section string, expectedUserID string) client.SyncCheckOpt { + jsonPath := fmt.Sprintf("device_lists.%s", section) + return func(clientUserID string, topLevelSyncJSON gjson.Result) error { + usersWithChangedDeviceListsArray := topLevelSyncJSON.Get(jsonPath).Array() + for _, userID := range usersWithChangedDeviceListsArray { + if userID.Str == expectedUserID { + return nil + } + } + return fmt.Errorf( + "syncDeviceListsHas: %s not found in %s", + expectedUserID, + jsonPath, + ) + } + } + + // mustSyncUntilDeviceListsHas syncs until `device_lists.changed` or `device_lists.left` + // contains a given user ID. + // Also tests that /keys/changes returns the same information. + mustSyncUntilDeviceListsHas := func( + t *testing.T, user *client.CSAPI, syncToken string, section string, + expectedUserID string, + ) string { + t.Helper() + + nextSyncToken := user.MustSyncUntil( + t, + client.SyncReq{ + Since: syncToken, + Filter: buildLazyLoadingSyncFilter(nil), + }, + syncDeviceListsHas(section, expectedUserID), + ) + + res := user.MustDoFunc(t, "GET", []string{"_matrix", "client", "v3", "keys", "changes"}, + client.WithQueries(url.Values{ + "from": []string{syncToken}, + "to": []string{nextSyncToken}, + }), + ) + must.MatchResponse(t, res, match.HTTPResponse{ + StatusCode: 200, + JSON: []match.JSON{ + match.JSONCheckOffAllowUnwanted( + section, + []interface{}{expectedUserID}, + func(r gjson.Result) interface{} { return r.Str }, + nil, + ), + }, + }) + return nextSyncToken + } + + // tests device list tracking for pre-existing members in a room with partial state. + // Tests that: + // * device lists are not cached for pre-existing members. + // * device list updates received while the room has partial state are sent to clients once + // fully joined. + t.Run("Device list tracking for pre-existing members in partial state room", func(t *testing.T) { + alice, server, userDevicesChannel, room, sendDeviceListUpdate, cleanup := setupDeviceListCachingTest(t, deployment, "t30alice") + defer cleanup() + + // The room starts with @charlie and @derek in it. + + // @t30alice:hs1 joins the room. + psjResult := beginPartialStateJoin(t, server, room, alice) + defer psjResult.Destroy() + + // @charlie and @derek's device list ought to not be cached. + mustQueryKeysWithFederationRequest(t, alice, userDevicesChannel, server.UserID("charlie")) + mustQueryKeysWithFederationRequest(t, alice, userDevicesChannel, server.UserID("derek")) + mustQueryKeysWithFederationRequest(t, alice, userDevicesChannel, server.UserID("charlie")) + mustQueryKeysWithFederationRequest(t, alice, userDevicesChannel, server.UserID("derek")) + + // @charlie sends a message. + // Depending on the homeserver implementation, @t30alice:hs1 may be told that @charlie's devices are being tracked. + event := psjResult.CreateMessageEvent(t, "charlie", nil) + psjResult.Server.MustSendTransaction(t, deployment, "hs1", []json.RawMessage{event.JSON()}, nil) + syncToken := awaitEventViaSync(t, alice, psjResult.ServerRoom.RoomID, event.EventID(), "") + + // @charlie updates their device list. + // Depending on the homeserver implementation, @t30alice:hs1 may or may not see the update, + // independent of what they were told about the tracking of @charlie's device list earlier. + sendDeviceListUpdate("charlie") + + // Before completing the partial state join, try to wait for the homeserver to finish processing the device list update. + event = psjResult.CreateMessageEvent(t, "charlie", nil) + psjResult.Server.MustSendTransaction(t, deployment, "hs1", []json.RawMessage{event.JSON()}, nil) + awaitEventViaSync(t, alice, psjResult.ServerRoom.RoomID, event.EventID(), syncToken) + + // Finish the partial state join. + psjResult.FinishStateRequest() + awaitPartialStateJoinCompletion(t, room, alice) + + // @charlie's device list update ought to have arrived by now. + mustSyncUntilDeviceListsHas(t, alice, syncToken, "changed", server.UserID("charlie")) + + // Cache @charlie and @derek's device lists. + mustQueryKeysWithFederationRequest(t, alice, userDevicesChannel, server.UserID("charlie")) + mustQueryKeysWithFederationRequest(t, alice, userDevicesChannel, server.UserID("derek")) + + // @charlie and @derek's device lists ought to be cached now. + mustQueryKeysWithoutFederationRequest(t, alice, userDevicesChannel, server.UserID("charlie")) + mustQueryKeysWithoutFederationRequest(t, alice, userDevicesChannel, server.UserID("derek")) + }) + + // test device list tracking when a pre-existing member in a room with partial state joins + // another shared room and starts being tracked for real. + t.Run("Device list tracking when pre-existing members in partial state room join another shared room", func(t *testing.T) { + alice, server, _, room, sendDeviceListUpdate, cleanup := setupDeviceListCachingTest(t, deployment, "t31alice") + defer cleanup() + + // The room starts with @charlie and @derek in it. + + // @t31alice:hs1 joins the room. + psjResult := beginPartialStateJoin(t, server, room, alice) + defer psjResult.Destroy() + + // @charlie sends a message. + // Depending on the homeserver implementation, @t31alice:hs1 may be told that @charlie's devices are being tracked. + event := psjResult.CreateMessageEvent(t, "charlie", nil) + psjResult.Server.MustSendTransaction(t, deployment, "hs1", []json.RawMessage{event.JSON()}, nil) + syncToken := awaitEventViaSync(t, alice, psjResult.ServerRoom.RoomID, event.EventID(), "") + + // @charlie updates their device list. + // Depending on the homeserver implementation, @t31alice:hs1 may or may not see the update, + // independent of what they were told about the tracking of @charlie's device list earlier. + sendDeviceListUpdate("charlie") + + // @alice:hs1 creates a public room. + otherRoomID := alice.CreateRoom(t, map[string]interface{}{"preset": "public_chat"}) + + // @charlie joins the room. + // Now @charlie's device list is definitely being tracked. + server.MustJoinRoom(t, deployment, "hs1", otherRoomID, server.UserID("charlie")) + alice.MustSyncUntil(t, + client.SyncReq{ + Since: syncToken, + Filter: buildLazyLoadingSyncFilter(nil), + }, + client.SyncJoinedTo(server.UserID("charlie"), otherRoomID), + ) + + // Depending on the homeserver implementation, @t31alice:hs1 must have been told that either: + // * charlie updated their device list, or + // * charlie's device list is being tracked now, for real. + mustSyncUntilDeviceListsHas(t, alice, syncToken, "changed", server.UserID("charlie")) + }) + + // test device list tracking for users that join after the local homeserver. + // It is expected that device list tracking works as normal for such users. + t.Run("Device list tracked for new members in partial state room", func(t *testing.T) { + alice, server, userDevicesChannel, room, sendDeviceListUpdate, cleanup := setupDeviceListCachingTest(t, deployment, "t32alice") + defer cleanup() + + // The room starts with @charlie and @derek in it. + + // @t32alice:hs1 joins the room. + psjResult := beginPartialStateJoin(t, server, room, alice) + defer psjResult.Destroy() + + syncToken := getSyncToken(t, alice) + + // @elsie joins the room. + joinEvent := createJoinEvent(t, server, room, server.UserID("elsie")) + room.AddEvent(joinEvent) + server.MustSendTransaction(t, deployment, "hs1", []json.RawMessage{joinEvent.JSON()}, nil) + awaitEventViaSync(t, alice, room.RoomID, joinEvent.EventID(), syncToken) + + // @elsie's device list ought to be cached. + syncToken = mustSyncUntilDeviceListsHas(t, alice, syncToken, "changed", server.UserID("elsie")) + mustQueryKeysWithFederationRequest(t, alice, userDevicesChannel, server.UserID("elsie")) + mustQueryKeysWithoutFederationRequest(t, alice, userDevicesChannel, server.UserID("elsie")) + + // @elsie updates their device list. + // @t32alice:hs1 ought to be notified. + sendDeviceListUpdate("elsie") + mustSyncUntilDeviceListsHas(t, alice, syncToken, "changed", server.UserID("elsie")) + mustQueryKeysWithFederationRequest(t, alice, userDevicesChannel, server.UserID("elsie")) + mustQueryKeysWithoutFederationRequest(t, alice, userDevicesChannel, server.UserID("elsie")) + + // Finish the partial state join. + psjResult.FinishStateRequest() + awaitPartialStateJoinCompletion(t, room, alice) + + // @elsie's device list ought to still be cached. + mustQueryKeysWithoutFederationRequest(t, alice, userDevicesChannel, server.UserID("elsie")) + }) + + // test that device lists stop being tracked when a user leaves before the partial state + // join completes. + // Similar to the previous test, except @elsie leaves before the partial state join + // completes. + t.Run("Device list no longer tracked when new member leaves partial state room", func(t *testing.T) { + alice, server, userDevicesChannel, room, _, cleanup := setupDeviceListCachingTest(t, deployment, "t33alice") + defer cleanup() + + // The room starts with @charlie and @derek in it. + + // @t33alice:hs1 joins the room. + psjResult := beginPartialStateJoin(t, server, room, alice) + defer psjResult.Destroy() + + syncToken := getSyncToken(t, alice) + + // @elsie joins the room. + joinEvent := createJoinEvent(t, server, room, server.UserID("elsie")) + room.AddEvent(joinEvent) + server.MustSendTransaction(t, deployment, "hs1", []json.RawMessage{joinEvent.JSON()}, nil) + awaitEventViaSync(t, alice, room.RoomID, joinEvent.EventID(), syncToken) + + // @elsie's device list ought to be cached. + syncToken = mustSyncUntilDeviceListsHas(t, alice, syncToken, "changed", server.UserID("elsie")) + mustQueryKeysWithFederationRequest(t, alice, userDevicesChannel, server.UserID("elsie")) + mustQueryKeysWithoutFederationRequest(t, alice, userDevicesChannel, server.UserID("elsie")) + + // @elsie leaves the room. + leaveEvent := createLeaveEvent(t, server, room, server.UserID("elsie")) + room.AddEvent(leaveEvent) + server.MustSendTransaction(t, deployment, "hs1", []json.RawMessage{leaveEvent.JSON()}, nil) + awaitEventViaSync(t, alice, room.RoomID, leaveEvent.EventID(), syncToken) + + // @elsie's device list ought to no longer be cached. + mustSyncUntilDeviceListsHas(t, alice, syncToken, "left", server.UserID("elsie")) + mustQueryKeysWithFederationRequest(t, alice, userDevicesChannel, server.UserID("elsie")) + }) + }) } // test reception of an event over federation during a resync