Skip to content

Commit

Permalink
Remove use of field_names_values_metadata, etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Dec 8, 2021
1 parent e8db83e commit 8b99b35
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
31 changes: 20 additions & 11 deletions efax/_src/parametrization.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

from dataclasses import fields
from functools import partial, reduce
from itertools import count
from typing import TYPE_CHECKING, Any, Dict, Iterable, Tuple, Type, TypeVar

import jax.numpy as jnp
from tjax import ComplexArray, RealArray, Shape, custom_jvp, jit
from tjax.dataclasses import field_names_values_metadata, fields

from .parameter import Support
from .tools import parameters_dot_product
Expand Down Expand Up @@ -96,15 +96,17 @@ def total_elements(dimensions: int) -> int:
return cls(**kwargs) # type: ignore

def fixed_parameters_mapping(self) -> Dict[str, Any]:
return {name: value
for name, value, metadata in field_names_values_metadata(self)
if metadata['fixed']}
return {field.name: getattr(self, field.name)
for field in fields(self)
if field.metadata['fixed']}

def parameters_value_support(self) -> Iterable[Tuple[ComplexArray, Support]]:
"""
Returns: The value and support of each variable parameter.
"""
for _, value, metadata in field_names_values_metadata(self):
for field in fields(self):
value = getattr(self, field.name)
metadata = field.metadata
if metadata['fixed']:
continue
support = metadata['support']
Expand All @@ -116,7 +118,10 @@ def parameters_name_value(self) -> Iterable[Tuple[str, ComplexArray]]:
"""
Returns: The name and value of each variable parameter.
"""
for name, value, metadata in field_names_values_metadata(self):
for field in fields(self):
name = field.name
value = getattr(self, name)
metadata = field.metadata
if metadata['fixed']:
continue
yield name, value
Expand All @@ -125,7 +130,10 @@ def parameters_name_value_support(self) -> Iterable[Tuple[str, ComplexArray, Sup
"""
Returns: The name, value, and support of each variable parameter.
"""
for name, value, metadata in field_names_values_metadata(self):
for field in fields(self):
name = field.name
value = getattr(self, name)
metadata = field.metadata
if metadata['fixed']:
continue
support = metadata['support']
Expand All @@ -138,13 +146,14 @@ def parameters_name_support(cls) -> Iterable[Tuple[str, Support]]:
"""
Returns: The name and support of each variable parameter.
"""
for this_field in fields(cls):
if this_field.metadata['fixed']:
for field in fields(cls):
metadata = field.metadata
if metadata['fixed']:
continue
support = this_field.metadata['support']
support = metadata['support']
if not isinstance(support, Support):
raise TypeError
yield this_field.name, support
yield field.name, support

# Abstract methods -----------------------------------------------------------------------------
@property
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = 'poetry.core.masonry.api'

[tool.poetry]
name = 'efax'
version = "1.4.8"
version = "1.4.10"
description = "Exponential families for JAX"
license = 'MIT'
authors = ['Neil Girdhar <mistersheik@gmail.com>']
Expand Down

0 comments on commit 8b99b35

Please sign in to comment.