Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow recursive schemas #104

Merged
merged 3 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 88 additions & 42 deletions jsf/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
),
initial_state: Dict[str, Any] = MappingProxyType({}),
allow_none_optionals: confloat(ge=0.0, le=1.0) = 0.5,
max_recursive_depth: int = 10,
):
"""Initializes the JSF generator with the provided schema and
configuration options.
Expand All @@ -62,16 +63,19 @@ def __init__(
context (Dict[str, Any], optional): A dictionary that provides additional utilities for handling the schema, such as a faker for generating fake data, a random number generator, and datetime utilities. It also includes an internal dictionary for handling List, Union, and Tuple types. Defaults to a dictionary with "faker", "random", "datetime", and "__internal__" keys.
initial_state (Dict[str, Any], optional): A dictionary that represents the initial state of the parser. If you wish to extend the state so it can be accesses by your schema you can add any references in here. Defaults to an empty dictionary.
allow_none_optionals (confloat, optional): A parameter that determines the probability of optional fields being set to None. Defaults to 0.5.
max_recursive_depth (int, optional): A parameter that determines the maximum depth when generating a recursive schema. Defaults to 10.
"""
self.root_schema = schema
self.definitions = {}
self.base_state = {
"__counter__": count(start=1),
"__all_json_paths__": [],
"__depth__": 0,
**initial_state,
}
self.base_context = context
self.allow_none_optionals = allow_none_optionals
self.max_recursive_depth = max_recursive_depth

self.root = None
self._parse(schema)
Expand All @@ -89,6 +93,7 @@ def from_json(
),
initial_state: Dict[str, Any] = MappingProxyType({}),
allow_none_optionals: confloat(ge=0.0, le=1.0) = 0.5,
max_recursive_depth: int = 10,
) -> "JSF":
"""Initializes the JSF generator with the provided schema at the given
path and configuration options.
Expand All @@ -98,9 +103,12 @@ def from_json(
context (Dict[str, Any], optional): A dictionary that provides additional utilities for handling the schema, such as a faker for generating fake data, a random number generator, and datetime utilities. It also includes an internal dictionary for handling List, Union, and Tuple types. Defaults to a dictionary with "faker", "random", "datetime", and "__internal__" keys.
initial_state (Dict[str, Any], optional): A dictionary that represents the initial state of the parser. If you wish to extend the state so it can be accesses by your schema you can add any references in here. Defaults to an empty dictionary.
allow_none_optionals (confloat, optional): A parameter that determines the probability of optional fields being set to None. Defaults to 0.5.
max_recursive_depth (int, optional): A parameter that determines the maximum depth when generating a recursive schema. Defaults to 10.
"""
with open(path) as f:
return JSF(json.load(f), context, initial_state, allow_none_optionals)
return JSF(
json.load(f), context, initial_state, allow_none_optionals, max_recursive_depth
)

def __parse_primitive(self, name: str, path: str, schema: Dict[str, Any]) -> PrimitiveTypes:
item_type, is_nullable = self.__is_field_nullable(schema)
Expand All @@ -111,62 +119,79 @@ def __parse_primitive(self, name: str, path: str, schema: Dict[str, Any]) -> Pri
"path": path,
"is_nullable": is_nullable,
"allow_none_optionals": self.allow_none_optionals,
"max_recursive_depth": self.max_recursive_depth,
**schema,
}
)

def __parse_object(self, name: str, path: str, schema: Dict[str, Any]) -> Object:
def __parse_object(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> Object:
_, is_nullable = self.__is_field_nullable(schema)
model = Object.from_dict(
{
"name": name,
"path": path,
"is_nullable": is_nullable,
"allow_none_optionals": self.allow_none_optionals,
"max_recursive_depth": self.max_recursive_depth,
**schema,
}
)
root = model if root is None else root
props = []
for _name, definition in schema.get("properties", {}).items():
props.append(self.__parse_definition(_name, path=f"{path}/{_name}", schema=definition))
props.append(
self.__parse_definition(_name, path=f"{path}/{_name}", schema=definition, root=root)
)
model.properties = props
pattern_props = []
for _name, definition in schema.get("patternProperties", {}).items():
pattern_props.append(
self.__parse_definition(_name, path=f"{path}/{_name}", schema=definition)
self.__parse_definition(_name, path=f"{path}/{_name}", schema=definition, root=root)
)
model.patternProperties = pattern_props

return model

def __parse_array(self, name: str, path: str, schema: Dict[str, Any]) -> Array:
def __parse_array(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> Array:
_, is_nullable = self.__is_field_nullable(schema)
arr = Array.from_dict(
{
"name": name,
"path": path,
"is_nullable": is_nullable,
"allow_none_optionals": self.allow_none_optionals,
"max_recursive_depth": self.max_recursive_depth,
**schema,
}
)
arr.items = self.__parse_definition(name, name, schema["items"])
root = arr if root is None else root
arr.items = self.__parse_definition(name, f"{path}/items", schema["items"], root=root)
return arr

def __parse_tuple(self, name: str, path: str, schema: Dict[str, Any]) -> JSFTuple:
def __parse_tuple(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> JSFTuple:
_, is_nullable = self.__is_field_nullable(schema)
arr = JSFTuple.from_dict(
{
"name": name,
"path": path,
"is_nullable": is_nullable,
"allow_none_optionals": self.allow_none_optionals,
"max_recursive_depth": self.max_recursive_depth,
**schema,
}
)
root = arr if root is None else root
arr.items = []
for i, item in enumerate(schema["items"]):
arr.items.append(self.__parse_definition(name, path=f"{name}[{i}]", schema=item))
arr.items.append(
self.__parse_definition(name, path=f"{path}/{name}[{i}]", schema=item, root=root)
)
return arr

def __is_field_nullable(self, schema: Dict[str, Any]) -> Tuple[str, bool]:
Expand All @@ -181,40 +206,55 @@ def __is_field_nullable(self, schema: Dict[str, Any]) -> Tuple[str, bool]:
return random.choice(item_type_deep_copy), False
return item_type, False

def __parse_anyOf(self, name: str, path: str, schema: Dict[str, Any]) -> AnyOf:
def __parse_anyOf(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> AnyOf:
model = AnyOf(name=name, path=path, max_recursive_depth=self.max_recursive_depth, **schema)
root = model if root is None else root
schemas = []
for d in schema["anyOf"]:
schemas.append(self.__parse_definition(name, path, d))
return AnyOf(name=name, path=path, schemas=schemas, **schema)
schemas.append(self.__parse_definition(name, path, d, root=root))
model.schemas = schemas
return model

def __parse_allOf(self, name: str, path: str, schema: Dict[str, Any]) -> AllOf:
def __parse_allOf(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> AllOf:
combined_schema = dict(ChainMap(*schema["allOf"]))
return AllOf(
name=name,
path=path,
combined_schema=self.__parse_definition(name, path, combined_schema),
**schema,
)
model = AllOf(name=name, path=path, max_recursive_depth=self.max_recursive_depth, **schema)
root = model if root is None else root
model.combined_schema = self.__parse_definition(name, path, combined_schema, root=root)
return model

def __parse_oneOf(self, name: str, path: str, schema: Dict[str, Any]) -> OneOf:
def __parse_oneOf(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> OneOf:
model = OneOf(name=name, path=path, max_recursive_depth=self.max_recursive_depth, **schema)
root = model if root is None else root
schemas = []
for d in schema["oneOf"]:
schemas.append(self.__parse_definition(name, path, d))
return OneOf(name=name, path=path, schemas=schemas, **schema)
schemas.append(self.__parse_definition(name, path, d, root=root))
model.schemas = schemas
return model

def __parse_named_definition(self, def_name: str) -> AllTypes:
def __parse_named_definition(self, path: str, def_name: str, root) -> AllTypes:
schema = self.root_schema
parsed_definition = None
for def_tag in ("definitions", "$defs"):
for name, definition in schema.get(def_tag, {}).items():
if name == def_name:
parsed_definition = self.__parse_definition(
name, path=f"#/{def_tag}", schema=definition
)
self.definitions[f"#/{def_tag}/{name}"] = parsed_definition
if path.startswith(f"#/{def_tag}/{def_name}"):
root.is_recursive = True
return root
definition = schema.get(def_tag, {}).get(def_name)
if definition is not None:
parsed_definition = self.__parse_definition(
def_name, path=f"{path}/#/{def_tag}/{def_name}", schema=definition, root=root
)
self.definitions[f"#/{def_tag}/{def_name}"] = parsed_definition
return parsed_definition

def __parse_definition(self, name: str, path: str, schema: Dict[str, Any]) -> AllTypes:
def __parse_definition(
self, name: str, path: str, schema: Dict[str, Any], root: Optional[AllTypes] = None
) -> AllTypes:
self.base_state["__all_json_paths__"].append(path)
item_type, is_nullable = self.__is_field_nullable(schema)
if "const" in schema:
Expand All @@ -232,25 +272,26 @@ def __parse_definition(self, name: str, path: str, schema: Dict[str, Any]) -> Al
"path": path,
"is_nullable": is_nullable,
"allow_none_optionals": self.allow_none_optionals,
"max_recursive_depth": self.max_recursive_depth,
**schema,
}
)
elif "type" in schema:
if item_type == "object" and "properties" in schema:
return self.__parse_object(name, path, schema)
return self.__parse_object(name, path, schema, root)
elif item_type == "object" and "anyOf" in schema:
return self.__parse_anyOf(name, path, schema)
return self.__parse_anyOf(name, path, schema, root)
elif item_type == "object" and "allOf" in schema:
return self.__parse_allOf(name, path, schema)
return self.__parse_allOf(name, path, schema, root)
elif item_type == "object" and "oneOf" in schema:
return self.__parse_oneOf(name, path, schema)
return self.__parse_oneOf(name, path, schema, root)
elif item_type == "array":
if (schema.get("contains") is not None) or isinstance(schema.get("items"), dict):
return self.__parse_array(name, path, schema)
return self.__parse_array(name, path, schema, root)
if isinstance(schema.get("items"), list) and all(
isinstance(x, dict) for x in schema.get("items", [])
):
return self.__parse_tuple(name, path, schema)
return self.__parse_tuple(name, path, schema, root)
else:
return self.__parse_primitive(name, path, schema)
elif "$ref" in schema:
Expand All @@ -261,28 +302,33 @@ def __parse_definition(self, name: str, path: str, schema: Dict[str, Any]) -> Al
else:
# parse referenced definition
ref_name = frag.split("/")[-1]
cls = self.__parse_named_definition(ref_name)
cls = self.__parse_named_definition(path, ref_name, root)
else:
with s_open(ext, "r") as f:
external_jsf = JSF(json.load(f))
cls = deepcopy(external_jsf.definitions.get(f"#{frag}"))
cls.name = name
cls.path = path
if path != "#" and cls == root:
cls.name = name
elif path != "#":
cls.name = name
cls.path = path
return cls
elif "anyOf" in schema:
return self.__parse_anyOf(name, path, schema)
return self.__parse_anyOf(name, path, schema, root)
elif "allOf" in schema:
return self.__parse_allOf(name, path, schema)
return self.__parse_allOf(name, path, schema, root)
elif "oneOf" in schema:
return self.__parse_oneOf(name, path, schema)
return self.__parse_oneOf(name, path, schema, root)
else:
raise ValueError(f"Cannot parse schema {repr(schema)}") # pragma: no cover

def _parse(self, schema: Dict[str, Any]) -> AllTypes:
for def_tag in ("definitions", "$defs"):
for name, definition in schema.get(def_tag, {}).items():
if f"#/{def_tag}/{name}" not in self.definitions:
item = self.__parse_definition(name, path=f"#/{def_tag}", schema=definition)
item = self.__parse_definition(
name, path=f"#/{def_tag}/{name}", schema=definition
)
self.definitions[f"#/{def_tag}/{name}"] = item

self.root = self.__parse_definition(name="root", path="#", schema=schema)
Expand Down
7 changes: 6 additions & 1 deletion jsf/schema_types/_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ def generate(self, context: Dict[str, Any]) -> Optional[List[Tuple]]:
try:
return super().generate(context)
except ProviderNotSetException:
return tuple(item.generate(context) for item in self.items)
depth = context["state"]["__depth__"]
output = []
for item in self.items:
output.append(item.generate(context))
context["state"]["__depth__"] = depth
return tuple(output)

def model(self, context: Dict[str, Any]) -> Tuple[Type, Any]:
_type = eval(
Expand Down
5 changes: 4 additions & 1 deletion jsf/schema_types/anyof.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
try:
return super().generate(context)
except ProviderNotSetException:
return random.choice(self.schemas).generate(context)
filtered_schemas = []
if context["state"]["__depth__"] > self.max_recursive_depth:
filtered_schemas = [schema for schema in self.schemas if not schema.is_recursive]

Check warning on line 20 in jsf/schema_types/anyof.py

View check run for this annotation

Codecov / codecov/patch

jsf/schema_types/anyof.py#L20

Added line #L20 was not covered by tests
return random.choice(filtered_schemas or self.schemas).generate(context)

def model(self, context: Dict[str, Any]) -> None:
pass
11 changes: 7 additions & 4 deletions jsf/schema_types/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,22 @@ def generate(self, context: Dict[str, Any]) -> Optional[List[Any]]:
elif isinstance(self.fixed, int):
self.minItems = self.maxItems = self.fixed

output = [
self.items.generate(context)
for _ in range(random.randint(int(self.minItems), int(self.maxItems)))
]
depth = context["state"]["__depth__"]
output = []
for _ in range(random.randint(int(self.minItems), int(self.maxItems))):
output.append(self.items.generate(context))
context["state"]["__depth__"] = depth
if self.uniqueItems and self.items.type == "object":
output = [dict(s) for s in {frozenset(d.items()) for d in output}]
while len(output) < self.minItems:
output.append(self.items.generate(context))
output = [dict(s) for s in {frozenset(d.items()) for d in output}]
context["state"]["__depth__"] = depth
elif self.uniqueItems:
output = set(output)
while len(output) < self.minItems:
output.add(self.items.generate(context))
context["state"]["__depth__"] = depth
output = list(output)
return output

Expand Down
10 changes: 9 additions & 1 deletion jsf/schema_types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,25 @@ class BaseSchema(BaseModel):
provider: Optional[str] = Field(None, alias="$provider")
set_state: Optional[Dict[str, str]] = Field(None, alias="$state")
is_nullable: bool = False
is_recursive: bool = False
allow_none_optionals: float = Field(0.5, ge=0.0, le=1.0)
max_recursive_depth: int = 10

@classmethod
def from_dict(cls, d: Dict[str, Any]) -> Self:
raise NotImplementedError # pragma: no cover

def generate(self, context: Dict[str, Any]) -> Any:
if self.is_recursive:
context["state"]["__depth__"] += 1

if self.set_state is not None:
context["state"][self.path] = {k: eval(v, context)() for k, v in self.set_state.items()}

if self.is_nullable and random.uniform(0, 1) < self.allow_none_optionals:
if self.is_nullable and (
random.uniform(0, 1) < self.allow_none_optionals
or context["state"]["__depth__"] > self.max_recursive_depth
):
return None
if self.provider is not None:
return eval(self.provider, context)()
Expand Down
Loading
Loading