Skip to content

Commit

Permalink
fix(trino): parse URL passed to ibis.connect
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth authored and jcrist committed May 22, 2024
1 parent 56e0b38 commit e3ee67b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
18 changes: 16 additions & 2 deletions ibis/backends/trino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import cached_property
from operator import itemgetter
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

import sqlglot as sg
import sqlglot.expressions as sge
Expand All @@ -16,7 +17,7 @@
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis import util
from ibis.backends import CanCreateDatabase, CanCreateSchema, CanListCatalog, NoUrl
from ibis.backends import CanCreateDatabase, CanCreateSchema, CanListCatalog
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import C
from ibis.backends.trino.compiler import TrinoCompiler
Expand All @@ -30,12 +31,25 @@
import ibis.expr.operations as ops


class Backend(SQLBackend, CanListCatalog, CanCreateDatabase, CanCreateSchema, NoUrl):
class Backend(SQLBackend, CanListCatalog, CanCreateDatabase, CanCreateSchema):
name = "trino"
compiler = TrinoCompiler()
supports_create_or_replace = False
supports_temporary_tables = False

def _from_url(self, url: str, **kwargs):
url = urlparse(url)
catalog, db = url.path.strip("/").split("/")
self.do_connect(
user=url.username or None,
password=url.password or None,
host=url.hostname or None,
port=url.port or None,
database=catalog,
schema=db,
)
return self

def raw_sql(self, query: str | sg.Expression) -> Any:
"""Execute a raw SQL query."""
with contextlib.suppress(AttributeError):
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/trino/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import math
import os
import string

import pytest
Expand Down Expand Up @@ -183,3 +184,15 @@ def test_list_tables_schema_warning_refactor(con):

assert con.list_tables(database="tpch.sf1") == tpch_tables
assert con.list_tables(database=("tpch", "sf1")) == tpch_tables


def test_connect_uri():
TRINO_USER = os.getenv("IBIS_TEST_TRINO_USER", os.getenv("TRINO_USER", "user"))
TRINO_PASS = os.getenv("IBIS_TEST_TRINO_PASSWORD", os.getenv("TRINO_PASSWORD", ""))
TRINO_HOST = os.getenv("IBIS_TEST_TRINO_HOST", os.getenv("TRINO_HOST", "localhost"))
TRINO_PORT = int(os.getenv("IBIS_TEST_TRINO_PORT", os.getenv("TRINO_PORT", "8080")))
con = ibis.connect(
f"trino://{TRINO_USER}:{TRINO_PASS}@{TRINO_HOST}:{TRINO_PORT}/memory/default"
)

assert con.list_tables()

0 comments on commit e3ee67b

Please sign in to comment.