diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9d27d4358..d18db9b62 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: 'v0.4.8' + rev: 'v0.4.9' hooks: - id: ruff files: "^datamodel_code_generator|^tests" diff --git a/README.md b/README.md index 7bbb143c9..d7d77dfdd 100644 --- a/README.md +++ b/README.md @@ -457,6 +457,9 @@ Model customization: --use-schema-description Use schema description to populate class docstring --use-title-as-name use titles as class names of models + --use-exact-imports Import exact types instead of modules, for example: + `from .foo import Bar` instead of + `from . import foo` with `foo.Bar` Template customization: --aliases ALIASES Alias mapping file diff --git a/datamodel_code_generator/__init__.py b/datamodel_code_generator/__init__.py index c883717e1..4a4f76a93 100644 --- a/datamodel_code_generator/__init__.py +++ b/datamodel_code_generator/__init__.py @@ -302,6 +302,7 @@ def generate( use_pendulum: bool = False, http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None, treat_dots_as_module: bool = False, + use_exact_imports: bool = False, ) -> None: remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict() if isinstance(input_, str): @@ -463,6 +464,7 @@ def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]: use_pendulum=use_pendulum, http_query_parameters=http_query_parameters, treat_dots_as_module=treat_dots_as_module, + use_exact_imports=use_exact_imports, **kwargs, ) diff --git a/datamodel_code_generator/__main__.py b/datamodel_code_generator/__main__.py index 27933e7b0..7bb4b1ea2 100644 --- a/datamodel_code_generator/__main__.py +++ b/datamodel_code_generator/__main__.py @@ -313,6 +313,7 @@ def validate_root(cls, values: Any) -> Any: use_pendulum: bool = False http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None treat_dot_as_module: bool = False + use_exact_imports: bool = False def merge_args(self, args: Namespace) -> None: set_args = { @@ -510,6 +511,7 @@ def main(args: Optional[Sequence[str]] = None) -> Exit: use_pendulum=config.use_pendulum, http_query_parameters=config.http_query_parameters, treat_dots_as_module=config.treat_dot_as_module, + use_exact_imports=config.use_exact_imports, ) return Exit.OK except InvalidClassNameError as e: diff --git a/datamodel_code_generator/arguments.py b/datamodel_code_generator/arguments.py index 562bc126c..83dbc6074 100644 --- a/datamodel_code_generator/arguments.py +++ b/datamodel_code_generator/arguments.py @@ -184,6 +184,13 @@ def start_section(self, heading: Optional[str]) -> None: action='store_true', default=False, ) +model_options.add_argument( + '--use-exact-imports', + help='import exact types instead of modules, for example: "from .foo import Bar" instead of ' + '"from . import foo" with "foo.Bar"', + action='store_true', + default=False, +) # ====================================================================================== # Typing options for generated models diff --git a/datamodel_code_generator/imports.py b/datamodel_code_generator/imports.py index bef598ad7..0b4564dba 100644 --- a/datamodel_code_generator/imports.py +++ b/datamodel_code_generator/imports.py @@ -26,11 +26,12 @@ class Imports(DefaultDict[Optional[str], Set[str]]): def __str__(self) -> str: return self.dump() - def __init__(self) -> None: + def __init__(self, use_exact: bool = False) -> None: super().__init__(set) self.alias: DefaultDict[Optional[str], Dict[str, str]] = defaultdict(dict) self.counter: Dict[Tuple[Optional[str], str], int] = defaultdict(int) self.reference_paths: Dict[str, Import] = {} + self.use_exact: bool = use_exact def _set_alias(self, from_: Optional[str], imports: Set[str]) -> List[str]: return [ diff --git a/datamodel_code_generator/parser/base.py b/datamodel_code_generator/parser/base.py index b28d59949..5544bb0f5 100644 --- a/datamodel_code_generator/parser/base.py +++ b/datamodel_code_generator/parser/base.py @@ -238,6 +238,14 @@ def relative(current_module: str, reference: str) -> Tuple[str, str]: return left, right +def exact_import(from_: str, import_: str, short_name: str) -> Tuple[str, str]: + if from_ == '.': + # Prevents "from . import foo" becoming "from ..foo import Foo" + # when our imported module has the same parent + return f'.{import_}', short_name + return f'{from_}.{import_}', short_name + + @runtime_checkable class Child(Protocol): @property @@ -393,6 +401,7 @@ def __init__( use_pendulum: bool = False, http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None, treat_dots_as_module: bool = False, + use_exact_imports: bool = False, ) -> None: self.data_type_manager: DataTypeManager = data_type_manager_type( python_version=target_python_version, @@ -406,7 +415,8 @@ def __init__( self.data_model_root_type: Type[DataModel] = data_model_root_type self.data_model_field_type: Type[DataModelFieldBase] = data_model_field_type - self.imports: Imports = Imports() + self.imports: Imports = Imports(use_exact_imports) + self.use_exact_imports: bool = use_exact_imports self._append_additional_imports(additional_imports=additional_imports) self.base_class: Optional[str] = base_class @@ -701,6 +711,10 @@ def __change_from_import( from_, import_ = full_path = relative( model.module_name, data_type.full_name ) + if imports.use_exact: + from_, import_ = exact_import( + from_, import_, data_type.reference.short_name + ) import_ = import_.replace('-', '_') if ( len(model.module_path) > 1 @@ -1287,7 +1301,7 @@ class Processed(NamedTuple): processed_models: List[Processed] = [] for module, models in module_models: - imports = module_to_import[module] = Imports() + imports = module_to_import[module] = Imports(self.use_exact_imports) init = False if module: parent = (*module[:-1], '__init__.py') diff --git a/datamodel_code_generator/parser/graphql.py b/datamodel_code_generator/parser/graphql.py index 51cdf8a90..9eb55d6c8 100644 --- a/datamodel_code_generator/parser/graphql.py +++ b/datamodel_code_generator/parser/graphql.py @@ -159,6 +159,7 @@ def __init__( use_pendulum: bool = False, http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None, treat_dots_as_module: bool = False, + use_exact_imports: bool = False, ) -> None: super().__init__( source=source, @@ -227,6 +228,7 @@ def __init__( use_pendulum=use_pendulum, http_query_parameters=http_query_parameters, treat_dots_as_module=treat_dots_as_module, + use_exact_imports=use_exact_imports, ) self.data_model_scalar_type = data_model_scalar_type diff --git a/datamodel_code_generator/parser/jsonschema.py b/datamodel_code_generator/parser/jsonschema.py index 495afd07f..c849ad3c6 100644 --- a/datamodel_code_generator/parser/jsonschema.py +++ b/datamodel_code_generator/parser/jsonschema.py @@ -441,6 +441,7 @@ def __init__( use_pendulum: bool = False, http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None, treat_dots_as_module: bool = False, + use_exact_imports: bool = False, ) -> None: super().__init__( source=source, @@ -509,6 +510,7 @@ def __init__( use_pendulum=use_pendulum, http_query_parameters=http_query_parameters, treat_dots_as_module=treat_dots_as_module, + use_exact_imports=use_exact_imports, ) self.remote_object_cache: DefaultPutDict[str, Dict[str, Any]] = DefaultPutDict() diff --git a/datamodel_code_generator/parser/openapi.py b/datamodel_code_generator/parser/openapi.py index 0cc59b05d..89ad80314 100644 --- a/datamodel_code_generator/parser/openapi.py +++ b/datamodel_code_generator/parser/openapi.py @@ -223,6 +223,7 @@ def __init__( use_pendulum: bool = False, http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None, treat_dots_as_module: bool = False, + use_exact_imports: bool = False, ): super().__init__( source=source, @@ -291,6 +292,7 @@ def __init__( use_pendulum=use_pendulum, http_query_parameters=http_query_parameters, treat_dots_as_module=treat_dots_as_module, + use_exact_imports=use_exact_imports, ) self.open_api_scopes: List[OpenAPIScope] = openapi_scopes or [ OpenAPIScope.Schemas diff --git a/tests/parser/test_base.py b/tests/parser/test_base.py index 1a980891b..77316a4ed 100644 --- a/tests/parser/test_base.py +++ b/tests/parser/test_base.py @@ -5,7 +5,12 @@ from datamodel_code_generator.model import DataModel, DataModelFieldBase from datamodel_code_generator.model.pydantic import BaseModel, DataModelField -from datamodel_code_generator.parser.base import Parser, relative, sort_data_models +from datamodel_code_generator.parser.base import ( + Parser, + exact_import, + relative, + sort_data_models, +) from datamodel_code_generator.reference import Reference, snake_to_upper_camel from datamodel_code_generator.types import DataType @@ -183,6 +188,19 @@ def test_relative(current_module: str, reference: str, val: Tuple[str, str]): assert relative(current_module, reference) == val +@pytest.mark.parametrize( + 'from_,import_,name,val', + [ + ('.', 'mod', 'Foo', ('.mod', 'Foo')), + ('.a', 'mod', 'Foo', ('.a.mod', 'Foo')), + ('..a', 'mod', 'Foo', ('..a.mod', 'Foo')), + ('..a.b', 'mod', 'Foo', ('..a.b.mod', 'Foo')), + ], +) +def test_exact_import(from_: str, import_: str, name: str, val: Tuple[str, str]): + assert exact_import(from_, import_, name) == val + + @pytest.mark.parametrize( 'word,expected', [