diff --git a/examples/ex5_dissipative_Lindblad.py b/examples/ex5_dissipative_Lindblad.py index b17cd28..5a9cd6d 100644 --- a/examples/ex5_dissipative_Lindblad.py +++ b/examples/ex5_dissipative_Lindblad.py @@ -56,8 +56,8 @@ def norm_fun(v, df=lambda x: x): biases = jnp.log(prob_dist[1:]) params = copy_dict(psi._param_unflatten(psi.get_parameters())) -params["params"]["outputDense"]["bias"] = biases -params["params"]["outputDense"]["kernel"] = 1e-15 * params["params"]["outputDense"]["kernel"] +params["outputDense"]["bias"] = biases +params["outputDense"]["kernel"] = 1e-15 * params["outputDense"]["kernel"] params = jnp.concatenate([p.ravel() for p in jax.tree_util.tree_flatten(params)[0]]) psi.set_parameters(params) diff --git a/examples/ex6_dissipative_Lindblad_2D.py b/examples/ex6_dissipative_Lindblad_2D.py index abd5a31..6c36c45 100644 --- a/examples/ex6_dissipative_Lindblad_2D.py +++ b/examples/ex6_dissipative_Lindblad_2D.py @@ -60,8 +60,8 @@ def xy_to_id(x, y, L): biases = jnp.log(prob_dist[1:]) params = copy_dict(psi._param_unflatten(psi.get_parameters())) -params["params"]["outputDense"]["bias"] = biases -params["params"]["outputDense"]["kernel"] = 1e-15 * params["params"]["outputDense"]["kernel"] +params["outputDense"]["bias"] = biases +params["outputDense"]["kernel"] = 1e-15 * params["outputDense"]["kernel"] params = jnp.concatenate([p.ravel() for p in jax.tree_util.tree_flatten(params)[0]]) psi.set_parameters(params)