Skip to content

Commit

Permalink
fixes inverse for thetas > 0
Browse files Browse the repository at this point in the history
  • Loading branch information
MArpogaus committed Oct 2, 2023
1 parent ff84814 commit 35cd580
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions src/bernstein_flow/bijectors/bernstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def root_search_fn(objective_fn, _, max_iterations=None):
iteration,
) = tfp.math.find_root_chandrupatla(
objective_fn,
low=self.z_min,
high=self.z_max,
low=tf.convert_to_tensor(0, dtype=dtype),
high=tf.convert_to_tensor(1, dtype=dtype),
position_tolerance=1e-6,
# value_tolerance=1e-7,
max_iterations=max_iterations,
Expand All @@ -184,9 +184,5 @@ def _forward_log_det_jacobian(self, y):
ldj = tf.math.log(dz_dy)
return reshape_out(batch_shape, sample_shape, ldj)

def inverse(self, z):
y = super().inverse(z)
return tf.clip_by_value(y, self.clip_inverse, 1.0 - self.clip_inverse)

def _is_increasing(self, **kwargs):
return tf.reduce_all(self.thetas[..., 1:] >= self.thetas[..., :-1])

2 comments on commit 35cd580

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old Faithful

Learning Curve

Metrics

Min of loss: -0.7146248817443848

Parameter Vector

a1 = array([6.957123], dtype=float32)
b1 = array([-0.5075588], dtype=float32)
thetas = array([-3.4682438 , -0.2741778 , -0.24878922, -0.2481467 , -0.24785   ,
   -0.24761055, -0.24731891, -0.24650416,  2.5771663 ], dtype=float32)
a2 = array([1.7397425], dtype=float32)

Results

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bimodal Model

Learning Curve

Learning Curve

Metrics

loss: -0.8930790424346924
val_loss: -0.9371266961097717

Results

Parameter Vector for x = 1

BernsteinFlow:
invert_chain_of_bpoly_of_scale1_of_shift1:
chain_of_bpoly_of_scale1_of_shift1:
bpoly: [-3.0000067e+00 -2.1729879e+00 -1.3459691e+00 -1.3355074e+00
-1.8175721e-02 -6.7576133e-03 -6.7430581e-03 -6.7285029e-03
-6.7139477e-03 -6.6993926e-03 -6.6848374e-03 -6.6702822e-03
-6.6557270e-03 -6.6411719e-03 -6.6266167e-03 -6.6120615e-03
-2.5624121e-03 6.3137040e+00 1.2629971e+01]
scale1: 0.47052833437919617
shift1: 0.672807514667511

Flow



Bijector


Please sign in to comment.