diff --git a/xarray_jax/register_pytrees.py b/xarray_jax/register_pytrees.py index 17ae3f1..2349e58 100644 --- a/xarray_jax/register_pytrees.py +++ b/xarray_jax/register_pytrees.py @@ -6,9 +6,9 @@ import contextvars VarChangeFn = Callable[[xarray.Variable], xarray.Variable] -_VAR_CHANGE_ON_UNFLATTEN_FN: contextvars.ContextVar[VarChangeFn] = ( - contextvars.ContextVar("var_change_on_unflatten_fn") -) +_VAR_CHANGE_ON_UNFLATTEN_FN: contextvars.ContextVar[ + VarChangeFn +] = contextvars.ContextVar("var_change_on_unflatten_fn") @contextlib.contextmanager @@ -27,7 +27,7 @@ def _flatten_variable( children = (v._data,) aux = ( v._dims, - # Xarray will sometimes turn None into empty dictionaries. To maintain consistent tree structures, we convert empty dictionaries to None. + # Xarray will sometimes turn None into empty dictionaries. To maintain consistent tree structures, we convert None to empty dictionaries. # https://github.com/pydata/xarray/issues/9560 {} if v._attrs is None else v._attrs, )