Skip to content

Commit

Permalink
Pass dtype directly to zeros_like
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 8, 2024
1 parent 935ce79 commit aa616e6
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def grad(self, ins, grads):
# `condition` does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition
condition_grad = condition.zeros_like().astype(config.floatX)
condition_grad = condition.zeros_like(dtype=config.floatX)

return [
condition_grad,
Expand Down
48 changes: 24 additions & 24 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,8 +1323,8 @@ def L_op(self, inputs, outputs, output_gradients):
x, y = inputs
assert outputs[0].type == bool
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]

def c_code_cache_version(self):
Expand Down Expand Up @@ -1358,7 +1358,7 @@ def output_types(self, *input_dtypes):
def L_op(self, inputs, outputs, output_gradients):
(x,) = inputs
assert outputs[0].type == bool
return [x.zeros_like().astype(config.floatX)]
return [x.zeros_like(dtype=config.floatX)]

def c_code_cache_version(self):
super_version = super().c_code_cache_version()
Expand Down Expand Up @@ -1577,7 +1577,7 @@ def get_grad(self, elem):
)
raise NotImplementedError(msg)
elif elem.type in discrete_types:
return elem.zeros_like().astype(config.floatX)
return elem.zeros_like(dtype=config.floatX)
else:
return elem.zeros_like()

Expand Down Expand Up @@ -1611,13 +1611,13 @@ def L_op(self, inputs, outputs, gout):
second_part = switch(cond, 0.0, gz)

if outputs[0].type in discrete_types:
first_part = ift.zeros_like(config.floatX)
second_part = iff.zeros_like(config.floatX)
first_part = ift.zeros_like(dtype=config.floatX)
second_part = iff.zeros_like(dtype=config.floatX)

# cond does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition
condition_grad = cond.zeros_like().astype(config.floatX)
condition_grad = cond.zeros_like(dtype=config.floatX)

return (condition_grad, first_part, second_part)

Expand All @@ -1644,7 +1644,7 @@ def output_types(self, *input_types):
return upcast_out(*input_types[0])

def grad(self, inputs, output_gradients):
return [inputs[0].zeros_like().astype(config.floatX)]
return [inputs[0].zeros_like(dtype=config.floatX)]


class BinaryBitOp(BinaryScalarOp):
Expand All @@ -1664,8 +1664,8 @@ def output_types(self, *input_types):
def grad(self, inputs, output_gradients):
a, b = inputs
return [
a.zeros_like().astype(config.floatX),
b.zeros_like().astype(config.floatX),
a.zeros_like(dtype=config.floatX),
b.zeros_like(dtype=config.floatX),
]


Expand Down Expand Up @@ -1776,8 +1776,8 @@ def L_op(self, inputs, outputs, gout):

if outputs[0].type in discrete_types:
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]
# This form handle the case when both value are the same.
# In that case, gx will be gz, gy will be 0.
Expand Down Expand Up @@ -1818,8 +1818,8 @@ def L_op(self, inputs, outputs, gout):

if outputs[0].type in discrete_types:
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]
# This form handle the case when both value are the same.
# In that case, gx will be gz, gy will be 0.
Expand Down Expand Up @@ -1861,7 +1861,7 @@ def L_op(self, inputs, outputs, gout):
retval = []
for ii, inp in enumerate(inputs):
if hasattr(inp, "zeros_like"):
retval.append(inp.zeros_like().astype(config.floatX))
retval.append(inp.zeros_like(dtype=config.floatX))
else:
retval.append(grad_undefined(self, ii, inp))
else:
Expand Down Expand Up @@ -1937,7 +1937,7 @@ def grad(self, inputs, gout):
)

if output_type in discrete_types:
return [ipt.zeros_like().astype(config.floatX) for ipt in inputs]
return [ipt.zeros_like(dtype=config.floatX) for ipt in inputs]

for input in inputs:
if gz.type in complex_types:
Expand Down Expand Up @@ -1980,8 +1980,8 @@ def L_op(self, inputs, outputs, gout):
raise NotImplementedError()
if outputs[0].type in discrete_types:
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]

first_part = gz
Expand Down Expand Up @@ -2293,8 +2293,8 @@ def L_op(self, inputs, outputs, gout):

if outputs[0].type in discrete_types:
return [
x.zeros_like().astype(config.floatX),
y.zeros_like().astype(config.floatX),
x.zeros_like(dtype=config.floatX),
y.zeros_like(dtype=config.floatX),
]

first_part = gz * y * x ** (y - 1)
Expand Down Expand Up @@ -2385,7 +2385,7 @@ def L_op(self, inputs, outputs, gout):

def handle_int(v):
if outputs[0].type in int_types:
return v.zeros_like().astype(config.floatX)
return v.zeros_like(dtype=config.floatX)
return v

return list(map(handle_int, [gx, gmn, gmx]))
Expand Down Expand Up @@ -2422,7 +2422,7 @@ def grad(self, inputs, gout):
# to deal with real-valued inputs by rounding them to the
# nearest integer. f(x+eps) thus equals f(x) so the gradient
# is zero, not disconnected or undefined
return DisconnectedType()(), y.zeros_like()
return DisconnectedType()(), y.zeros_like(dtype=config.floatX)


second = Second(transfer_type(1), name="second")
Expand Down Expand Up @@ -2494,7 +2494,7 @@ def grad(self, inputs, gout):
if self.o_type in continuous_types:
return [gz]
else:
return [x.zeros_like().astype(config.floatX)]
return [x.zeros_like(dtype=config.floatX)]

def c_code_cache_version(self):
s = super().c_code_cache_version()
Expand Down Expand Up @@ -2715,7 +2715,7 @@ def impl(self, x):
def grad(self, inputs, gout):
(x,) = inputs
(gz,) = gout
return [x.zeros_like().astype(config.floatX)]
return [x.zeros_like(dtype=config.floatX)]

def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def grad(self, inp, grads):
# Currently, pytensor.grad insists that the dtype of the returned
# gradient has a float dtype, so we use floatX.
if s.type.dtype in discrete_dtypes:
return [s.zeros_like().astype(config.floatX)]
return [s.zeros_like(dtype=config.floatX)]

raise NotImplementedError("grad not implemented for complex dtypes")

Expand Down Expand Up @@ -1876,7 +1876,7 @@ def infer_shape(self, fgraph, node, ishapes):
def grad(self, inputs, output_gradients):
# If the output is of an integer dtype, no gradient shall pass
if self.dtype in discrete_dtypes:
return [ipt.zeros_like().astype(config.floatX) for ipt in inputs]
return [ipt.zeros_like(dtype=config.floatX) for ipt in inputs]

grads = [output_gradients[0][i] for i in range(len(inputs))]
return grads
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def grad(self, inputs, grads):
x = inputs[0]
rest = inputs[1:]
if x.dtype in discrete_dtypes:
first = x.zeros_like().astype(config.floatX)
first = x.zeros_like(dtype=config.floatX)
else:
# For best optimization, we let this as an inc.
# This allow the opt local_IncSubtensor_serialize to apply first.
Expand Down

0 comments on commit aa616e6

Please sign in to comment.