Skip to content

Commit

Permalink
Fix torch version parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Apr 14, 2023
1 parent 404741b commit f9d4ca0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
12 changes: 6 additions & 6 deletions docs/api/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from sphinx.ext.autodoc import mock

import nncf

sys.path.insert(0, os.path.abspath('../../..'))

Expand Down Expand Up @@ -41,14 +40,15 @@ def collect_api_entities() -> List[str]:
:return: A list of fully qualified names of API symbols.
"""
modules = {}
skipped_modules = []
skipped_modules = {} # type: Dict[str, str]
import nncf
for importer, modname, ispkg in pkgutil.walk_packages(path=nncf.__path__,
prefix=nncf.__name__+'.',
onerror=lambda x: None):
try:
modules[modname] = importlib.import_module(modname)
except:
skipped_modules.append(modname)
except Exception as e:
skipped_modules[modname] = str(e)

api_fqns = []
for modname, module in modules.items():
Expand All @@ -62,7 +62,7 @@ def collect_api_entities() -> List[str]:
api_fqns.append(f"{modname}.{obj_name}")

print()
skipped_str = '\n'.join(skipped_modules)
skipped_str = '\n'.join([f"{k}: {v}" for k, v in skipped_modules.items()])
print(f"Skipped: {skipped_str}\n")

print("API entities:")
Expand All @@ -71,7 +71,7 @@ def collect_api_entities() -> List[str]:
return api_fqns


with mock(['torch', 'onnx', 'openvino', 'tensorflow', 'tensorflow_addons']):
with mock(['torch', 'torchvision', 'onnx', 'onnxruntime', 'openvino', 'tensorflow', 'tensorflow_addons']):
api_fqns = collect_api_entities()

module_fqns = set()
Expand Down
10 changes: 9 additions & 1 deletion nncf/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@

import torch
from pkg_resources import parse_version
torch_version = parse_version(torch.__version__).base_version

try:
_torch_version = torch.__version__
torch_version = parse_version(_torch_version).base_version
except:
nncf_logger.debug("Could not parse torch version")
_torch_version = '0.0.0'
torch_version = parse_version(_torch_version).base_version

if parse_version(BKC_TORCH_VERSION).base_version != torch_version:
warn_bkc_version_mismatch("torch", BKC_TORCH_VERSION, torch.__version__)

Expand Down

0 comments on commit f9d4ca0

Please sign in to comment.