Skip to content

Commit

Permalink
Add missing properties to copied Function objects
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 20, 2022
1 parent 104dc03 commit 9ac18dc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
2 changes: 2 additions & 0 deletions aesara/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,8 @@ def checkSV(sv_ori, sv_rpl):
f_cpy.finder[swap[in_ori.variable]] = container
in_cpy.variable = swap[in_ori.variable]

f_cpy.trust_input = self.trust_input
f_cpy.unpack_single = self.unpack_single
f_cpy.name = name
f_cpy.maker.fgraph.name = name
return f_cpy
Expand Down
44 changes: 26 additions & 18 deletions tests/compile/function/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def t():
t()

def test_copy(self):
a = scalar() # the a is for 'anonymous' (un-named).
a = scalar()
x, s = scalars("xs")

f = function(
Expand All @@ -312,26 +312,34 @@ def test_copy(self):
)

g = copy.copy(f)
# if they both return, assume that they return equivalent things.

assert f.unpack_single == g.unpack_single
assert f.trust_input == g.trust_input

assert g.container[x].storage is not f.container[x].storage
assert g.container[a].storage is not f.container[a].storage
assert g.container[s].storage is not f.container[s].storage

assert g.value[a] is f.value[a] # should not have been copied
assert (
g.value[s] is not f.value[s]
) # should have been copied because it is mutable.
assert not (g.value[s] != f.value[s]).any() # its contents should be identical
# Should not have been copied
assert g.value[a] is f.value[a]

assert f(2, 1) == g(
2
) # they should be in sync, default value should be copied.
assert f(2, 1) == g(
2
) # they should be in sync, default value should be copied.
f(1, 2) # put them out of sync
assert f(1, 2) != g(1, 2) # they should not be equal anymore.
# Should have been copied because it is mutable
assert g.value[s] is not f.value[s]

# Their contents should be equal, though
assert np.array_equal(g.value[s], f.value[s])

# They should be in sync, default value should be copied
assert np.array_equal(f(2, 1), g(2))

# They should be in sync, default value should be copied
assert np.array_equal(f(2, 1), g(2))

# Put them out of sync
f(1, 2)

# They should not be equal anymore
assert not np.array_equal(f(1, 2), g(1, 2))

def test_copy_share_memory(self):
x = fscalar("x")
Expand Down Expand Up @@ -478,9 +486,9 @@ def test_copy_delete_updates(self):
ori = function([x], out, mode=mode, updates={z: z * 2})
cpy = ori.copy(delete_updates=True)

assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
assert cpy(1) == 4
assert cpy(1) == 4
assert cpy(1) == 4

# Test if unused implicit and explicit inputs from delete_updates
# are ignored as intended.
Expand Down

0 comments on commit 9ac18dc

Please sign in to comment.