diff --git a/.github/workflows/doctest.yaml b/.github/workflows/doctest.yaml index 62b07b0..faae6bb 100644 --- a/.github/workflows/doctest.yaml +++ b/.github/workflows/doctest.yaml @@ -1,9 +1,7 @@ name: Verify docs generation # Runs on pushes to master and all pull requests -on: # yamllint disable-line rule:truthy - push: - pull_request: +on: pull_request jobs: docs: @@ -24,4 +22,3 @@ jobs: cd docs make doctest make html - diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..2696a36 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/psf/black + rev: 22.10.0 + hooks: + - id: black +- repo: https://github.com/pycqa/flake8 + rev: '6.0.0' # pick a git hash / tag to point to + hooks: + - id: flake8 diff --git a/micromind/conversion/convert.py b/micromind/conversion/convert.py index 197d005..f7ea1ce 100644 --- a/micromind/conversion/convert.py +++ b/micromind/conversion/convert.py @@ -25,10 +25,10 @@ def convert_to_onnx(net: nn.Module, save_path: Path, simplify: bool = True): """Converts nn.Module to onnx and saves it to save_path. Optionally simplifies it.""" - x = torch.zeros([1] + net.input_shape) + x = torch.zeros([1] + list(net.input_shape)) torch.onnx.export( - net, + net.cpu(), x, save_path, verbose=False, @@ -88,13 +88,18 @@ def convert_to_tflite( if not isinstance(save_path, Path): save_path = Path(save_path) + if not (batch_quant is None): + batch_quant = batch_quant.cpu() + vino_sub = save_path.joinpath("vino") os.makedirs(vino_sub, exist_ok=True) vino_path = convert_to_openvino(net, vino_sub) if os.name == "nt": openvino2tensorflow_exe_cmd = [ sys.executable, - os.path.join(os.path.dirname(sys.executable), "openvino2tensorflow"), + os.path.join( + os.path.dirname(sys.executable), "Scripts", "openvino2tensorflow" + ), ] else: openvino2tensorflow_exe_cmd = ["openvino2tensorflow"] diff --git a/micromind/networks/phinet.py b/micromind/networks/phinet.py index 916b5b0..22b334c 100644 --- a/micromind/networks/phinet.py +++ b/micromind/networks/phinet.py @@ -829,10 +829,10 @@ def __init__( if include_top: # Includes classification head if required - self.glob_pooling = lambda x: nn.functional.avg_pool2d(x, x.size()[2:]) - - self.new_convolution = nn.Conv2d( - int(block_filters * alpha), num_classes, kernel_size=1, bias=True + self.classifier = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(int(block_filters * alpha), num_classes, bias=True), ) def forward(self, x): @@ -851,8 +851,6 @@ def forward(self, x): x = layers(x) if self.classify: - x = self.glob_pooling(x) - x = self.new_convolution(x) - x = x.view(-1, x.shape[1]) + x = self.classifier(x) return x diff --git a/pyproject.toml b/pyproject.toml index 3821327..c046898 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ requires-python = ">=3.9" [project.optional-dependencies] -dev = ["black", "bumpver", "flake8", "isort", "pip-tools", "pytest"] +dev = ["black", "bumpver", "flake8", "isort", "pip-tools", "pytest", "pre-commit"] [project.urls] Homepage = "https://github.com/fpaissan/micromind" diff --git a/tests/test_networks.py b/tests/test_networks.py index f6599bf..09c87f5 100644 --- a/tests/test_networks.py +++ b/tests/test_networks.py @@ -13,8 +13,8 @@ def test_onnx(): save_path = "temp.onnx" - in_shape = list((3, 224, 224)) - net = PhiNet(in_shape, compatibility=False) + in_shape = (3, 224, 224) + net = PhiNet(in_shape, compatibility=False, include_top=True) convert_to_onnx(net, save_path, simplify=True) import os @@ -34,8 +34,8 @@ def test_openvino(): save_dir = "vino" - in_shape = list((3, 224, 224)) - net = PhiNet(in_shape, compatibility=False) + in_shape = (3, 224, 224) + net = PhiNet(in_shape, compatibility=False, include_top=True) convert_to_openvino(net, save_dir) @@ -50,8 +50,8 @@ def test_tflite(): save_path = "tflite" - in_shape = list((3, 224, 224)) - net = PhiNet(in_shape, compatibility=False) + in_shape = (3, 224, 224) + net = PhiNet(in_shape, compatibility=False, include_top=True) convert_to_tflite(net, save_path)