Skip to content

Commit

Permalink
Extended unit tests to classifier and fixed pooling (#17)
Browse files Browse the repository at this point in the history
* Extended unit tests to classifier and fixed pooling

* Changed trigger of doctest workflow

* Fixing issue #18

* fixed linters

* Add pre-commit hooks

* Doctest only on PRs

* Fixed network conversion from GPU

Also tested on Windows machine.
  • Loading branch information
fpaissan authored Mar 7, 2023
1 parent f2af0d9 commit e82bcc5
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 21 deletions.
5 changes: 1 addition & 4 deletions .github/workflows/doctest.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
name: Verify docs generation

# Runs on pushes to master and all pull requests
on: # yamllint disable-line rule:truthy
push:
pull_request:
on: pull_request

jobs:
docs:
Expand All @@ -24,4 +22,3 @@ jobs:
cd docs
make doctest
make html
15 changes: 15 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: '6.0.0' # pick a git hash / tag to point to
hooks:
- id: flake8
11 changes: 8 additions & 3 deletions micromind/conversion/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
def convert_to_onnx(net: nn.Module, save_path: Path, simplify: bool = True):
"""Converts nn.Module to onnx and saves it to save_path.
Optionally simplifies it."""
x = torch.zeros([1] + net.input_shape)
x = torch.zeros([1] + list(net.input_shape))

torch.onnx.export(
net,
net.cpu(),
x,
save_path,
verbose=False,
Expand Down Expand Up @@ -88,13 +88,18 @@ def convert_to_tflite(
if not isinstance(save_path, Path):
save_path = Path(save_path)

if not (batch_quant is None):
batch_quant = batch_quant.cpu()

vino_sub = save_path.joinpath("vino")
os.makedirs(vino_sub, exist_ok=True)
vino_path = convert_to_openvino(net, vino_sub)
if os.name == "nt":
openvino2tensorflow_exe_cmd = [
sys.executable,
os.path.join(os.path.dirname(sys.executable), "openvino2tensorflow"),
os.path.join(
os.path.dirname(sys.executable), "Scripts", "openvino2tensorflow"
),
]
else:
openvino2tensorflow_exe_cmd = ["openvino2tensorflow"]
Expand Down
12 changes: 5 additions & 7 deletions micromind/networks/phinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,10 +829,10 @@ def __init__(

if include_top:
# Includes classification head if required
self.glob_pooling = lambda x: nn.functional.avg_pool2d(x, x.size()[2:])

self.new_convolution = nn.Conv2d(
int(block_filters * alpha), num_classes, kernel_size=1, bias=True
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(int(block_filters * alpha), num_classes, bias=True),
)

def forward(self, x):
Expand All @@ -851,8 +851,6 @@ def forward(self, x):
x = layers(x)

if self.classify:
x = self.glob_pooling(x)
x = self.new_convolution(x)
x = x.view(-1, x.shape[1])
x = self.classifier(x)

return x
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies = [
requires-python = ">=3.9"

[project.optional-dependencies]
dev = ["black", "bumpver", "flake8", "isort", "pip-tools", "pytest"]
dev = ["black", "bumpver", "flake8", "isort", "pip-tools", "pytest", "pre-commit"]

[project.urls]
Homepage = "https://github.com/fpaissan/micromind"
Expand Down
12 changes: 6 additions & 6 deletions tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def test_onnx():

save_path = "temp.onnx"

in_shape = list((3, 224, 224))
net = PhiNet(in_shape, compatibility=False)
in_shape = (3, 224, 224)
net = PhiNet(in_shape, compatibility=False, include_top=True)

convert_to_onnx(net, save_path, simplify=True)
import os
Expand All @@ -34,8 +34,8 @@ def test_openvino():

save_dir = "vino"

in_shape = list((3, 224, 224))
net = PhiNet(in_shape, compatibility=False)
in_shape = (3, 224, 224)
net = PhiNet(in_shape, compatibility=False, include_top=True)

convert_to_openvino(net, save_dir)

Expand All @@ -50,8 +50,8 @@ def test_tflite():

save_path = "tflite"

in_shape = list((3, 224, 224))
net = PhiNet(in_shape, compatibility=False)
in_shape = (3, 224, 224)
net = PhiNet(in_shape, compatibility=False, include_top=True)

convert_to_tflite(net, save_path)

Expand Down

0 comments on commit e82bcc5

Please sign in to comment.