-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_modules.py
34 lines (31 loc) · 1.15 KB
/
custom_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""handle all modules using importlib"""
import importlib
from pathlib import Path
import toml
cmods = []
mapping = {}
#to whoever has to improve this code, i am so sorry
def import_all_custom_modules_needed():
"""import all custom modules"""
global cmods # pylint: disable=global-statement
cfgs = []
for config in Path("./models").glob("*.toml"):
with open(config, "r", encoding="utf-8") as _:
cfgs.append(toml.load(_))
for config in cfgs:
cmods.append(config["module"])
cmods = list(set(cmods))
for cmod in cmods:
_tmod = importlib.import_module(f"modules.{cmod}")
ali = _tmod.MODEL_NAME
mapping[ali] = _tmod.MODEL[0]
def import_custom_module(cmod):
"""get the model class for the module called `cmod`"""
_tmod = importlib.import_module(f"modules.{cmod}")
ali:str = _tmod.MODEL_NAME
return mapping[ali]
def get_required_model_class(modpath:Path):
"""get the model class for the model at `modpath`"""
with open(modpath.with_suffix("".join(modpath.suffix+".toml")), "r", encoding="utf-8") as _:
_cfg = toml.load(_)
return import_custom_module(_cfg['module'])