-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support torch tensors #7
Changes from 2 commits
cc6902d
91d7144
6391de4
dc1ac17
d424d22
a9c39c4
4d1e603
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,17 +105,21 @@ def to_content_buf(data): | |
else: | ||
raise ValueError("Expected a 3D ndarray (RGB/RGBA image) or 2D (grayscale image), " | ||
"but given shape: {}".format(im.shape)) | ||
return _get_bytes_from_numpy(im, mode) | ||
|
||
try: | ||
from PIL import Image | ||
except ImportError as e: | ||
raise ImportError(e.msg + | ||
"\nTo draw numpy arrays, we require Pillow. " + | ||
"(pip install Pillow)") # TODO; reraise | ||
|
||
with io.BytesIO() as buf: | ||
Image.fromarray(im, mode=mode).save(buf, format='png') | ||
return buf.getvalue() | ||
elif 'torch' in sys.modules and isinstance(data, sys.modules['torch'].Tensor): | ||
# numpy ndarray: convert to png | ||
im = data | ||
if im.shape[0] == 1: | ||
mode = 'L' # 8-bit pixels, grayscale | ||
im = im.mul(255).byte().squeeze().numpy() | ||
elif im.shape[0] == 3: | ||
mode = None # RGB/RGBA | ||
im = im.mul(255).byte().permute(1, 2, 0).numpy() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One should not assume anything about RGB/BGR order and datatype (uint8 or float). But why don't we just do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, You are right. |
||
else: | ||
raise ValueError("Expected a 3D ndarray (RGB/RGBA image) or 2D (grayscale image), " | ||
"but given shape: {}".format(im.shape)) | ||
return _get_bytes_from_numpy(im, mode) | ||
|
||
elif 'PIL.Image' in sys.modules and isinstance(data, sys.modules['PIL.Image'].Image): | ||
# PIL/Pillow images | ||
|
@@ -288,5 +292,18 @@ def main(): | |
return 0 | ||
|
||
|
||
def _get_bytes_from_numpy(im, mode): | ||
try: | ||
from PIL import Image | ||
except ImportError as e: | ||
raise ImportError(e.msg + | ||
"\nTo draw numpy arrays, we require Pillow. " + | ||
"(pip install Pillow)") # TODO; reraise | ||
|
||
with io.BytesIO() as buf: | ||
Image.fromarray(im, mode=mode).save(buf, format='png') | ||
return buf.getvalue() | ||
|
||
|
||
if __name__ == '__main__': | ||
sys.exit(main()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit