Skip to content

Commit

Permalink
Fix nested struct append.
Browse files Browse the repository at this point in the history
  • Loading branch information
igorborgest committed Dec 18, 2020
1 parent 0563164 commit cf42d6b
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
34 changes: 31 additions & 3 deletions awswrangler/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import re
from decimal import Decimal
from typing import Any, Callable, Dict, List, Match, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Match, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -189,9 +189,35 @@ def pyarrow2timestream(dtype: pa.DataType) -> str: # pylint: disable=too-many-b
raise exceptions.UnsupportedType(f"Unsupported Amazon Timestream measure type: {dtype}")


def _split_fields(s: str) -> Iterator[str]:
counter: int = 0
last: int = 0
for i, x in enumerate(s):
if x == "<":
counter += 1
elif x == ">":
counter -= 1
elif x == "," and counter == 0:
yield s[last:i]
last = i + 1
yield s[last:]


def _split_struct(s: str) -> List[str]:
return list(_split_fields(s=s))


def _split_map(s: str) -> List[str]:
parts: List[str] = list(_split_fields(s=s))
if len(parts) != 2:
raise RuntimeError(f"Invalid map fields: {s}")
return parts


def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-return-statements
"""Athena to PyArrow data types conversion."""
dtype = dtype.lower().replace(" ", "")
print(dtype)
if dtype == "tinyint":
return pa.int8()
if dtype == "smallint":
Expand Down Expand Up @@ -220,9 +246,10 @@ def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-retur
if dtype.startswith("array") is True:
return pa.list_(value_type=athena2pyarrow(dtype=dtype[6:-1]), list_size=-1)
if dtype.startswith("struct") is True:
return pa.struct([(f.split(":", 1)[0], athena2pyarrow(f.split(":", 1)[1])) for f in dtype[7:-1].split(",")])
return pa.struct([(f.split(":", 1)[0], athena2pyarrow(f.split(":", 1)[1])) for f in _split_struct(dtype[7:-1])])
if dtype.startswith("map") is True:
return pa.map_(athena2pyarrow(dtype[4:-1].split(",", 1)[0]), athena2pyarrow(dtype[4:-1].split(",", 1)[1]))
parts: List[str] = _split_map(s=dtype[4:-1])
return pa.map_(athena2pyarrow(parts[0]), athena2pyarrow(parts[1]))
raise exceptions.UnsupportedType(f"Unsupported Athena type: {dtype}")


Expand Down Expand Up @@ -491,6 +518,7 @@ def pyarrow_schema_from_pandas(
) -> pa.Schema:
"""Extract the related Pyarrow Schema from any Pandas DataFrame."""
casts: Dict[str, str] = {} if dtype is None else dtype
_logger.debug("casts: %s", casts)
ignore: List[str] = [] if ignore_cols is None else ignore_cols
ignore_plus = ignore + list(casts.keys())
columns_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas(
Expand Down
33 changes: 33 additions & 0 deletions tests/test_athena_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest

import awswrangler as wr
from awswrangler._data_types import _split_fields

from ._utils import ensure_data_types, get_df, get_df_cast, get_df_list

Expand Down Expand Up @@ -674,3 +675,35 @@ def test_cast_decimal(path, glue_table, glue_database):
assert df2["c1"].iloc[0] == Decimal((0, (1, 0, 0, 1), -1))
assert df2["c2"].iloc[0] == Decimal((0, (1, 0, 0, 1), -1))
assert df2["c3"].iloc[0] == "100.1"


def test_splits():
s = "a:struct<id:string,name:string>,b:struct<id:string,name:string>"
assert list(_split_fields(s)) == ["a:struct<id:string,name:string>", "b:struct<id:string,name:string>"]
s = "a:struct<a:struct<id:string,name:string>,b:struct<id:string,name:string>>,b:struct<a:struct<id:string,name:string>,b:struct<id:string,name:string>>" # noqa
assert list(_split_fields(s)) == [
"a:struct<a:struct<id:string,name:string>,b:struct<id:string,name:string>>",
"b:struct<a:struct<id:string,name:string>,b:struct<id:string,name:string>>",
]
s = "a:struct<id:string,name:string>,b:struct<id:string,name:string>,c:struct<id:string,name:string>,d:struct<id:string,name:string>" # noqa
assert list(_split_fields(s)) == [
"a:struct<id:string,name:string>",
"b:struct<id:string,name:string>",
"c:struct<id:string,name:string>",
"d:struct<id:string,name:string>",
]


def test_to_parquet_nested_structs(glue_database, glue_table, path):
df = pd.DataFrame(
{
"c0": [1],
"c1": [[{"a": {"id": "0", "name": "foo", "amount": 1}, "b": {"id": "1", "name": "boo", "amount": 2}}]],
}
)
wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table)
df2 = wr.athena.read_sql_query(sql=f"SELECT * FROM {glue_table}", database=glue_database)
assert df2.shape == (1, 2)
wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table)
df3 = wr.athena.read_sql_query(sql=f"SELECT * FROM {glue_table}", database=glue_database)
assert df3.shape == (2, 2)

0 comments on commit cf42d6b

Please sign in to comment.