Skip to content

Commit

Permalink
feat(datafusion): add extract url fields functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored and cpcloud committed Jul 27, 2023
1 parent 0f49d9f commit 4f5ea98
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 2 deletions.
107 changes: 107 additions & 0 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import operator
from functools import partial, reduce, singledispatch
from urllib.parse import parse_qs, urlsplit

import datafusion as df
import datafusion.functions
Expand Down Expand Up @@ -703,3 +704,109 @@ def string_join(op, **kw):
)

return df.functions.concat_ws(sep, *(translate(arg, **kw) for arg in op.arg))


def extract_url_field_arrow(field_name, arr):
return pa.array([getattr(urlsplit(url), field_name, "") for url in arr.to_pylist()])


def register_extract_url_field_udf(field_name):
extract_url_field_udf = df.udf(
partial(extract_url_field_arrow, field_name),
input_types=[PyArrowType.from_ibis(dt.string)],
return_type=PyArrowType.from_ibis(dt.string),
volatility="immutable",
name=f"extract_{field_name}_udf",
)

return extract_url_field_udf


@translate.register(ops.ExtractFragment)
def extract_fragment(op, **kw):
extract_fragment_udf = register_extract_url_field_udf("fragment")
return extract_fragment_udf(translate(op.arg, **kw))


@translate.register(ops.ExtractProtocol)
def extract_protocol(op, **kw):
extract_protocol_udf = register_extract_url_field_udf("scheme")
return extract_protocol_udf(translate(op.arg, **kw))


@translate.register(ops.ExtractAuthority)
def extract_authority(op, **kw):
extract_authority_udf = register_extract_url_field_udf("netloc")
return extract_authority_udf(translate(op.arg, **kw))


@translate.register(ops.ExtractPath)
def extract_path(op, **kw):
extract_path_udf = register_extract_url_field_udf("path")
return extract_path_udf(translate(op.arg, **kw))


@translate.register(ops.ExtractHost)
def extract_host(op, **kw):
extract_host_udf = register_extract_url_field_udf("hostname")
return extract_host_udf(translate(op.arg, **kw))


def extract_user_info_arrow(arr):
def _extract_user_info(url):
url_parts = urlsplit(url)
username = url_parts.username or ""
password = url_parts.password or ""

return f"{username}:{password}"

return pa.array([_extract_user_info(url) for url in arr.to_pylist()])


@translate.register(ops.ExtractUserInfo)
def extract_user_info(op, **kw):
extract_user_info_udf = df.udf(
extract_user_info_arrow,
input_types=[PyArrowType.from_ibis(dt.string)],
return_type=PyArrowType.from_ibis(dt.string),
volatility="immutable",
name="extract_user_info_udf",
)

return extract_user_info_udf(translate(op.arg, **kw))


def extract_query_arrow(arr, param_name=None):
def _extract_query(url, param):
query = urlsplit(url).query
if param is not None:
value = parse_qs(query)[param]
return value if len(value) > 1 else value[0]
else:
return query

key = param_name[0].as_py() if param_name is not None else None
return pa.array([_extract_query(url, key) for url in arr.to_pylist()])


@translate.register(ops.ExtractQuery)
def extract_query(op, **kw):
arg = translate(op.arg, **kw)

input_types = [PyArrowType.from_ibis(dt.string)]
if op.key is not None:
input_types.append(PyArrowType.from_ibis(dt.string))

extract_query_udf = df.udf(
extract_query_arrow,
input_types=input_types,
return_type=PyArrowType.from_ibis(dt.string),
volatility="immutable",
name="extract_query_udf",
)

return (
extract_query_udf(arg, translate(op.key, **kw))
if op.key is not None
else extract_query_udf(arg)
)
4 changes: 2 additions & 2 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,8 @@ def test_substr_with_null_values(backend, alltypes, df):
id="file",
marks=[
pytest.mark.notimpl(
["pandas", "dask", "sqlite"], raises=com.OperationNotDefinedError
["pandas", "dask", "datafusion", "sqlite"],
raises=com.OperationNotDefinedError,
),
],
),
Expand All @@ -962,7 +963,6 @@ def test_substr_with_null_values(backend, alltypes, df):
@pytest.mark.notimpl(
[
"bigquery",
"datafusion",
"duckdb",
"mssql",
"mysql",
Expand Down

0 comments on commit 4f5ea98

Please sign in to comment.