Skip to content

Commit

Permalink
Fixes exponentiation support in auto_diff.
Browse files Browse the repository at this point in the history
  • Loading branch information
PTNobel committed Apr 19, 2020
1 parent 0efca0f commit b9e3611
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
33 changes: 33 additions & 0 deletions auto_diff/vecvalder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

_HANDLED_FUNCS_AND_UFUNCS = {}

def _defer_to_val(f):
def fn(self, *args, **kwargs):
return getattr(self.val, f)(*args, **kwargs)
fn.__name__ = f
return fn

class VecValDer(np.lib.mixins.NDArrayOperatorsMixin):
__slots__ = 'val', 'der'
Expand All @@ -25,6 +30,34 @@ def transpose(self, *axes):
axes = None
return np.transpose(self, axes)

all = _defer_to_val('all')
any = _defer_to_val('any')
argmax = _defer_to_val('argmax')
argmin = _defer_to_val('argmin')
argpartition = _defer_to_val('argpartition')
argsort = _defer_to_val('argsort')
nonzero = _defer_to_val('nonzero')

def copy(self):
return VecValDer(self.val.copy(), self.der.copy())

def fill(self, value):
if isinstance(value, VecValDer):
self.val.fill(value.val)
self.der[:] = value.der
else:
self.val.fill(value)
self.der.fill(0.0)

def reshape(self, shape):
der_dim_shape = self.der.shape[len(self.val.shape):]
new_der_shape = shape + der_dim_shape
self.val.reshape(shape)
self.der.reshape(new_der_shape)

def trace(self, *args, **kwargs):
return np.trace(*args, **kwargs)

def __array_ufunc__(self, ufunc, method, *args, **kwargs):
if method == '__call__' and ufunc in _HANDLED_FUNCS_AND_UFUNCS:
return _HANDLED_FUNCS_AND_UFUNCS[ufunc](*args, **kwargs)
Expand Down
26 changes: 17 additions & 9 deletions auto_diff/vecvalder_funcs_and_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _chain_rule(f_x, dx, out=None):
if out is None:
out = np.ndarray(dx.shape) # Uninitialized memory is fine because
# we're about to overwrite each element. If we do compression of the for
# loop in the future be sure to swtich to np.zeros.
# loop in the future be sure to switch to np.zeros.
for index, y in np.ndenumerate(f_x):
out[index] = y * dx[index]
return out
Expand Down Expand Up @@ -284,25 +284,33 @@ def true_divide(x1, x2, /, out):
raise RuntimeError("This should not be occuring.")


# Tested
@register(np.float_power)
@_add_out_support
def float_power(x1, x2):
def float_power(x1, x2, /, out):
if isinstance(x1, cls) and isinstance(x2, cls):
return cls(x1.val ** x2.val, x1.val**(x2.val - 1) * (
x2.val * x1.der + x1.val * np.log(x1.val) * x2.der))
return cls(np.float_power(x1.val, x2.val, out=out.val),
np.multiply(x1.val**(x2.val - 1), (x2.val * x1.der + x1.val * np.log(x1.val) * x2.der), out=out.der))
elif isinstance(x1, cls):
return cls(x1.val ** x2, x1.val**(x2 - 1) * x2 * x1.der)
return cls(np.float_power(x1.val, x2, out=out.val), np.multiply(x1.val**(x2 - 1) * x2, x1.der, out=out.der))
elif isinstance(x2, cls):
return cls(x1.val ** x2.val, x1**(x2.val) * np.log(x1.val) * x2.der)
return cls(np.float_power(x1, x2.val, out=out.val), np.multiply(x1**(x2.val) * np.log(x1.val), x2.der, out=out.der))
else:
raise RuntimeError("This should not be occuring.")


@register(np.power)
@_add_out_support
def power(x1, x2):
return float_power(x1, x2)
def power(x1, x2, /, out):
if isinstance(x1, cls) and isinstance(x2, cls):
return cls(np.power(x1.val, x2.val, out=out.val),
np.multiply(x1.val**(x2.val - 1), (x2.val * x1.der + x1.val * np.log(x1.val) * x2.der), out=out.der))
elif isinstance(x1, cls):
return cls(np.power(x1.val, x2, out=out.val), np.multiply(x1.val**(x2 - 1) * x2, x1.der, out=out.der))
elif isinstance(x2, cls):
return cls(np.power(x1, x2.val, out=out.val), np.multiply(x1**(x2.val) * np.log(x1.val), x2.der, out=out.der))
else:
raise RuntimeError("This should not be occuring.")



# Partially Tested
Expand Down

0 comments on commit b9e3611

Please sign in to comment.