Skip to content

Commit

Permalink
refactor(encoder): refactoring torch encoders
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed Apr 5, 2020
1 parent ee2f658 commit 01a830b
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 56 deletions.
2 changes: 0 additions & 2 deletions jina/executors/encoders/image/paddlehub.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def __init__(self,
``densenet264_imagenet``, ``densenet201_imagenet``, ``densenet169_imagenet``, ``densenet161_imagenet``,
``densenet121_imagenet``, ``darknet53_imagenet``,
``alexnet_imagenet``,
# ``pnasnet_imagenet``,
# ``nasnet_imagenet``
"""
Expand Down
30 changes: 12 additions & 18 deletions jina/executors/encoders/image/torchvision.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import numpy as np

from .. import BaseImageEncoder
from ..torchvision import TorchEncoder


class TorchImageEncoder(BaseImageEncoder):
class ImageTorchEncoder(TorchEncoder):
"""
:class:`TorchImageEncoder` encodes data from a ndarray, potentially B x (Channel x Height x Width) into a
:class:`ImageTorchEncoder` encodes data from a ndarray, potentially B x (Channel x Height x Width) into a
ndarray of `B x D`.
Internally, :class:`TorchImageEncoder` wraps the models from `torchvision.models`.
Internally, :class:`ImageTorchEncoder` wraps the models from `torchvision.models`.
https://pytorch.org/docs/stable/torchvision/models.html
"""

def __init__(self, model_name: str = 'mobilenet_v2', pool_strategy: str = 'mean', *args, **kwargs):
def __init__(self,
model_name: str = 'mobilenet_v2',
pool_strategy: str = 'mean', *args, **kwargs):
"""
:param model_name: the name of the model. Supported models include
Expand All @@ -33,28 +35,20 @@ def __init__(self, model_name: str = 'mobilenet_v2', pool_strategy: str = 'mean'
thus the output of the model will be a 2D tensor.
- `max` means that global max pooling will be applied.
"""
super().__init__(*args, **kwargs)
self.model_name = model_name
super().__init__(model_name, *args, **kwargs)
self.pool_strategy = pool_strategy
if pool_strategy not in ('mean', 'max', None):
raise NotImplementedError('unknown pool_strategy: {}'.format(self.pool_strategy))

def post_init(self):
def _build_model(self):
import torchvision.models as models
import torch
model = getattr(models, self.model_name)(pretrained=True)
self.model = model.features.eval()
device = 'cuda:0' if self.on_gpu else 'cpu'
self.model.to(torch.device(device))

def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
"""
def _get_features(self, data):
return self.model(data)

:param data: a `B x (Channel x Height x Width)` numpy ``ndarray``, `B` is the size of the batch
:return: a `B x D` numpy ``ndarray``, `D` is the output dimension
"""
import torch
feature_map = self.model(torch.from_numpy(data.astype('float32'))).detach().numpy()
def _get_pooling(self, feature_map: 'np.ndarray') -> 'np.ndarray':
if feature_map.ndim == 2 or self.pool_strategy is None:
return feature_map
return getattr(np, self.pool_strategy)(feature_map, axis=(2, 3))
39 changes: 39 additions & 0 deletions jina/executors/encoders/torchvision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np

from . import BaseNumericEncoder
from ..decorators import batching, as_ndarray


class TorchEncoder(BaseNumericEncoder):
def __init__(self,
model_name: str,
channel_axis: int = 1,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.model_name = model_name
self.channel_axis = channel_axis
self._default_channel_axis = 1

def post_init(self):
import torch
self._build_model()
device = 'cuda:0' if self.on_gpu else 'cpu'
self.model.to(torch.device(device))

@batching
@as_ndarray
def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
if self.channel_axis != self._default_channel_axis:
data = np.moveaxis(data, self.channel_axis, self._default_channel_axis)
import torch
feature_map = self._get_features(torch.from_numpy(data.astype('float32'))).detach().numpy()
return self._get_pooling(feature_map)

def _build_model(self):
raise NotImplementedError

def _get_features(self, data):
raise NotImplementedError

def _get_pooling(self, feature_map):
return feature_map
34 changes: 9 additions & 25 deletions jina/executors/encoders/video/torchvision.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,26 @@
import numpy as np

from .. import BaseVideoEncoder
from ..torchvision import TorchEncoder


class TorchVideoEncoder(BaseVideoEncoder):
class VideoTorchEncoder(TorchEncoder):
"""
:class:`TorchVideoEncoder` encodes data from a ndarray, potentially B x T x (Channel x Height x Width) into an
:class:`VideoTorchEncoder` encodes data from a ndarray, potentially B x T x (Channel x Height x Width) into an
ndarray of `B x D`.
Internally, :class:`TorchVideoEncoder` wraps the models from `torchvision.models`.
Internally, :class:`VideoTorchEncoder` wraps the models from `torchvision.models`.
https://pytorch.org/docs/stable/torchvision/models.html
"""
def __init__(self,
model_name: str = 'r3d_18',
*args, **kwargs):
def __init__(self, model_name: str = 'r3d_18', *args, **kwargs):
"""
:param model_name: the name of the model. Supported models include ``r3d_18``, ``mc3_18``, ``r2plus1d_18``
"""
super().__init__(*args, **kwargs)
self.model_name = model_name
super().__init__(model_name, *args, **kwargs)
self._default_channel_axis = 2

def post_init(self):
def _build_model(self):
import torchvision.models.video as models
import torch
model = getattr(models, self.model_name)(pretrained=True)
self.model = model.eval()
device = 'cuda:0' if self.on_gpu else 'cpu'
self.model.to(torch.device(device))

def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
"""
:param data: a `B x T x (Channel x Height x Width)` numpy ``ndarray``, `B` is the size of the batch
:return: a `B x D` numpy ``ndarray``, `D` is the output dimension
"""
import torch
return self._get_features(
torch.from_numpy(np.moveaxis(data.astype('float32'), 1, 2))).detach().numpy()
self.model = getattr(models, self.model_name)(pretrained=True).eval()

def _get_features(self, x):
x = self.model.stem(x)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_exec_encoder_cv_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
import numpy as np

from jina.executors import BaseExecutor
from jina.executors.encoders.image.torchvision import TorchImageEncoder
from jina.executors.encoders.image.torchvision import ImageTorchEncoder
from tests import JinaTestCase


class MyTestCase(JinaTestCase):
@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_encoding_results(self):
encoder = TorchImageEncoder()
encoder = ImageTorchEncoder()
test_data = np.random.rand(2, 3, 224, 224)
encoded_data = encoder.encode(test_data)
self.assertEqual(encoded_data.shape, (2, 1280))

@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_save_and_load(self):
encoder = TorchImageEncoder()
encoder = ImageTorchEncoder()
test_data = np.random.rand(2, 3, 224, 224)
encoded_data_control = encoder.encode(test_data)
encoder.touch()
Expand All @@ -33,7 +33,7 @@ def test_save_and_load(self):

@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_save_and_load_config(self):
encoder = TorchImageEncoder()
encoder = ImageTorchEncoder()
encoder.save_config()
self.assertTrue(os.path.exists(encoder.config_abspath))
encoder_loaded = BaseExecutor.load_config(encoder.config_abspath)
Expand Down
14 changes: 7 additions & 7 deletions tests/test_exec_encoder_video_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
import numpy as np

from jina.executors import BaseExecutor
from jina.executors.encoders.video.torchvision import TorchVideoEncoder
from jina.executors.encoders.video.torchvision import VideoTorchEncoder
from tests import JinaTestCase


class MyTestCase(JinaTestCase):
@unittest.skipIf(os.getenv('JINA_SKIP_TEST_PRETRAINED', True), 'skip the pretrained test if not set')
@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_encoding_results(self):
encoder = TorchVideoEncoder()
encoder = VideoTorchEncoder()
test_data = np.random.rand(2, 3, 3, 112, 112)
encoded_data = encoder.encode(test_data)
self.assertEqual(encoded_data.shape, (2, 512))

@unittest.skipIf(os.getenv('JINA_SKIP_TEST_PRETRAINED', True), 'skip the pretrained test if not set')
@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_save_and_load(self):
encoder = TorchVideoEncoder()
encoder = VideoTorchEncoder()
test_data = np.random.rand(2, 3, 3, 112, 112)
encoded_data_control = encoder.encode(test_data)
encoder.touch()
Expand All @@ -31,9 +31,9 @@ def test_save_and_load(self):
self.add_tmpfile(
encoder.config_abspath, encoder.save_abspath, encoder_loaded.config_abspath, encoder_loaded.save_abspath)

@unittest.skipIf(os.getenv('JINA_SKIP_TEST_PRETRAINED', True), 'skip the pretrained test if not set')
@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_save_and_load_config(self):
encoder = TorchVideoEncoder()
encoder = VideoTorchEncoder()
encoder.save_config()
self.assertTrue(os.path.exists(encoder.config_abspath))
encoder_loaded = BaseExecutor.load_config(encoder.config_abspath)
Expand Down

0 comments on commit 01a830b

Please sign in to comment.