Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch tensors #7

Merged
merged 7 commits into from
Mar 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions imgcat/imgcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,20 @@ def to_content_buf(data):
Image.fromarray(im, mode=mode).save(buf, format='png')
return buf.getvalue()

elif 'torch' in sys.modules and isinstance(data, sys.modules['torch'].Tensor):
# pytorch tensor: convert to png
im = data
try:
from torchvision import transforms
except ImportError as e:
raise ImportError(e.msg +
"\nTo draw torch tensor, we require torchvision. " +
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you fix grammar issues here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is grammar issues?

"(pip install torchvision)")

with io.BytesIO() as buf:
transforms.ToPILImage()(im).save(buf, format='png')
return buf.getvalue()

elif 'PIL.Image' in sys.modules and isinstance(data, sys.modules['PIL.Image'].Image):
# PIL/Pillow images
img = data
Expand Down
15 changes: 15 additions & 0 deletions imgcat/test_imgcat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,21 @@ def test_numpy(self):
a[:, :, 0] = 255 # (255, 0, 0): red
imgcat(a)

@unittest.skipIf(sys.version_info < (3, 5), "Only in Python 3.5+")
def test_torch(self):
import torch

# uint8, grayscale
a = torch.ones([1, 32, 32], dtype=torch.uint8)
imgcat(a)

a = torch.ones([1, 32, 32], dtype=torch.float32)
imgcat(a)

# uint8, color image
a = torch.ones([3, 32, 32], dtype=torch.uint8) * 0
imgcat(a)

def test_matplotlib(self):
# plt
import matplotlib.pyplot as plt
Expand Down
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,20 @@ def read_version():
'pytest<5.0',
'numpy',
]
if sys.version_info >= (3, 5):
tests_requires += ['torch', 'torchvision']

if sys.version_info >= (3, 6):
tests_requires += ['matplotlib>=3.1', 'Pillow']
elif sys.version_info >= (3, 5):
tests_requires += ['matplotlib~=3.0.3', 'Pillow']
else: # <= Python 3.4
tests_requires += ['matplotlib<3.0', 'Pillow<6.0']

# pytorch: python 2.7 require future
if sys.version_info < (3, 0):
tests_requires += ['future']
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate on why future is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python2.7 import error.
No module named builtins
for avoid this, Install future.

see.
https://discuss.pytorch.org/t/building-pytorch-from-source-in-conda-fails-in-pytorch-caffe2-operators-crash-op-cc/42859/3
hyperopt/hyperopt#273

root@1f2b2ea83553:~# pip install torch
DEPRECATION: Python 2.7 reached the end of its life on January 1st, 2020. Please upgrade your Python as Python 2.7 is no longer maintained. A future version of pip will drop support for Python 2.7. More details about Python 2 support in pip, can be found at https://pip.pypa.io/en/latest/development/release-process/#python-2-support
Collecting torch
  Downloading torch-1.4.0-cp27-cp27mu-manylinux1_x86_64.whl (753.4 MB)
     |████████████████████████████████| 753.4 MB 2.5 kB/s
Installing collected packages: torch
Successfully installed torch-1.4.0
root@1f2b2ea83553:~# python
Python 2.7.17 (default, Feb 26 2020, 17:18:08)
[GCC 8.3.0] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python2.7/site-packages/torch/__init__.py", line 19, in <module>
    from ._six import string_classes as _string_classes
  File "/usr/local/lib/python2.7/site-packages/torch/_six.py", line 23, in <module>
    import builtins
ImportError: No module named builtins


__version__ = read_version()


Expand Down