diff --git a/pyproject.toml b/pyproject.toml index b97ec97..1db1b6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "torcharc" -version = "2.1.1" +version = "2.1.2" description = "Build PyTorch models by specifying architectures." readme = "README.md" requires-python = ">=3.12" diff --git a/torcharc/validator/modules.py b/torcharc/validator/modules.py index 95e1fa1..0c6023e 100644 --- a/torcharc/validator/modules.py +++ b/torcharc/validator/modules.py @@ -1,5 +1,5 @@ # Pydantic validation for modules spec -from pydantic import BaseModel, Field, RootModel, field_validator +from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator from torch import nn @@ -47,7 +47,7 @@ def build(self) -> nn.Module: return cls(**kwargs) -class SequentialSpec(RootModel): +class SequentialSpec(BaseModel): """ Sequential spec where key = Sequential and value = list of NNSpecs. E.g. @@ -61,7 +61,9 @@ class SequentialSpec(RootModel): out_features: 10 """ - root: dict[str, list[NNSpec]] = Field( + model_config = ConfigDict(extra="forbid") + + Sequential: list[NNSpec] = Field( description="Sequential module spec where value is a list of NNSpec.", examples=[ { @@ -74,20 +76,9 @@ class SequentialSpec(RootModel): ], ) - @field_validator("root", mode="before") - def is_single_key_dict(value: dict) -> dict: - return NNSpec.is_single_key_dict(value) - - @field_validator("root", mode="before") - def key_is_sequential(value: dict) -> dict: - assert ( - next(iter(value)) == "Sequential" - ), "Key must be 'Sequential' if using SequentialSpec." - return value - def build(self) -> nn.Sequential: """Build nn.Sequential from sequential spec.""" - nn_specs = next(iter(self.root.values())) + nn_specs = self.Sequential return nn.Sequential(*[nn_spec.build() for nn_spec in nn_specs]) @@ -138,7 +129,7 @@ class CompactValueSpec(BaseModel): ) -class CompactSpec(RootModel): +class CompactSpec(BaseModel): """ Higher level compact spec that expands into Sequential spec. This is useful for architecture search. Compact spec has the format: @@ -174,7 +165,9 @@ class CompactSpec(RootModel): p: 0.1 """ - root: dict[str, CompactValueSpec] = Field( + model_config = ConfigDict(extra="forbid") + + compact: CompactValueSpec = Field( description="Higher level compact spec that expands into Sequential spec.", examples=[ { @@ -201,17 +194,6 @@ class CompactSpec(RootModel): ], ) - @field_validator("root", mode="before") - def is_single_key_dict(value: dict) -> dict: - return NNSpec.is_single_key_dict(value) - - @field_validator("root", mode="before") - def key_is_compact(value: dict) -> dict: - assert ( - next(iter(value)) == "compact" - ), "Key must be 'compact' if using CompactSpec." - return value - def __expand_spec(self, compact_layer: dict) -> list[dict]: class_name = compact_layer["type"] keys = compact_layer["keys"] @@ -223,7 +205,7 @@ def __expand_spec(self, compact_layer: dict) -> list[dict]: return nn_specs def expand_to_sequential_spec(self) -> SequentialSpec: - compact_spec = next(iter(self.root.values())).model_dump() + compact_spec = self.compact.model_dump() prelayer = compact_spec.get("prelayer") postlayer = compact_spec.get("postlayer") nn_specs = [] diff --git a/uv.lock b/uv.lock index a394c96..4b3efd2 100644 --- a/uv.lock +++ b/uv.lock @@ -1354,7 +1354,7 @@ wheels = [ [[package]] name = "torcharc" -version = "2.1.1" +version = "2.1.2" source = { editable = "." } dependencies = [ { name = "pydantic" },