Skip to content
This repository has been archived by the owner on Nov 14, 2024. It is now read-only.

Fix RewritesState bug #1557

Merged
merged 10 commits into from
Oct 22, 2020
6 changes: 6 additions & 0 deletions federationsender/consumers/roomserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
case api.OutputTypeNewRoomEvent:
ev := &output.NewRoomEvent.Event

if output.NewRoomEvent.RewritesState {
if err := s.db.PurgeRoomState(context.TODO(), ev.RoomID()); err != nil {
return fmt.Errorf("s.db.PurgeRoom: %w", err)
}
}

if err := s.processMessage(*output.NewRoomEvent); err != nil {
// panic rather than continue with an inconsistent database
log.WithFields(log.Fields{
Expand Down
1 change: 1 addition & 0 deletions federationsender/storage/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Database interface {
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
PurgeRoomState(ctx context.Context, roomID string) error

StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)

Expand Down
15 changes: 15 additions & 0 deletions federationsender/storage/postgres/joined_hosts_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ const insertJoinedHostsSQL = "" +
const deleteJoinedHostsSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE event_id = ANY($1)"

const deleteJoinedHostsForRoomSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE room_id = $1"

const selectJoinedHostsSQL = "" +
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
" WHERE room_id = $1"
Expand All @@ -67,6 +70,7 @@ type joinedHostsStatements struct {
db *sql.DB
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
deleteJoinedHostsForRoomStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
selectJoinedHostsForRoomsStmt *sql.Stmt
Expand All @@ -86,6 +90,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro
if s.deleteJoinedHostsStmt, err = s.db.Prepare(deleteJoinedHostsSQL); err != nil {
return
}
if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil {
return
}
if s.selectJoinedHostsStmt, err = s.db.Prepare(selectJoinedHostsSQL); err != nil {
return
}
Expand Down Expand Up @@ -117,6 +124,14 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
return err
}

func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
_, err := stmt.ExecContext(ctx, roomID)
return err
}

func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
ctx context.Context, txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) {
Expand Down
14 changes: 14 additions & 0 deletions federationsender/storage/shared/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,20 @@ func (d *Database) StoreJSON(
}, nil
}

func (d *Database) PurgeRoomState(
ctx context.Context, roomID string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// If the event is a create event then we'll delete all of the existing
// data for the room. The only reason that a create event would be replayed
// to us in this way is if we're about to receive the entire room state.
if err := d.FederationSenderJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.FederationSenderJoinedHosts.DeleteJoinedHosts: %w", err)
}
return nil
})
}

func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
Expand Down
25 changes: 20 additions & 5 deletions federationsender/storage/sqlite3/joined_hosts_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ const insertJoinedHostsSQL = "" +
const deleteJoinedHostsSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE event_id = $1"

const deleteJoinedHostsForRoomSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE room_id = $1"

const selectJoinedHostsSQL = "" +
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
" WHERE room_id = $1"
Expand All @@ -64,11 +67,12 @@ const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"

type joinedHostsStatements struct {
db *sql.DB
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
db *sql.DB
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
deleteJoinedHostsForRoomStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
}

Expand All @@ -86,6 +90,9 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error)
if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
return
}
if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil {
return
}
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
return
}
Expand Down Expand Up @@ -118,6 +125,14 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
return nil
}

func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
_, err := stmt.ExecContext(ctx, roomID)
return err
}

func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
ctx context.Context, txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) {
Expand Down
1 change: 1 addition & 0 deletions federationsender/storage/tables/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type FederationSenderQueueJSON interface {
type FederationSenderJoinedHosts interface {
InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error
DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error
DeleteJoinedHostsForRoom(ctx context.Context, txn *sql.Tx, roomID string) error
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
Expand Down
12 changes: 3 additions & 9 deletions roomserver/internal/input/input_latest_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ type latestEventsUpdater struct {

func (u *latestEventsUpdater) doUpdateLatestEvents() error {
u.lastEventIDSent = u.updater.LastEventIDSent()
u.oldStateNID = u.updater.CurrentStateSnapshotNID()

// If we are doing a regular event update then we will get the
// previous latest events to use as a part of the calculation. If
Expand All @@ -125,7 +124,8 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// then start with an empty set - none of the forward extremities
// that we knew about before matter anymore.
oldLatest := []types.StateAtEventAndReference{}
if !u.stateAtEvent.Overwrite {
if !u.rewritesState {
u.oldStateNID = u.updater.CurrentStateSnapshotNID()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whilst this makes sense, why did you change it so it's only set if we don't rewrite state? I think updater.CurrentStateSnapshotNID() just returns an int - the New... function did the DB query so it's not even an optimisation. Maybe to act as a guard against accidentally fiddling with it and calculating deltas against things we shouldn't?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that, if we are rewriting state, we want the state snapshot NID to be zero because that means that the deltas produced later on will be against an empty snapshot (effectively a complete rewrite, where AddsStateEventIDs will contain the entire new state).

If we aren't rewriting state, we want to know what the current state is so that we produce a real delta.

oldLatest = u.updater.LatestEvents()
}

Expand Down Expand Up @@ -153,7 +153,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// Now that we know what the latest events are, it's time to get the
// latest state.
var updates []api.OutputEvent
if extremitiesChanged {
if extremitiesChanged || u.rewritesState {
if err = u.latestState(); err != nil {
return fmt.Errorf("u.latestState: %w", err)
}
Expand Down Expand Up @@ -324,7 +324,6 @@ func (u *latestEventsUpdater) calculateLatest(
}

func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) {

latestEventIDs := make([]string, len(u.latest))
for i := range u.latest {
latestEventIDs[i] = u.latest[i].EventID
Expand Down Expand Up @@ -365,11 +364,6 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
return nil, fmt.Errorf("failed to load add_state_events from db: %w", err)
}
}
// State is rewritten if the input room event HasState and we actually produced a delta on state events.
// Without this check, /get_missing_events which produce events with associated (but not complete) state
// will incorrectly purge the room and set it to no state. TODO: This is likely flakey, as if /gme produced
// a state conflict res which just so happens to include 2+ events we might purge the room state downstream.
ore.RewritesState = len(ore.AddsStateEventIDs) > 1

return &api.OutputEvent{
Type: api.OutputTypeNewRoomEvent,
Expand Down
2 changes: 1 addition & 1 deletion roomserver/roomserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ func TestOutputRewritesState(t *testing.T) {
if len(producer.producedMessages) != 1 {
t.Fatalf("Rewritten events got output, want only 1 got %d", len(producer.producedMessages))
}
outputEvent := producer.producedMessages[0]
outputEvent := producer.producedMessages[len(producer.producedMessages)-1]
if !outputEvent.NewRoomEvent.RewritesState {
t.Errorf("RewritesState flag not set on output event")
}
Expand Down
2 changes: 1 addition & 1 deletion syncapi/consumers/roomserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
}

if msg.RewritesState {
if err = s.db.PurgeRoom(ctx, ev.RoomID()); err != nil {
if err = s.db.PurgeRoomState(ctx, ev.RoomID()); err != nil {
return fmt.Errorf("s.db.PurgeRoom: %w", err)
}
}
Expand Down
4 changes: 2 additions & 2 deletions syncapi/storage/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ type Database interface {
// Returns an error if there was a problem inserting this event.
WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []gomatrixserverlib.HeaderedEvent,
addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error)
// PurgeRoom completely purges room state from the sync API. This is done when
// PurgeRoomState completely purges room state from the sync API. This is done when
// receiving an output event that completely resets the state.
PurgeRoom(ctx context.Context, roomID string) error
PurgeRoomState(ctx context.Context, roomID string) error
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
// If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error
Expand Down
11 changes: 1 addition & 10 deletions syncapi/storage/shared/syncserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
return nil
}

func (d *Database) PurgeRoom(
func (d *Database) PurgeRoomState(
ctx context.Context, roomID string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
Expand All @@ -286,15 +286,6 @@ func (d *Database) PurgeRoom(
if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err)
}
if err := d.OutputEvents.DeleteEventsForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.Events.DeleteEventsForRoom: %w", err)
}
if err := d.Topology.DeleteTopologyForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.Topology.DeleteTopologyForRoom: %w", err)
}
if err := d.BackwardExtremities.DeleteBackwardExtremitiesForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.BackwardExtremities.DeleteBackwardExtremitiesForRoom: %w", err)
}
return nil
})
}
Expand Down