-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Features/125 modf #402
Features/125 modf #402
Changes from 15 commits
ec6c508
f70fb61
b3d1c71
71b8832
828cb50
9a8d6f3
4ae4be5
2e57691
b045297
3d23b9d
b99d531
11aa1f1
bbe33a6
e123b3c
0ce6890
f25a448
5d088b9
db796cc
dac92bd
efcefbd
46b001c
4c9e548
f456d7b
2921012
e4ed000
1ecdeba
6a82e9c
342e661
0f3d6e4
7ad1fb3
e5aa64e
387d1f4
0af3803
9ef2614
865b868
d0613dc
bdc5be2
30464c3
c5810a6
3ecd850
c383c78
7f5e34b
567c2e4
05b319c
58ac9db
e85e5c3
9356e81
8bd1ae7
cfa1f74
2b52855
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
import torch | ||
import heat as ht | ||
|
||
from . import operations | ||
from . import dndarray | ||
from . import types | ||
|
||
__all__ = ["abs", "absolute", "ceil", "clip", "fabs", "floor", "trunc"] | ||
__all__ = ["abs", "absolute", "ceil", "clip", "fabs", "floor", "modf", "round", "trunc"] | ||
|
||
|
||
def abs(x, out=None, dtype=None): | ||
|
@@ -178,6 +179,109 @@ def floor(x, out=None): | |
return operations.__local_op(torch.floor, x, out) | ||
|
||
|
||
def modf(x, out=None): | ||
""" | ||
Return the fractional and integral parts of a tensor, element-wise. | ||
The fractional and integral parts are negative if the given number is negative. | ||
|
||
Parameters | ||
---------- | ||
x : ht.DNDarray | ||
Input tensor | ||
out : tuple(ht.DNDarray, ht.DNDarray), optional | ||
A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. | ||
If not provided or None, a freshly-allocated tensor is returned. | ||
|
||
Returns | ||
------- | ||
tuple(ht.DNDarray: fractionalParts, ht.DNDarray: integralParts) | ||
|
||
fractionalParts : ht.DNDdarray | ||
Fractional part of x. This is a scalar if x is a scalar. | ||
|
||
integralParts : ht.DNDdarray | ||
Integral part of x. This is a scalar if x is a scalar. | ||
|
||
Examples | ||
-------- | ||
>>> ht.modf(ht.arange(-2.0, 2.0, 0.4)) | ||
(tensor([-2., -1., -1., -0., -0., 0., 0., 0., 1., 1.]), | ||
tensor([ 0.0000, -0.6000, -0.2000, -0.8000, -0.4000, 0.0000, 0.4000, 0.8000, 0.2000, 0.6000])) | ||
""" | ||
|
||
integralParts = ht.trunc(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dont need ht.trunc, should just be trunc(x) in the future please avoid using ht.function() as it causes a cyclic import. instead use the form: file.function() |
||
fractionalParts = x - integralParts | ||
|
||
if out is not None: | ||
if not isinstance(out, tuple): | ||
raise TypeError( | ||
"expected out to be None or a tuple of ht.DNDarray, but was {}".format(type(out)) | ||
) | ||
if len(out) != 2: | ||
raise ValueError( | ||
"expected out to be a tuple of length 2, but was of length {}".format(len(out)) | ||
) | ||
if (not isinstance(out[0], ht.DNDarray)) or (not isinstance(out[1], ht.DNDarray)): | ||
raise TypeError( | ||
"expected out to be None or a tuple of ht.DNDarray, but was ({}, {})".format( | ||
type(out[0]), type(out[1]) | ||
) | ||
) | ||
out[0]._DNDarray__array = fractionalParts._DNDarray__array | ||
out[1]._DNDarray__array = integralParts._DNDarray__array | ||
return out | ||
|
||
return (fractionalParts, integralParts) | ||
|
||
|
||
def round(x, decimals=0, out=None, dtype=None): | ||
""" | ||
Calculate the rounded value element-wise. | ||
|
||
Parameters | ||
---------- | ||
x : ht.DNDarray | ||
The values for which the compute the rounded value. | ||
decimals: int, optional | ||
Number of decimal places to round to (default: 0). | ||
If decimals is negative, it specifies the number of positions to the left of the decimal point. | ||
out : ht.DNDarray, optional | ||
A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. | ||
If not provided or None, a freshly-allocated array is returned. | ||
dtype : ht.type, optional | ||
Determines the data type of the output array. The values are cast to this type with potential loss of | ||
precision. | ||
|
||
|
||
Returns | ||
------- | ||
rounded_values : ht.DNDarray | ||
A tensor containing the rounded value of each element in x. | ||
|
||
Examples | ||
-------- | ||
>>> ht.round(ht.arange(-2.0, 2.0, 0.4)) | ||
tensor([-2., -2., -1., -1., -0., 0., 0., 1., 1., 2.]) | ||
|
||
""" | ||
if dtype is not None and not issubclass(dtype, types.generic): | ||
raise TypeError("dtype must be a heat data type") | ||
|
||
if decimals != 0: | ||
x *= 10 ** decimals | ||
|
||
rounded_values = operations.__local_op(torch.round, x, out) | ||
|
||
if decimals != 0: | ||
rounded_values /= 10 ** decimals | ||
|
||
if dtype is not None: | ||
rounded_values._DNDarray__array = rounded_values._DNDarray__array.type(dtype.torch_type()) | ||
rounded_values._DNDarray__dtype = dtype | ||
|
||
return rounded_values | ||
|
||
|
||
def trunc(x, out=None): | ||
""" | ||
Return the trunc of the input, element-wise. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -162,6 +162,70 @@ def test_floor(self): | |
with self.assertRaises(TypeError): | ||
ht.floor(object()) | ||
|
||
def test_modf(self): | ||
start, end, step = -5.0, 5.0, 1.4 | ||
comparison = np.modf(np.arange(start, end, step, np.float32)) | ||
|
||
# exponential of float32 | ||
float32_tensor = ht.arange(start, end, step, dtype=ht.float32) | ||
float32_modf = float32_tensor.modf() | ||
self.assertIsInstance(float32_modf[0], ht.DNDarray) | ||
self.assertIsInstance(float32_modf[1], ht.DNDarray) | ||
self.assertEqual(float32_modf[0].dtype, ht.float32) | ||
self.assertEqual(float32_modf[1].dtype, ht.float32) | ||
self.assertTrue((x for x in float32_modf[0]._DNDarray__array) == y for y in comparison[0]) | ||
self.assertTrue((x for x in float32_modf[1]._DNDarray__array) == y for y in comparison[1]) | ||
|
||
# exponential of float64 | ||
comparison = np.modf(np.arange(start, end, step, np.float64)) | ||
|
||
float64_tensor = ht.arange(start, end, step, dtype=ht.float64) | ||
float64_modf = float64_tensor.modf() | ||
self.assertIsInstance(float64_modf[0], ht.DNDarray) | ||
self.assertIsInstance(float64_modf[1], ht.DNDarray) | ||
self.assertEqual(float64_modf[0].dtype, ht.float64) | ||
self.assertEqual(float64_modf[1].dtype, ht.float64) | ||
self.assertTrue((x for x in float32_modf[0]._DNDarray__array) == y for y in comparison[0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. while technically correct, a neater way to do this might be to use |
||
self.assertTrue((x for x in float32_modf[1]._DNDarray__array) == y for y in comparison[1]) | ||
|
||
# check exceptions | ||
with self.assertRaises(TypeError): | ||
ht.modf([0, 1, 2, 3]) | ||
with self.assertRaises(TypeError): | ||
ht.modf(object()) | ||
with self.assertRaises(TypeError): | ||
ht.modf(float32_tensor, 1) | ||
with self.assertRaises(ValueError): | ||
ht.modf(float32_tensor, (float32_tensor, float32_tensor, float64_tensor)) | ||
with self.assertRaises(TypeError): | ||
ht.modf(float32_tensor, (float32_tensor, 2)) | ||
|
||
def test_round(self): | ||
start, end, step = -5.0, 5.0, 1.4 | ||
comparison = torch.arange(start, end, step, dtype=torch.float64).round() | ||
|
||
# exponential of float32 | ||
float32_tensor = ht.arange(start, end, step, dtype=ht.float32) | ||
float32_round = float32_tensor.round() | ||
self.assertIsInstance(float32_round, ht.DNDarray) | ||
self.assertEqual(float32_round.dtype, ht.float32) | ||
self.assertEqual(float32_round.dtype, ht.float32) | ||
self.assertTrue((float32_round._DNDarray__array == comparison.float()).all()) | ||
|
||
# exponential of float64 | ||
float64_tensor = ht.arange(start, end, step, dtype=ht.float64) | ||
float64_round = float64_tensor.round() | ||
self.assertIsInstance(float64_round, ht.DNDarray) | ||
self.assertEqual(float64_round.dtype, ht.float64) | ||
self.assertEqual(float64_round.dtype, ht.float64) | ||
self.assertTrue((float64_round._DNDarray__array == comparison).all()) | ||
|
||
# check exceptions | ||
with self.assertRaises(TypeError): | ||
ht.round([0, 1, 2, 3]) | ||
with self.assertRaises(TypeError): | ||
ht.round(object()) | ||
|
||
def test_trunc(self): | ||
base_array = np.random.randn(20) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cyclic import