Skip to content

Commit

Permalink
Update lock and fix errors
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Dec 7, 2024
1 parent dd67372 commit a3e8868
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 170 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ repos:
- id: check-yaml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.4
rev: v0.8.2
hooks:
- id: ruff

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.389
rev: v1.1.390
hooks:
- id: pyright

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ dev-dependencies = [
'pylint >= 3.3.1',
'pyright >= 0.0.13',
'pytest >= 8',
'ruff >= 0.7',
'ruff >= 0.8.1',
]

[tool.isort]
Expand Down
6 changes: 3 additions & 3 deletions tjax/_src/cotangent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def _cotangent_combinator_bwd(f: Callable[..., tuple[XT, Y]],
strict=True)):
scaled_y_bar = tree.map(lambda y_bar_i, scale=aux_cotangent_scale: y_bar_i * scale,
y_bar)
this_xs_bar = cast(XT, (xs_zero[:i]
+ (x_bar,)
+ xs_zero[i + 1:]))
this_xs_bar = cast('XT', (xs_zero[:i]
+ (x_bar,)
+ xs_zero[i + 1:]))
this_result_bar = (this_xs_bar, scaled_y_bar)
this_args_bar = f_vjp(this_result_bar)
all_args_bar.append(this_args_bar)
Expand Down
2 changes: 1 addition & 1 deletion tjax/_src/dataclasses/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def f(x: type[Any], /) -> type[TDataclassInstance]:

# Apply dataclass function to cls.
data_clz: type[TDataclassInstance] = cast(
type[TDataclassInstance],
'type[TDataclassInstance]',
dataclasses.dataclass(init=init, repr=repr, eq=eq, order=order, frozen=frozen)(cls))

# Partition fields into static, and dynamic; and assign these to the class.
Expand Down
6 changes: 3 additions & 3 deletions tjax/_src/graph/register_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def flatten_with_keys(graph: nx.Graph[Any], /
return (zip(keys, values, strict=True), keys)

flatten_with_keys_ = cast(
Callable[[Graph], tuple[Iterable[tuple[Hashable, Any]], Hashable]],
"Callable[[Graph], tuple[Iterable[tuple[Hashable, Any]], Hashable]]",
flatten_with_keys)
unflatten_tree_ = cast(Callable[[Hashable, Any], Graph], unflatten_tree)
flatten_tree_ = cast(Callable[[Graph], tuple[Iterable[Any], Hashable]],
unflatten_tree_ = cast("Callable[[Hashable, Any], Graph]", unflatten_tree)
flatten_tree_ = cast("Callable[[Graph], tuple[Iterable[Any], Hashable]]",
flatten_tree)
register_pytree_with_keys(graph_type, flatten_with_keys_, unflatten_tree_, flatten_tree_)
2 changes: 1 addition & 1 deletion tjax/_src/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def tree_unflatten(cls,
if not isinstance(dynamic_kwargs, dict):
raise RuntimeError # noqa: TRY004

dynamic_kwargs = cast(dict[str, Any], dynamic_kwargs)
dynamic_kwargs = cast('dict[str, Any]', dynamic_kwargs)
args = cls._unpartition_args(static_argnums, static_args, dynamic_args,
callable_is_static=callable_is_static)

Expand Down
328 changes: 169 additions & 159 deletions uv.lock

Large diffs are not rendered by default.

0 comments on commit a3e8868

Please sign in to comment.