Skip to content

Commit

Permalink
support ctor np array from of tensor (#7970)
Browse files Browse the repository at this point in the history
* support ctor np array from of tensor

* add test case constructing np array from tensor

* refine

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
liufengwei0103 and mergify[bot] authored Apr 8, 2022
1 parent 3932e16 commit bc0a9b3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,7 @@ def RegisterMethods():
Tensor.__invert__ = _invert
Tensor.__float__ = _scalar_float
Tensor.__int__ = _scalar_int
Tensor.__array__ = _numpy
Tensor.uniform_ = _uniform
Tensor.trunc_normal_ = _trunc_normal_
Tensor.kaiming_uniform_ = _kaiming_uniform
Expand Down
9 changes: 9 additions & 0 deletions python/oneflow/test/tensor/test_tensor_part_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ def test_construct_from_another_tensor(test_case):
test_case.assertEqual(output.dtype, flow.float32)
test_case.assertTrue(np.allclose(output.numpy(), np_arr))

@flow.unittest.skip_unless_1n1d()
def test_construct_np_array_from_tensor(test_case):
tensor = flow.randn(5)
np_arr = np.array(tensor)
test_case.assertEqual(np_arr.shape, (5,))
test_case.assertEqual(np_arr.dtype, np.float32)
test_case.assertTrue(np.allclose(np_arr, tensor.numpy()))
test_case.assertEqual(str(np_arr), str(tensor.numpy()))

@flow.unittest.skip_unless_1n1d()
@autotest(check_graph=True)
def test_tensor_sign_with_random_data(test_case):
Expand Down

0 comments on commit bc0a9b3

Please sign in to comment.