Skip to content

Commit

Permalink
Merge pull request #18 from kengz/refactor
Browse files Browse the repository at this point in the history
refactor: use simpler BaseModel for Sequential, compact
  • Loading branch information
kengz authored Jan 10, 2025
2 parents 81fa559 + 5ddfad0 commit ff8a47a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "torcharc"
version = "2.1.1"
version = "2.1.2"
description = "Build PyTorch models by specifying architectures."
readme = "README.md"
requires-python = ">=3.12"
Expand Down
40 changes: 11 additions & 29 deletions torcharc/validator/modules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Pydantic validation for modules spec
from pydantic import BaseModel, Field, RootModel, field_validator
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator
from torch import nn


Expand Down Expand Up @@ -47,7 +47,7 @@ def build(self) -> nn.Module:
return cls(**kwargs)


class SequentialSpec(RootModel):
class SequentialSpec(BaseModel):
"""
Sequential spec where key = Sequential and value = list of NNSpecs.
E.g.
Expand All @@ -61,7 +61,9 @@ class SequentialSpec(RootModel):
out_features: 10
"""

root: dict[str, list[NNSpec]] = Field(
model_config = ConfigDict(extra="forbid")

Sequential: list[NNSpec] = Field(
description="Sequential module spec where value is a list of NNSpec.",
examples=[
{
Expand All @@ -74,20 +76,9 @@ class SequentialSpec(RootModel):
],
)

@field_validator("root", mode="before")
def is_single_key_dict(value: dict) -> dict:
return NNSpec.is_single_key_dict(value)

@field_validator("root", mode="before")
def key_is_sequential(value: dict) -> dict:
assert (
next(iter(value)) == "Sequential"
), "Key must be 'Sequential' if using SequentialSpec."
return value

def build(self) -> nn.Sequential:
"""Build nn.Sequential from sequential spec."""
nn_specs = next(iter(self.root.values()))
nn_specs = self.Sequential
return nn.Sequential(*[nn_spec.build() for nn_spec in nn_specs])


Expand Down Expand Up @@ -138,7 +129,7 @@ class CompactValueSpec(BaseModel):
)


class CompactSpec(RootModel):
class CompactSpec(BaseModel):
"""
Higher level compact spec that expands into Sequential spec. This is useful for architecture search.
Compact spec has the format:
Expand Down Expand Up @@ -174,7 +165,9 @@ class CompactSpec(RootModel):
p: 0.1
"""

root: dict[str, CompactValueSpec] = Field(
model_config = ConfigDict(extra="forbid")

compact: CompactValueSpec = Field(
description="Higher level compact spec that expands into Sequential spec.",
examples=[
{
Expand All @@ -201,17 +194,6 @@ class CompactSpec(RootModel):
],
)

@field_validator("root", mode="before")
def is_single_key_dict(value: dict) -> dict:
return NNSpec.is_single_key_dict(value)

@field_validator("root", mode="before")
def key_is_compact(value: dict) -> dict:
assert (
next(iter(value)) == "compact"
), "Key must be 'compact' if using CompactSpec."
return value

def __expand_spec(self, compact_layer: dict) -> list[dict]:
class_name = compact_layer["type"]
keys = compact_layer["keys"]
Expand All @@ -223,7 +205,7 @@ def __expand_spec(self, compact_layer: dict) -> list[dict]:
return nn_specs

def expand_to_sequential_spec(self) -> SequentialSpec:
compact_spec = next(iter(self.root.values())).model_dump()
compact_spec = self.compact.model_dump()
prelayer = compact_spec.get("prelayer")
postlayer = compact_spec.get("postlayer")
nn_specs = []
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ff8a47a

Please sign in to comment.