Skip to content

Commit

Permalink
Rework DynamoDB parser to expect DynamoDB's format
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers committed Oct 14, 2023
1 parent 33e2433 commit 0bd33e9
Show file tree
Hide file tree
Showing 12 changed files with 493 additions and 51 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ result = parser.parse("SELECT * FROM s3object")
import json
from py_partiql_parser import DynamoDBStatementParser

original_json = json.dumps({"a1": "b1", "a2": "b2"})
parser = DynamoDBStatementParser(source_data={"table1", original_json})
parser = DynamoDBStatementParser(source_data={"table1": {"a1": {"S": "b1"}, "a2": {"S": "b2"}}})
result = parser.parse("SELECT * from table1 WHERE a1 = ?", parameters=[{"S": "b1"}])
```

Expand Down
25 changes: 21 additions & 4 deletions py_partiql_parser/_internal/from_parser.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Dict, Any
from typing import Any, Dict, List

from .clause_tokenizer import ClauseTokenizer
from .json_parser import JsonParser
from .utils import CaseInsensitiveDict

from ..exceptions import ParserException


class FromParser:
def __init__(self):
Expand Down Expand Up @@ -92,6 +94,7 @@ def get_source_data(self, documents: Dict[str, str]):
].endswith("]")

source_data = JsonParser().parse(documents[from_query])

if doc_is_list:
return {"_1": source_data}
elif from_alias:
Expand Down Expand Up @@ -153,6 +156,20 @@ def _get_nested_source_data(self, documents: Dict[str, str]):


class DynamoDBFromParser(FromParser):
def get_source_data(self, documents: Dict[str, str]):
source_data = documents[list(self.clauses.values())[0].lower()]
return JsonParser().parse(source_data)
def parse(self, from_clause) -> Dict[str, str]:
super().parse(from_clause)

for alias, table_name in list(self.clauses.items()):
if table_name[0].isnumeric():
raise ParserException(
"ValidationException", "Aliasing is not supported"
)

if table_name[0] == '"' and table_name[-1] == '"':
self.clauses[alias] = table_name[1:-1]

return self.clauses

def get_source_data(self, documents: Dict[str, List[Dict[str, Any]]]):
list_of_json_docs = documents[list(self.clauses.values())[0].lower()]
return [CaseInsensitiveDict(doc) for doc in list_of_json_docs]
5 changes: 5 additions & 0 deletions py_partiql_parser/_internal/json_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def __init__(self) -> None:


class JsonParser:
"""
Input can be a multiple documents, separated by a new-line (\n) characters
So we can't use the builtin JSON parser
"""

def parse(self, original, tokenizer=None, only_parse_initial=False) -> Any:
if not (original.startswith("{") or original.startswith("[")):
# Doesn't look like JSON - let's return as a variable
Expand Down
17 changes: 16 additions & 1 deletion py_partiql_parser/_internal/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def parse(self, query: str, parameters=None) -> List[Dict[str, Any]]:
# FROM
from_parser = self.from_parser()
from_clauses = from_parser.parse(clauses[2])

source_data = from_parser.get_source_data(self.documents)
if is_dict(source_data):
source_data = [source_data] # type: ignore
Expand Down Expand Up @@ -64,7 +65,21 @@ def __init__(self, source_data: Dict[str, str]):


class DynamoDBStatementParser(Parser):
def __init__(self, source_data: Dict[str, str]):
def __init__(self, source_data: Dict[str, List[Dict[str, Any]]]):
"""
Source Data should be a list of DynamoDB documents, mapped to the table name
{
"table_name": [
{
"hash_key": "..",
"other_item": {"S": ".."},
..
},
..
],
..
}
"""
super().__init__(
source_data,
table_prefix=None,
Expand Down
35 changes: 34 additions & 1 deletion py_partiql_parser/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def find_nested_data_in_object(
)
)
return result
elif isinstance(json_doc, CaseInsensitiveDict):
elif is_dict(json_doc):
if current_key not in json_doc:
return MissingVariable()
if remaining_keys:
Expand Down Expand Up @@ -99,6 +99,39 @@ def find_value_in_document(keys: List[str], json_doc):
return find_value_in_document(keys[1:], json_doc.get(keys[0], {}))


def find_value_in_dynamodb_document(keys: List[str], json_doc):
if not is_dict(json_doc):
return None
key_is_array = re.search(r"(.+)\[(\d+)\]$", keys[0])
if key_is_array:
key_name = key_is_array.group(1)
array_index = int(key_is_array.group(2))
try:
requested_list = json_doc.get(key_name, {})
assert "L" in requested_list
doc_one_layer_down = requested_list["L"][array_index]
if "M" in doc_one_layer_down:
doc_one_layer_down = doc_one_layer_down["M"]
except IndexError:
# Array exists, but does not have enough values
doc_one_layer_down = {}
except AssertionError:
# Requested key is not a list - fail silently just like AWS does
doc_one_layer_down = {}
return find_value_in_dynamodb_document(keys[1:], doc_one_layer_down)
if len(keys) == 1:
if "M" in json_doc:
return json_doc["M"].get(keys[0])
else:
return json_doc.get(keys[0])
nested_doc = json_doc.get(keys[0], {})
if "M" in nested_doc:
return find_value_in_dynamodb_document(keys[1:], nested_doc["M"])
# Key is not a map
# Or does not exist
return None


class QueryMetadata:
def __init__(
self, tables: Dict[str, str], where_clauses: List[Tuple[List[str], str]] = None
Expand Down
17 changes: 15 additions & 2 deletions py_partiql_parser/_internal/where_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

from .clause_tokenizer import ClauseTokenizer
from .utils import find_value_in_document
from .utils import find_value_in_dynamodb_document
from .._packages.boto3.types import TypeDeserializer, TypeSerializer


deserializer = TypeDeserializer()
serializer = TypeSerializer()


class WhereParser:
Expand Down Expand Up @@ -75,7 +81,10 @@ def parse(self, where_clause: str, parameters) -> Any:
_filters = WhereParser.parse_where_clause(where_clause)

_filters = [
(key, parameters.pop(0) if value == "?" else value)
(
key,
deserializer.deserialize(parameters.pop(0)) if value == "?" else value,
)
for key, value in _filters
]

Expand All @@ -84,7 +93,11 @@ def parse(self, where_clause: str, parameters) -> Any:
def filter_rows(self, _filters):
def _filter(row) -> bool:
return all(
[find_value_in_document(keys, row) == value for keys, value in _filters]
[
find_value_in_dynamodb_document(keys, row)
== serializer.serialize(value)
for keys, value in _filters
]
)

return [row for row in self.source_data if _filter(row)]
Expand Down
Empty file.
Empty file.
Loading

0 comments on commit 0bd33e9

Please sign in to comment.