Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 9, 2024
1 parent 3c002e0 commit 9e833c4
Show file tree
Hide file tree
Showing 16 changed files with 409 additions and 400 deletions.
191 changes: 95 additions & 96 deletions datamodel_code_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import contextlib
import os
import sys
import yaml
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
Expand All @@ -28,8 +27,10 @@
)
from urllib.parse import ParseResult

import yaml

import datamodel_code_generator.pydantic_patch # noqa: F401
from datamodel_code_generator.format import PythonVersion, DatetimeClassType
from datamodel_code_generator.format import DatetimeClassType, PythonVersion
from datamodel_code_generator.model.pydantic_v2 import UnionMode
from datamodel_code_generator.parser import DefaultPutDict, LiteralType
from datamodel_code_generator.parser.base import Parser
Expand Down Expand Up @@ -59,8 +60,7 @@ def load_yaml_from_path(path: Path, encoding: str) -> Any:

if TYPE_CHECKING:

def get_version() -> str:
...
def get_version() -> str: ...

else:

Expand All @@ -82,15 +82,15 @@ def enable_debug_message() -> None: # pragma: no cover


def snooper_to_methods( # type: ignore
output=None,
watch=(),
watch_explode=(),
depth=1,
prefix='',
overwrite=False,
thread_info=False,
custom_repr=(),
max_variable_length=100,
output=None,
watch=(),
watch_explode=(),
depth=1,
prefix='',
overwrite=False,
thread_info=False,
custom_repr=(),
max_variable_length=100,
) -> Callable[..., Any]:
def inner(cls: Type[T]) -> Type[T]:
if not pysnooper:
Expand Down Expand Up @@ -147,18 +147,18 @@ def is_schema(text: str) -> bool:
return False
schema = data.get('$schema')
if isinstance(schema, str) and any(
schema.startswith(u) for u in JSON_SCHEMA_URLS
schema.startswith(u) for u in JSON_SCHEMA_URLS
): # pragma: no cover
return True
if isinstance(data.get('type'), str):
return True
if any(
isinstance(data.get(o), list)
for o in (
'allOf',
'anyOf',
'oneOf',
)
isinstance(data.get(o), list)
for o in (
'allOf',
'anyOf',
'oneOf',
)
):
return True
if isinstance(data.get('properties'), dict):
Expand Down Expand Up @@ -231,77 +231,77 @@ def get_first_file(path: Path) -> Path: # pragma: no cover


def generate(
input_: Union[Path, str, ParseResult],
*,
input_filename: Optional[str] = None,
input_file_type: InputFileType = InputFileType.Auto,
output: Optional[Path] = None,
output_model_type: DataModelType = DataModelType.PydanticBaseModel,
target_python_version: PythonVersion = PythonVersion.PY_38,
base_class: str = '',
additional_imports: Optional[List[str]] = None,
custom_template_dir: Optional[Path] = None,
extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None,
validation: bool = False,
field_constraints: bool = False,
snake_case_field: bool = False,
strip_default_none: bool = False,
aliases: Optional[Mapping[str, str]] = None,
disable_timestamp: bool = False,
enable_version_header: bool = False,
allow_population_by_field_name: bool = False,
allow_extra_fields: bool = False,
apply_default_values_for_required_fields: bool = False,
force_optional_for_required_fields: bool = False,
class_name: Optional[str] = None,
use_standard_collections: bool = False,
use_schema_description: bool = False,
use_field_description: bool = False,
use_default_kwarg: bool = False,
reuse_model: bool = False,
encoding: str = 'utf-8',
enum_field_as_literal: Optional[LiteralType] = None,
use_one_literal_as_default: bool = False,
set_default_enum_member: bool = False,
use_subclass_enum: bool = False,
strict_nullable: bool = False,
use_generic_container_types: bool = False,
enable_faux_immutability: bool = False,
disable_appending_item_suffix: bool = False,
strict_types: Optional[Sequence[StrictTypes]] = None,
empty_enum_field_name: Optional[str] = None,
custom_class_name_generator: Optional[Callable[[str], str]] = None,
field_extra_keys: Optional[Set[str]] = None,
field_include_all_keys: bool = False,
field_extra_keys_without_x_prefix: Optional[Set[str]] = None,
openapi_scopes: Optional[List[OpenAPIScope]] = None,
graphql_scopes: Optional[List[GraphQLScope]] = None,
wrap_string_literal: Optional[bool] = None,
use_title_as_name: bool = False,
use_operation_id_as_name: bool = False,
use_unique_items_as_set: bool = False,
http_headers: Optional[Sequence[Tuple[str, str]]] = None,
http_ignore_tls: bool = False,
use_annotated: bool = False,
use_non_positive_negative_number_constrained_types: bool = False,
original_field_name_delimiter: Optional[str] = None,
use_double_quotes: bool = False,
use_union_operator: bool = False,
collapse_root_models: bool = False,
special_field_name_prefix: Optional[str] = None,
remove_special_field_name_prefix: bool = False,
capitalise_enum_members: bool = False,
keep_model_order: bool = False,
custom_file_header: Optional[str] = None,
custom_file_header_path: Optional[Path] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
use_pendulum: bool = False,
http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None,
treat_dots_as_module: bool = False,
use_exact_imports: bool = False,
union_mode: Optional[UnionMode] = None,
output_datetime_class: DataModelType = DatetimeClassType.Datetime
input_: Union[Path, str, ParseResult],
*,
input_filename: Optional[str] = None,
input_file_type: InputFileType = InputFileType.Auto,
output: Optional[Path] = None,
output_model_type: DataModelType = DataModelType.PydanticBaseModel,
target_python_version: PythonVersion = PythonVersion.PY_38,
base_class: str = '',
additional_imports: Optional[List[str]] = None,
custom_template_dir: Optional[Path] = None,
extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None,
validation: bool = False,
field_constraints: bool = False,
snake_case_field: bool = False,
strip_default_none: bool = False,
aliases: Optional[Mapping[str, str]] = None,
disable_timestamp: bool = False,
enable_version_header: bool = False,
allow_population_by_field_name: bool = False,
allow_extra_fields: bool = False,
apply_default_values_for_required_fields: bool = False,
force_optional_for_required_fields: bool = False,
class_name: Optional[str] = None,
use_standard_collections: bool = False,
use_schema_description: bool = False,
use_field_description: bool = False,
use_default_kwarg: bool = False,
reuse_model: bool = False,
encoding: str = 'utf-8',
enum_field_as_literal: Optional[LiteralType] = None,
use_one_literal_as_default: bool = False,
set_default_enum_member: bool = False,
use_subclass_enum: bool = False,
strict_nullable: bool = False,
use_generic_container_types: bool = False,
enable_faux_immutability: bool = False,
disable_appending_item_suffix: bool = False,
strict_types: Optional[Sequence[StrictTypes]] = None,
empty_enum_field_name: Optional[str] = None,
custom_class_name_generator: Optional[Callable[[str], str]] = None,
field_extra_keys: Optional[Set[str]] = None,
field_include_all_keys: bool = False,
field_extra_keys_without_x_prefix: Optional[Set[str]] = None,
openapi_scopes: Optional[List[OpenAPIScope]] = None,
graphql_scopes: Optional[List[GraphQLScope]] = None,
wrap_string_literal: Optional[bool] = None,
use_title_as_name: bool = False,
use_operation_id_as_name: bool = False,
use_unique_items_as_set: bool = False,
http_headers: Optional[Sequence[Tuple[str, str]]] = None,
http_ignore_tls: bool = False,
use_annotated: bool = False,
use_non_positive_negative_number_constrained_types: bool = False,
original_field_name_delimiter: Optional[str] = None,
use_double_quotes: bool = False,
use_union_operator: bool = False,
collapse_root_models: bool = False,
special_field_name_prefix: Optional[str] = None,
remove_special_field_name_prefix: bool = False,
capitalise_enum_members: bool = False,
keep_model_order: bool = False,
custom_file_header: Optional[str] = None,
custom_file_header_path: Optional[Path] = None,
custom_formatters: Optional[List[str]] = None,
custom_formatters_kwargs: Optional[Dict[str, Any]] = None,
use_pendulum: bool = False,
http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None,
treat_dots_as_module: bool = False,
use_exact_imports: bool = False,
union_mode: Optional[UnionMode] = None,
output_datetime_class: DataModelType = DatetimeClassType.Datetime,
) -> None:
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
if isinstance(input_, str):
Expand Down Expand Up @@ -361,8 +361,7 @@ def generate(

def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]:
csv_reader = csv.DictReader(csv_file)
return dict(
zip(csv_reader.fieldnames, next(csv_reader))) # type: ignore
return dict(zip(csv_reader.fieldnames, next(csv_reader))) # type: ignore

if isinstance(input_, Path):
with input_.open(encoding=encoding) as f:
Expand Down Expand Up @@ -398,11 +397,11 @@ def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]:
else:
default_field_extras = None


from datamodel_code_generator.model import get_data_model_types

data_model_types = get_data_model_types(output_model_type, target_python_version,
output_datetime_class)
data_model_types = get_data_model_types(
output_model_type, target_python_version, output_datetime_class
)
parser = parser_class(
source=input_text or input_,
data_model_type=data_model_types.data_model,
Expand Down
33 changes: 16 additions & 17 deletions datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from __future__ import annotations

import argcomplete
import black
import json
import signal
import sys
Expand All @@ -16,7 +14,6 @@
from enum import IntEnum
from io import TextIOBase
from pathlib import Path
from pydantic import BaseModel
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -32,6 +29,10 @@
)
from urllib.parse import ParseResult, urlparse

import argcomplete
import black
from pydantic import BaseModel

from datamodel_code_generator.model.pydantic_v2 import UnionMode

if TYPE_CHECKING:
Expand All @@ -50,10 +51,10 @@
)
from datamodel_code_generator.arguments import DEFAULT_ENCODING, arg_parser, namespace
from datamodel_code_generator.format import (
DatetimeClassType,
PythonVersion,
black_find_project_root,
is_supported_in_black,
DatetimeClassType
)
from datamodel_code_generator.parser import LiteralType
from datamodel_code_generator.reference import is_url
Expand Down Expand Up @@ -96,8 +97,7 @@ def __getitem__(self, item: str) -> Any:
if TYPE_CHECKING:

@classmethod
def get_fields(cls) -> Dict[str, Any]:
...
def get_fields(cls) -> Dict[str, Any]: ...

else:

Expand All @@ -117,6 +117,7 @@ class Config:
arbitrary_types_allowed = (TextIOBase,)

if not TYPE_CHECKING:

@classmethod
def get_fields(cls) -> Dict[str, Any]:
return cls.__fields__
Expand Down Expand Up @@ -153,7 +154,7 @@ def validate_url(cls, value: Any) -> Optional[ParseResult]:

@model_validator(mode='after')
def validate_use_generic_container_types(
cls, values: Dict[str, Any]
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
if values.get('use_generic_container_types'):
target_python_version: PythonVersion = values['target_python_version']
Expand All @@ -166,7 +167,7 @@ def validate_use_generic_container_types(

@model_validator(mode='after')
def validate_original_field_name_delimiter(
cls, values: Dict[str, Any]
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
if values.get('original_field_name_delimiter') is not None:
if not values.get('snake_case_field'):
Expand All @@ -189,8 +190,7 @@ def validate_http_headers(cls, value: Any) -> Optional[List[Tuple[str, str]]]:
def validate_each_item(each_item: Any) -> Tuple[str, str]:
if isinstance(each_item, str): # pragma: no cover
try:
field_name, field_value = each_item.split(':',
maxsplit=1) # type: str, str
field_name, field_value = each_item.split(':', maxsplit=1) # type: str, str
return field_name, field_value.lstrip()
except ValueError:
raise Error(f'Invalid http header: {each_item!r}')
Expand All @@ -202,13 +202,12 @@ def validate_each_item(each_item: Any) -> Tuple[str, str]:

@field_validator('http_query_parameters', mode='before')
def validate_http_query_parameters(
cls, value: Any
cls, value: Any
) -> Optional[List[Tuple[str, str]]]:
def validate_each_item(each_item: Any) -> Tuple[str, str]:
if isinstance(each_item, str): # pragma: no cover
try:
field_name, field_value = each_item.split('=',
maxsplit=1) # type: str, str
field_name, field_value = each_item.split('=', maxsplit=1) # type: str, str
return field_name, field_value.lstrip()
except ValueError:
raise Error(f'Invalid http query parameter: {each_item!r}')
Expand Down Expand Up @@ -417,7 +416,7 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
print(f'Unable to load alias mapping: {e}', file=sys.stderr)
return Exit.ERROR
if not isinstance(aliases, dict) or not all(
isinstance(k, str) and isinstance(v, str) for k, v in aliases.items()
isinstance(k, str) and isinstance(v, str) for k, v in aliases.items()
):
print(
'Alias mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
Expand All @@ -438,8 +437,8 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
)
return Exit.ERROR
if not isinstance(custom_formatters_kwargs, dict) or not all(
isinstance(k, str) and isinstance(v, str)
for k, v in custom_formatters_kwargs.items()
isinstance(k, str) and isinstance(v, str)
for k, v in custom_formatters_kwargs.items()
): # pragma: no cover
print(
'Custom formatters kwargs mapping must be a JSON string mapping (e.g. {"from": "to", ...})',
Expand Down Expand Up @@ -515,7 +514,7 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
treat_dots_as_module=config.treat_dot_as_module,
use_exact_imports=config.use_exact_imports,
union_mode=config.union_mode,
output_datetime_class=config.output_datetime_class
output_datetime_class=config.output_datetime_class,
)
return Exit.OK
except InvalidClassNameError as e:
Expand Down
Loading

0 comments on commit 9e833c4

Please sign in to comment.