From 8dbf1a05c6b694df26ea7a9fc0223b833441caea Mon Sep 17 00:00:00 2001 From: Michal Charemza Date: Tue, 19 Mar 2024 18:07:10 +0000 Subject: [PATCH] perf: GRANT all SELECT privileges in a single query This reduces the number of queries run when copying existing SELECT privileges, which happens at the end of the first batch in the cases when the ingest is not directly into the live table. The existing test that privileges are preserved is extended to make sure the `','.join(...` behaviour is correct --- pg_bulk_ingest.py | 9 +++++--- test_pg_bulk_ingest.py | 51 ++++++++++++++++++++++-------------------- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/pg_bulk_ingest.py b/pg_bulk_ingest.py index c2f93b7..0d7a378 100644 --- a/pg_bulk_ingest.py +++ b/pg_bulk_ingest.py @@ -434,11 +434,14 @@ def escape_string(text): ''').format(schema=sql.Literal(target_table.schema), table=sql.Literal(target_table.name)) .as_string(conn.connection.driver_connection) )).fetchall() - for grantee in grantees: - conn.execute(sa.text(sql.SQL('GRANT SELECT ON {schema_table} TO {user}') + if grantees: + conn.execute(sa.text(sql.SQL('GRANT SELECT ON {schema_table} TO {users}') .format( schema_table=sql.Identifier(ingest_table.schema, ingest_table.name), - user=sql.Identifier(grantee[0]), + users=sql.SQL(',').join( + sql.Identifier(grantee[0]) + for grantee in grantees + ), ) .as_string(conn.connection.driver_connection)) ) diff --git a/test_pg_bulk_ingest.py b/test_pg_bulk_ingest.py index 52820f2..d6b3fde 100644 --- a/test_pg_bulk_ingest.py +++ b/test_pg_bulk_ingest.py @@ -1089,28 +1089,30 @@ def batches_1(high_watermark): with engine.connect() as conn: ingest(conn, metadata_1, batches_1) - user_id = uuid.uuid4().hex[:16] - with engine.connect() as conn: - conn.execute(sa.text(sql.SQL(''' - CREATE USER {user_id} WITH PASSWORD 'password'; - ''').format(user_id=sql.Identifier(user_id)).as_string(conn.connection.driver_connection))) - conn.execute(sa.text(sql.SQL(''' - GRANT CONNECT ON DATABASE postgres TO {user_id}; - ''').format(user_id=sql.Identifier(user_id)).as_string(conn.connection.driver_connection))) - conn.execute(sa.text(sql.SQL(''' - GRANT USAGE ON SCHEMA my_schema TO {user_id}; - ''').format(user_id=sql.Identifier(user_id)).as_string(conn.connection.driver_connection))) - conn.commit() - conn.execute(sa.text(sql.SQL(''' - GRANT SELECT ON my_schema.{table} TO {user_id}; - ''').format(table=sql.Identifier(my_table_1.name), user_id=sql.Identifier(user_id)).as_string(conn.connection.driver_connection))) - conn.commit() - - user_engine = sa.create_engine(f'{engine_type}://{user_id}:password@127.0.0.1:5432/postgres', **engine_future) - with user_engine.connect() as conn: - results = conn.execute(sa.select(my_table_1).order_by('id')).fetchall() - - assert results == [(1, 'a', 'b')] + user_ids = [uuid.uuid4().hex[:16], uuid.uuid4().hex[:16]] + with engine.connect() as conn: + for user_id in user_ids: + conn.execute(sa.text(sql.SQL(''' + CREATE USER {user_id} WITH PASSWORD 'password'; + ''').format(user_id=sql.Identifier(user_id)).as_string(conn.connection.driver_connection))) + conn.execute(sa.text(sql.SQL(''' + GRANT CONNECT ON DATABASE postgres TO {user_id}; + ''').format(user_id=sql.Identifier(user_id)).as_string(conn.connection.driver_connection))) + conn.execute(sa.text(sql.SQL(''' + GRANT USAGE ON SCHEMA my_schema TO {user_id}; + ''').format(user_id=sql.Identifier(user_id)).as_string(conn.connection.driver_connection))) + conn.commit() + conn.execute(sa.text(sql.SQL(''' + GRANT SELECT ON my_schema.{table} TO {user_id}; + ''').format(table=sql.Identifier(my_table_1.name), user_id=sql.Identifier(user_id)).as_string(conn.connection.driver_connection))) + conn.commit() + + for user_id in user_ids: + user_engine = sa.create_engine(f'{engine_type}://{user_id}:password@127.0.0.1:5432/postgres', **engine_future) + with user_engine.connect() as conn: + results = conn.execute(sa.select(my_table_1).order_by('id')).fetchall() + + assert results == [(1, 'a', 'b')] metadata_2 = sa.MetaData() my_table_2 = sa.Table( @@ -1129,10 +1131,11 @@ def batches_2(high_watermark): with engine.connect() as conn: ingest(conn, metadata_2, batches_2) - with user_engine.connect() as conn: + for user_id in user_ids: + user_engine = sa.create_engine(f'{engine_type}://{user_id}:password@127.0.0.1:5432/postgres', **engine_future) results = conn.execute(sa.select(my_table_2).order_by('id')).fetchall() - assert results == [(1, 'a', None, 'b')] + assert results == [(1, 'a', None, 'b')] def test_migrate_add_column_not_at_end_no_data():