diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bde0ee1a..33bca82b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -48,6 +48,8 @@ jobs: include: # Test with older Trino versions for backward compatibility - { python: "3.10", trino: "351" } # first Trino version + # Test with Trino version that requires result set to be fully exhausted + - { python: "3.10", trino: "395" } env: TRINO_VERSION: "${{ matrix.trino }}" steps: diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index c5d0dda0..99d3ae83 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -153,8 +153,8 @@ def test_execute_many_without_params(trino_connection): cur = trino_connection.cursor() cur.execute("CREATE TABLE memory.default.test_execute_many_without_param (value varchar)") cur.fetchall() - cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", []) with pytest.raises(TrinoUserError) as e: + cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", []) cur.fetchall() assert "Incorrect number of parameters: expected 1 but found 0" in str(e.value) @@ -883,13 +883,12 @@ def test_transaction_autocommit(trino_connection_in_autocommit): with trino_connection_in_autocommit as connection: connection.start_transaction() cur = connection.cursor() - cur.execute( - """ - CREATE TABLE memory.default.nation - AS SELECT * from tpch.tiny.nation - """) - with pytest.raises(TrinoUserError) as transaction_error: + cur.execute( + """ + CREATE TABLE memory.default.nation + AS SELECT * from tpch.tiny.nation + """) cur.fetchall() assert "Catalog only supports writes using autocommit: memory" \ in str(transaction_error.value) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9d968e07..26edb9dd 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -37,10 +37,10 @@ def sample_post_response_data(): """ yield { - "nextUri": "coordinator:8080/v1/statement/20210817_140827_00000_arvdv/1", + "nextUri": "https://coordinator:8080/v1/statement/20210817_140827_00000_arvdv/1", "id": "20210817_140827_00000_arvdv", "taskDownloadUris": [], - "infoUri": "http://coordinator:8080/query.html?20210817_140827_00000_arvdv", + "infoUri": "https://coordinator:8080/query.html?20210817_140827_00000_arvdv", "stats": { "scheduled": False, "runningSplits": 0, diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 12089305..d7da4983 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -892,9 +892,7 @@ def test_trino_result_response_headers(): 'X-Trino-Fake-2': 'two', }) - result = TrinoResult( - query=mock_trino_query, - ) + result = TrinoResult(query=mock_trino_query, rows=[]) assert result.response_headers == mock_trino_query.response_headers diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index 7b1c72c2..6f2cc50c 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -51,7 +51,7 @@ def test_http_session_is_defaulted_when_not_specified(mock_client): @httprettified -def test_token_retrieved_once_per_auth_instance(sample_post_response_data): +def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sample_get_response_data): token = str(uuid.uuid4()) challenge_id = str(uuid.uuid4()) @@ -59,13 +59,20 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data): token_server = f"{TOKEN_RESOURCE}/{challenge_id}" post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) + get_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_get_response_data) - # bind post statement + # bind post statement to submit query httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", body=post_statement_callback) + # bind get statement for result retrieval + httpretty.register_uri( + method=httpretty.GET, + uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", + body=get_statement_callback) + # bind get token get_token_callback = GetTokenCallback(token_server, token) httpretty.register_uri( @@ -108,7 +115,8 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data): @httprettified -def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data): +def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data, + sample_get_response_data): token = str(uuid.uuid4()) challenge_id = str(uuid.uuid4()) @@ -116,13 +124,20 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post token_server = f"{TOKEN_RESOURCE}/{challenge_id}" post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) + get_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_get_response_data) - # bind post statement + # bind post statement to submit query httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", body=post_statement_callback) + # bind get statement for result retrieval + httpretty.register_uri( + method=httpretty.GET, + uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", + body=get_statement_callback) + # bind get token get_token_callback = GetTokenCallback(token_server, token) httpretty.register_uri( @@ -166,7 +181,7 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post @httprettified -def test_token_retrieved_once_when_multithreaded(sample_post_response_data): +def test_token_retrieved_once_when_multithreaded(sample_post_response_data, sample_get_response_data): token = str(uuid.uuid4()) challenge_id = str(uuid.uuid4()) @@ -174,13 +189,20 @@ def test_token_retrieved_once_when_multithreaded(sample_post_response_data): token_server = f"{TOKEN_RESOURCE}/{challenge_id}" post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data) + get_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_get_response_data) - # bind post statement + # bind post statement to submit query httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", body=post_statement_callback) + # bind get statement for result retrieval + httpretty.register_uri( + method=httpretty.GET, + uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", + body=get_statement_callback) + # bind get token get_token_callback = GetTokenCallback(token_server, token) httpretty.register_uri( diff --git a/trino/client.py b/trino/client.py index 87981e0a..b973ff6b 100644 --- a/trino/client.py +++ b/trino/client.py @@ -592,30 +592,36 @@ class TrinoResult(object): https://docs.python.org/3/library/stdtypes.html#generator-types """ - def __init__(self, query, rows=None): + def __init__(self, query, rows: List[Any]): self._query = query - self._rows = rows or [] + # Initial rows from the first POST request + self._rows = rows self._rownumber = 0 + @property + def rows(self): + return self._rows + + @rows.setter + def rows(self, rows): + self._rows = rows + @property def rownumber(self) -> int: return self._rownumber def __iter__(self): - # Initial fetch from the first POST request - for row in self._rows: - self._rownumber += 1 - yield row - self._rows = None - - # Subsequent fetches from GET requests until next_uri is empty. - while not self._query.finished: - rows = self._query.fetch() - for row in rows: + # A query only transitions to a FINISHED state when the results are fully consumed: + # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi. + while not self._query.finished or self._rows is not None: + next_rows = self._query.fetch() if not self._query.finished else None + for row in self._rows: self._rownumber += 1 logger.debug("row %s", row) yield row + self._rows = next_rows + @property def response_headers(self): return self._query.response_headers @@ -641,7 +647,7 @@ def __init__( self._request = request self._update_type = None self._sql = sql - self._result = TrinoResult(self) + self._result: Optional[TrinoResult] = None self._response_headers = None self._experimental_python_types = experimental_python_types self._row_mapper: Optional[RowMapper] = None @@ -652,7 +658,7 @@ def columns(self): while not self._columns and not self.finished and not self.cancelled: # Columns are not returned immediately after query is submitted. # Continue fetching data until columns information is available and push fetched rows into buffer. - self._result._rows += self.fetch() + self._result.rows += self.fetch() return self._columns @property @@ -697,8 +703,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult: self._finished = True rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows - self._result = TrinoResult(self, rows) + + # Execute should block until at least one row is received + while not self.finished and not self.cancelled and len(self._result.rows) == 0: + self._result.rows += self.fetch() return self._result def _update_state(self, status): @@ -921,8 +930,7 @@ class RowMapper: """ Maps a row of data given a list of mapping functions """ - - def __init__(self, columns=[]): + def __init__(self, columns): self.columns = columns def map(self, rows): diff --git a/trino/dbapi.py b/trino/dbapi.py index 44813168..70fb43bb 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -322,7 +322,7 @@ def _prepare_statement(self, operation, statement_name): operation=operation ) - # Send prepare statement. Copy the _request object to avoid poluting the + # Send prepare statement. Copy the _request object to avoid polluting the # one that is going to be used to execute the actual operation. query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql, experimental_python_types=self._experimental_pyton_types) diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index 7c4409a0..e967cb6b 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -231,7 +231,7 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st """ ).strip() res = connection.execute(sql.text(query), schema=schema, view=view_name) - return res.scalar() + return res.scalar_one_or_none() def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: if not self.has_table(connection, table_name, schema): @@ -284,7 +284,7 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str sql.text(query), catalog_name=catalog_name, schema_name=schema_name, table_name=table_name ) - return dict(text=res.scalar()) + return dict(text=res.scalar_one_or_none()) except error.TrinoQueryError as e: if e.error_name in ( error.PERMISSION_DENIED, @@ -326,7 +326,7 @@ def _get_server_version_info(self, connection: Connection) -> Any: query = "SELECT version()" try: res = connection.execute(sql.text(query)) - version = res.scalar() + version = res.scalar_one() return tuple([version]) except exc.ProgrammingError as e: logger.debug(f"Failed to get server version: {e.orig.message}")