Skip to content

Commit

Permalink
Merge branch 'main' into feat/dots-in-path
Browse files Browse the repository at this point in the history
  • Loading branch information
luca-knaack-webcom authored Jun 18, 2024
2 parents deee870 + 28be37d commit 17c356a
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.4.8'
rev: 'v0.4.9'
hooks:
- id: ruff
files: "^datamodel_code_generator|^tests"
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,9 @@ Model customization:
--use-schema-description
Use schema description to populate class docstring
--use-title-as-name use titles as class names of models
--use-exact-imports Import exact types instead of modules, for example:
`from .foo import Bar` instead of
`from . import foo` with `foo.Bar`

Template customization:
--aliases ALIASES Alias mapping file
Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def generate(
use_pendulum: bool = False,
http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None,
treat_dots_as_module: bool = False,
use_exact_imports: bool = False,
) -> None:
remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict()
if isinstance(input_, str):
Expand Down Expand Up @@ -463,6 +464,7 @@ def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]:
use_pendulum=use_pendulum,
http_query_parameters=http_query_parameters,
treat_dots_as_module=treat_dots_as_module,
use_exact_imports=use_exact_imports,
**kwargs,
)

Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def validate_root(cls, values: Any) -> Any:
use_pendulum: bool = False
http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None
treat_dot_as_module: bool = False
use_exact_imports: bool = False

def merge_args(self, args: Namespace) -> None:
set_args = {
Expand Down Expand Up @@ -510,6 +511,7 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
use_pendulum=config.use_pendulum,
http_query_parameters=config.http_query_parameters,
treat_dots_as_module=config.treat_dot_as_module,
use_exact_imports=config.use_exact_imports,
)
return Exit.OK
except InvalidClassNameError as e:
Expand Down
7 changes: 7 additions & 0 deletions datamodel_code_generator/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ def start_section(self, heading: Optional[str]) -> None:
action='store_true',
default=False,
)
model_options.add_argument(
'--use-exact-imports',
help='import exact types instead of modules, for example: "from .foo import Bar" instead of '
'"from . import foo" with "foo.Bar"',
action='store_true',
default=False,
)

# ======================================================================================
# Typing options for generated models
Expand Down
3 changes: 2 additions & 1 deletion datamodel_code_generator/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ class Imports(DefaultDict[Optional[str], Set[str]]):
def __str__(self) -> str:
return self.dump()

def __init__(self) -> None:
def __init__(self, use_exact: bool = False) -> None:
super().__init__(set)
self.alias: DefaultDict[Optional[str], Dict[str, str]] = defaultdict(dict)
self.counter: Dict[Tuple[Optional[str], str], int] = defaultdict(int)
self.reference_paths: Dict[str, Import] = {}
self.use_exact: bool = use_exact

def _set_alias(self, from_: Optional[str], imports: Set[str]) -> List[str]:
return [
Expand Down
18 changes: 16 additions & 2 deletions datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,14 @@ def relative(current_module: str, reference: str) -> Tuple[str, str]:
return left, right


def exact_import(from_: str, import_: str, short_name: str) -> Tuple[str, str]:
if from_ == '.':
# Prevents "from . import foo" becoming "from ..foo import Foo"
# when our imported module has the same parent
return f'.{import_}', short_name
return f'{from_}.{import_}', short_name


@runtime_checkable
class Child(Protocol):
@property
Expand Down Expand Up @@ -393,6 +401,7 @@ def __init__(
use_pendulum: bool = False,
http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None,
treat_dots_as_module: bool = False,
use_exact_imports: bool = False,
) -> None:
self.data_type_manager: DataTypeManager = data_type_manager_type(
python_version=target_python_version,
Expand All @@ -406,7 +415,8 @@ def __init__(
self.data_model_root_type: Type[DataModel] = data_model_root_type
self.data_model_field_type: Type[DataModelFieldBase] = data_model_field_type

self.imports: Imports = Imports()
self.imports: Imports = Imports(use_exact_imports)
self.use_exact_imports: bool = use_exact_imports
self._append_additional_imports(additional_imports=additional_imports)

self.base_class: Optional[str] = base_class
Expand Down Expand Up @@ -701,6 +711,10 @@ def __change_from_import(
from_, import_ = full_path = relative(
model.module_name, data_type.full_name
)
if imports.use_exact:
from_, import_ = exact_import(
from_, import_, data_type.reference.short_name
)
import_ = import_.replace('-', '_')
if (
len(model.module_path) > 1
Expand Down Expand Up @@ -1287,7 +1301,7 @@ class Processed(NamedTuple):
processed_models: List[Processed] = []

for module, models in module_models:
imports = module_to_import[module] = Imports()
imports = module_to_import[module] = Imports(self.use_exact_imports)
init = False
if module:
parent = (*module[:-1], '__init__.py')
Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/parser/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
use_pendulum: bool = False,
http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None,
treat_dots_as_module: bool = False,
use_exact_imports: bool = False,
) -> None:
super().__init__(
source=source,
Expand Down Expand Up @@ -227,6 +228,7 @@ def __init__(
use_pendulum=use_pendulum,
http_query_parameters=http_query_parameters,
treat_dots_as_module=treat_dots_as_module,
use_exact_imports=use_exact_imports,
)

self.data_model_scalar_type = data_model_scalar_type
Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/parser/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def __init__(
use_pendulum: bool = False,
http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None,
treat_dots_as_module: bool = False,
use_exact_imports: bool = False,
) -> None:
super().__init__(
source=source,
Expand Down Expand Up @@ -509,6 +510,7 @@ def __init__(
use_pendulum=use_pendulum,
http_query_parameters=http_query_parameters,
treat_dots_as_module=treat_dots_as_module,
use_exact_imports=use_exact_imports,
)

self.remote_object_cache: DefaultPutDict[str, Dict[str, Any]] = DefaultPutDict()
Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/parser/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def __init__(
use_pendulum: bool = False,
http_query_parameters: Optional[Sequence[Tuple[str, str]]] = None,
treat_dots_as_module: bool = False,
use_exact_imports: bool = False,
):
super().__init__(
source=source,
Expand Down Expand Up @@ -291,6 +292,7 @@ def __init__(
use_pendulum=use_pendulum,
http_query_parameters=http_query_parameters,
treat_dots_as_module=treat_dots_as_module,
use_exact_imports=use_exact_imports,
)
self.open_api_scopes: List[OpenAPIScope] = openapi_scopes or [
OpenAPIScope.Schemas
Expand Down
20 changes: 19 additions & 1 deletion tests/parser/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.pydantic import BaseModel, DataModelField
from datamodel_code_generator.parser.base import Parser, relative, sort_data_models
from datamodel_code_generator.parser.base import (
Parser,
exact_import,
relative,
sort_data_models,
)
from datamodel_code_generator.reference import Reference, snake_to_upper_camel
from datamodel_code_generator.types import DataType

Expand Down Expand Up @@ -183,6 +188,19 @@ def test_relative(current_module: str, reference: str, val: Tuple[str, str]):
assert relative(current_module, reference) == val


@pytest.mark.parametrize(
'from_,import_,name,val',
[
('.', 'mod', 'Foo', ('.mod', 'Foo')),
('.a', 'mod', 'Foo', ('.a.mod', 'Foo')),
('..a', 'mod', 'Foo', ('..a.mod', 'Foo')),
('..a.b', 'mod', 'Foo', ('..a.b.mod', 'Foo')),
],
)
def test_exact_import(from_: str, import_: str, name: str, val: Tuple[str, str]):
assert exact_import(from_, import_, name) == val


@pytest.mark.parametrize(
'word,expected',
[
Expand Down

0 comments on commit 17c356a

Please sign in to comment.