Skip to content
This repository has been archived by the owner on Nov 7, 2024. It is now read-only.

Add power function and test to backend #846

Merged
merged 2 commits into from
Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tensornetwork/backends/abstract_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,3 +988,26 @@ def deserialize_tensor(self, s: str) -> Tensor:
raise NotImplementedError(
"Backend '{}' has not implemented deserialize_tensor.".format(
self.name))

def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor:
"""
Returns the exponentiation of tensor a raised to b.
If b is a tensor, then the exponentiation is element-wise
between the two tensors, with a as the base and b as the power.
Note that a and b must be broadcastable to the same shape if
b is a tensor.
If b is a scalar, then the exponentiation is each value in a
raised to the power of b.

Args:
a: The tensor containing the bases.
alewis marked this conversation as resolved.
Show resolved Hide resolved
b: The tensor containing the powers; or a single scalar as the power.

Returns:
The tensor that is each element of a raised to the
power of b. Note that the shape of the returned tensor
is that produced by the broadcast of a and b.
"""
raise NotImplementedError(
f"Backend {self.name} has not implemented power.")

5 changes: 5 additions & 0 deletions tensornetwork/backends/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,8 @@ def test_pivot_not_implemented():
backend = AbstractBackend()
with pytest.raises(NotImplementedError):
backend.pivot(np.ones((2, 2)))

def test_power_not_implemented():
backend = AbstractBackend()
with pytest.raises(NotImplementedError):
backend.power(np.array([1, 2]), np.array([1, 2]))