Skip to content

Commit

Permalink
enhance(sdk): add map type support
Browse files Browse the repository at this point in the history
  • Loading branch information
xuchuan committed Nov 14, 2022
1 parent fc5e2ad commit 5161d30
Show file tree
Hide file tree
Showing 2 changed files with 289 additions and 1 deletion.
71 changes: 71 additions & 0 deletions client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def encode_schema(type: "SwType") -> Dict[str, Any]:
"type": "TUPLE",
"elementType": SwType.encode_schema(type.element_type),
}
if isinstance(type, SwMapType):
return {
"type": "MAP",
"keyType": SwType.encode_schema(type.key_type),
"valueType": SwType.encode_schema(type.value_type),
}
if isinstance(type, SwObjectType):
ret = {
"type": "OBJECT",
Expand Down Expand Up @@ -117,6 +123,16 @@ def decode_schema(schema: Dict[str, Any]) -> "SwType":
if element_type is None:
raise RuntimeError("no element type found for type TUPLE")
return SwTupleType(SwType.decode_schema(element_type))
if type_name == "MAP":
key_type = schema.get("keyType", None)
value_type = schema.get("valueType", None)
if key_type is None:
raise RuntimeError("no key type found for type MAP")
if value_type is None:
raise RuntimeError("no value type found for type MAP")
return SwMapType(
SwType.decode_schema(key_type), SwType.decode_schema(value_type)
)
if type_name == "OBJECT":
raw_type_name = schema.get("pythonType", None)
if raw_type_name is None:
Expand Down Expand Up @@ -334,6 +350,54 @@ def __eq__(self, other: Any) -> bool:
return False


class SwMapType(SwCompositeType):
def __init__(self, key_type: SwType, value_type: SwType) -> None:
super().__init__("map")
self.key_type = key_type
self.value_type = value_type

def merge(self, type: SwType) -> SwType:
if isinstance(type, SwMapType):
kt = self.key_type.merge(type.key_type)
vt = self.value_type.merge(type.value_type)
if kt is self.key_type and vt is self.value_type:
return self
if kt is type.key_type and vt is type.value_type:
return type
return SwMapType(kt, vt)
raise RuntimeError(f"conflicting type {self} and {type}")

def encode(self, value: Any) -> Any:
if value is None:
return None
if isinstance(value, dict):
return {
self.key_type.encode(k): self.value_type.encode(v)
for k, v in value.items()
}
raise RuntimeError(f"value should be a dict: {value}")

def decode(self, value: Any) -> Any:
if value is None:
return None
if isinstance(value, dict):
return {
self.key_type.decode(k): self.value_type.decode(v)
for k, v in value.items()
}
raise RuntimeError(f"value should be a dict: {value}")

def __str__(self) -> str:
return f"{{{self.key_type}:{self.value_type}}}"

def __eq__(self, other: Any) -> bool:
if isinstance(other, SwMapType):
return (
self.key_type == other.key_type and self.value_type == other.value_type
)
return False


class SwObjectType(SwCompositeType):
def __init__(self, raw_type: Type, attrs: Dict[str, SwType]) -> None:
super().__init__("object")
Expand Down Expand Up @@ -444,6 +508,13 @@ def _get_type(obj: Any) -> SwType:
for element in obj:
element_type = element_type.merge(_get_type(element))
return SwTupleType(element_type)
if isinstance(obj, dict):
key_type: SwType = UNKNOWN
value_type: SwType = UNKNOWN
for k, v in obj.items():
key_type = key_type.merge(_get_type(k))
value_type = value_type.merge(_get_type(v))
return SwMapType(key_type, value_type)
if isinstance(obj, SwObject):
attrs = {}
for k, v in obj.__dict__.items():
Expand Down
Loading

0 comments on commit 5161d30

Please sign in to comment.