diff --git a/changelog.d/13822.misc b/changelog.d/13822.misc new file mode 100644 index 000000000000..dbc77cbcfabe --- /dev/null +++ b/changelog.d/13822.misc @@ -0,0 +1 @@ +Support providing an index predicate clause when doing upserts. diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index cf1eabc4376f..bf5e7ee7be7e 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -533,6 +533,7 @@ def register_background_index_update( index_name: name of index to add table: table to add index to columns: columns/expressions to include in index + where_clause: A WHERE clause to specify a partial unique index. unique: true to make a UNIQUE index psql_only: true to only create this index on psql databases (useful for virtual sqlite tables) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index e881bff7fb48..921cd4dc5ee0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1191,6 +1191,7 @@ def simple_upsert_txn( keyvalues: Dict[str, Any], values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, + where_clause: Optional[str] = None, lock: bool = True, ) -> bool: """ @@ -1203,6 +1204,7 @@ def simple_upsert_txn( keyvalues: The unique key tables and their new values values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting + where_clause: An index predicate to apply to the upsert. lock: True to lock the table when doing the upsert. Unused when performing a native upsert. Returns: @@ -1213,7 +1215,12 @@ def simple_upsert_txn( if table not in self._unsafe_to_upsert_tables: return self.simple_upsert_txn_native_upsert( - txn, table, keyvalues, values, insertion_values=insertion_values + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + where_clause=where_clause, ) else: return self.simple_upsert_txn_emulated( @@ -1222,6 +1229,7 @@ def simple_upsert_txn( keyvalues, values, insertion_values=insertion_values, + where_clause=where_clause, lock=lock, ) @@ -1232,6 +1240,7 @@ def simple_upsert_txn_emulated( keyvalues: Dict[str, Any], values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, + where_clause: Optional[str] = None, lock: bool = True, ) -> bool: """ @@ -1240,6 +1249,7 @@ def simple_upsert_txn_emulated( keyvalues: The unique key tables and their new values values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting + where_clause: An index predicate to apply to the upsert. lock: True to lock the table when doing the upsert. Returns: Returns True if a row was inserted or updated (i.e. if `values` is @@ -1259,14 +1269,17 @@ def _getwhere(key: str) -> str: else: return "%s = ?" % (key,) + # Generate a where clause of each keyvalue and optionally the provided + # index predicate. + where = [_getwhere(k) for k in keyvalues] + if where_clause: + where.append(where_clause) + if not values: # If `values` is empty, then all of the values we care about are in # the unique key, so there is nothing to UPDATE. We can just do a # SELECT instead to see if it exists. - sql = "SELECT 1 FROM %s WHERE %s" % ( - table, - " AND ".join(_getwhere(k) for k in keyvalues), - ) + sql = "SELECT 1 FROM %s WHERE %s" % (table, " AND ".join(where)) sqlargs = list(keyvalues.values()) txn.execute(sql, sqlargs) if txn.fetchall(): @@ -1277,7 +1290,7 @@ def _getwhere(key: str) -> str: sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in values), - " AND ".join(_getwhere(k) for k in keyvalues), + " AND ".join(where), ) sqlargs = list(values.values()) + list(keyvalues.values()) @@ -1307,6 +1320,7 @@ def simple_upsert_txn_native_upsert( keyvalues: Dict[str, Any], values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, + where_clause: Optional[str] = None, ) -> bool: """ Use the native UPSERT functionality in PostgreSQL. @@ -1316,6 +1330,7 @@ def simple_upsert_txn_native_upsert( keyvalues: The unique key tables and their new values values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting + where_clause: An index predicate to apply to the upsert. Returns: Returns True if a row was inserted or updated (i.e. if `values` is @@ -1331,11 +1346,12 @@ def simple_upsert_txn_native_upsert( allvalues.update(values) latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % ( table, ", ".join(k for k in allvalues), ", ".join("?" for _ in allvalues), ", ".join(k for k in keyvalues), + f"WHERE {where_clause}" if where_clause else "", latter, ) txn.execute(sql, list(allvalues.values()))