Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of the Intel dGPUs devices #3058

Merged
merged 50 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
e99966b
POC: xpu support (#2553)
sovrasov Oct 25, 2023
83e7faa
Fix tv nms check (#2582)
sovrasov Oct 30, 2023
73e364f
Merge remote-tracking branch 'origin/develop' into develop-idev
sovrasov Nov 6, 2023
575b4ec
POC: HPU support (#2574)
sungchul2 Nov 8, 2023
e203e58
Merge develop & fix YOLOX inference in BFP16 on XPU (#2602)
sovrasov Nov 8, 2023
2cd9eda
Object Detection with Gaudi2 (#2608)
kprokofi Nov 10, 2023
ecb786f
Update XPU detection (#2623)
sovrasov Nov 13, 2023
08917e1
Add `ModuleCacher` for HPU graphs (#2624)
sungchul2 Nov 13, 2023
2239ac1
Update Intel devices branch to the latest develop state (#2666)
sovrasov Nov 23, 2023
03e87f5
Fix mixed & lower precision training (#2668)
sungchul2 Nov 29, 2023
c1beb05
Add XPU support to anomaly task (#2677)
sovrasov Nov 30, 2023
ae090ba
Disable mixed precision training on XPU (#2683)
eunwoosh Dec 6, 2023
0062179
Update anomaly XPU integration (#2697)
sovrasov Dec 8, 2023
6da38a3
Add XPU mixed precision plugin for lightning (#2714)
sovrasov Dec 13, 2023
71e9adc
Update code to support other features than 'train' on XPU (#2704)
eunwoosh Dec 13, 2023
927fab6
Merge develop to develop-idev (#2727)
eunwoosh Dec 14, 2023
dee113c
Deal with dynamic input shape on XPU (#2740)
eunwoosh Dec 28, 2023
54751ca
Merge develop to Intel dGPU branch (#2927)
kprokofi Feb 20, 2024
6914551
Merge develop 2 (#2936)
kprokofi Feb 20, 2024
b7a44a5
Merge remote-tracking branch 'origin/develop' into vs/idev_merge
sovrasov Feb 20, 2024
444ea5e
Merge pull request #2938 from sovrasov/vs/idev_merge
sovrasov Feb 20, 2024
f549b80
Remove XPUATSSAssigner (#2940)
kprokofi Feb 21, 2024
1246855
Merge branch 'develop' into develop-idev
sovrasov Feb 21, 2024
666e3bc
Merge remote-tracking branch 'origin/develop' into develop-idev
sovrasov Feb 21, 2024
8d3f0a6
Merge branch 'develop' into develop-idev
sovrasov Feb 22, 2024
113e24f
Fix linters
sovrasov Feb 22, 2024
59d78dd
Merge remote-tracking branch 'origin/develop' into develop-idev
sovrasov Feb 27, 2024
df13078
Fix linters
sovrasov Feb 27, 2024
3b547f2
align with pre-commit
eunwoosh Feb 28, 2024
ed12628
Fix uts
sovrasov Feb 29, 2024
0c9b97b
Merge pull request #3013 from openvinotoolkit/vs/fix_xpu_ut
sovrasov Mar 4, 2024
8181519
Resolve failed integration test w/ XPU (#3009)
eunwoosh Mar 6, 2024
dd56400
added unit tests. need to debug on XPU
kprokofi Mar 6, 2024
18904e7
refactor as class
kprokofi Mar 6, 2024
b913cb1
minor fix
kprokofi Mar 6, 2024
78a0d87
fix common tests
kprokofi Mar 6, 2024
7f6862c
fix common for XPU
kprokofi Mar 6, 2024
ce21edb
minor
kprokofi Mar 6, 2024
c36c259
fix anomaly tests
kprokofi Mar 6, 2024
d566b7e
add headers
kprokofi Mar 6, 2024
6a45aad
revert experiments back
kprokofi Mar 6, 2024
39f4216
fix device type
chuneuny-emily Mar 6, 2024
21225d0
add headers
kprokofi Mar 6, 2024
0d538de
Omit e2e tests, which are not supported on xpu (#3048)
chuneuny-emily Mar 7, 2024
a5dcbd5
Merge pull request #3049 from openvinotoolkit/kp/add_unit_tests_dgpus
sovrasov Mar 7, 2024
c34bf2c
Merge branch 'develop' into develop-idev
eunwoosh Mar 8, 2024
4edfd52
revert github workflow
eunwoosh Mar 8, 2024
3fab960
Merge branch 'develop' into develop-idev
eunwoosh Mar 8, 2024
a0b3bc6
remove pytest.mark.skipif
kprokofi Mar 10, 2024
76c5ce7
revert experiments back
kprokofi Mar 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/otx/algorithms/anomaly/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.


from .anomalib.accelerators.xpu import XPUAccelerator # noqa: F401
from .anomalib.strategies import SingleXPUStrategy # noqa: F401
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Lightning accelerator for XPU device."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from .xpu import XPUAccelerator

__all__ = ["XPUAccelerator"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Lightning accelerator for XPU device."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from typing import Any, Dict, Union

import torch
from pytorch_lightning.accelerators import AcceleratorRegistry
from pytorch_lightning.accelerators.accelerator import Accelerator

from otx.algorithms.common.utils.utils import is_xpu_available


class XPUAccelerator(Accelerator):
"""Support for a XPU, optimized for large-scale machine learning."""

accelerator_name = "xpu"

def setup_device(self, device: torch.device) -> None:
"""Sets up the specified device."""
if device.type != "xpu":
raise RuntimeError(f"Device should be xpu, got {device} instead")

Check warning on line 23 in src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py#L22-L23

Added lines #L22 - L23 were not covered by tests

torch.xpu.set_device(device)

Check warning on line 25 in src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py#L25

Added line #L25 was not covered by tests

@staticmethod
def parse_devices(devices: Any) -> Any:
"""Parses devices for multi-GPU training."""
if isinstance(devices, list):
return devices
return [devices]

Check warning on line 32 in src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py#L30-L32

Added lines #L30 - L32 were not covered by tests

@staticmethod
def get_parallel_devices(devices: Any) -> Any:
"""Generates a list of parrallel devices."""
return [torch.device("xpu", idx) for idx in devices]

Check warning on line 37 in src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py#L37

Added line #L37 was not covered by tests

@staticmethod
def auto_device_count() -> int:
"""Returns number of XPU devices available."""
return torch.xpu.device_count()

Check warning on line 42 in src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py#L42

Added line #L42 was not covered by tests

@staticmethod
def is_available() -> bool:
"""Checks if XPU available."""
return is_xpu_available()

Check warning on line 47 in src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py#L47

Added line #L47 was not covered by tests

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Returns XPU devices stats."""
return {}

Check warning on line 51 in src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py#L51

Added line #L51 was not covered by tests

def teardown(self) -> None:
"""Cleans-up XPU-related resources."""
pass

Check warning on line 55 in src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py#L55

Added line #L55 was not covered by tests


AcceleratorRegistry.register(
XPUAccelerator.accelerator_name, XPUAccelerator, description="Accelerator supports XPU devices"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Plugin for mixed-precision training on XPU."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .xpu_precision import MixedPrecisionXPUPlugin

__all__ = ["MixedPrecisionXPUPlugin"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Plugin for mixed-precision training on XPU."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, Optional, Union

import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.types import Optimizable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor
from torch.optim import LBFGS, Optimizer


class MixedPrecisionXPUPlugin(PrecisionPlugin):
"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.xpu.autocast``.

Args:
scaler: An optional :class:`torch.cuda.amp.GradScaler` to use.
"""

def __init__(self, scaler: Optional[Any] = None) -> None:
self.scaler = scaler

Check warning on line 27 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L27

Added line #L27 was not covered by tests

def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor:
"""Apply grad scaler before backward."""
if self.scaler is not None:
tensor = self.scaler.scale(tensor)
return super().pre_backward(tensor, module)

Check warning on line 33 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L31-L33

Added lines #L31 - L33 were not covered by tests

def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
model: "pl.LightningModule",
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> Any:
"""Make an optimizer step using scaler if it was passed."""
if self.scaler is None:

Check warning on line 44 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L44

Added line #L44 was not covered by tests
# skip scaler logic, as bfloat16 does not require scaler
return super().optimizer_step(

Check warning on line 46 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L46

Added line #L46 was not covered by tests
optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs
)
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(

Check warning on line 50 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L49-L50

Added lines #L49 - L50 were not covered by tests
f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
)
closure_result = closure()

Check warning on line 53 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L53

Added line #L53 was not covered by tests

if not _optimizer_handles_unscaling(optimizer):

Check warning on line 55 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L55

Added line #L55 was not covered by tests
# Unscaling needs to be performed here in case we are going to apply gradient clipping.
# Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam).
# Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook.
self.scaler.unscale_(optimizer)

Check warning on line 59 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L59

Added line #L59 was not covered by tests

self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None

Check warning on line 62 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L61-L62

Added lines #L61 - L62 were not covered by tests
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:

Check warning on line 64 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L64

Added line #L64 was not covered by tests
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
step_output = self.scaler.step(optimizer, **kwargs)
self.scaler.update()
return step_output
return closure_result

Check warning on line 69 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L66-L69

Added lines #L66 - L69 were not covered by tests

def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""Handle grad clipping with scaler."""
if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
raise RuntimeError(

Check warning on line 79 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L78-L79

Added lines #L78 - L79 were not covered by tests
f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
)
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)

Check warning on line 83 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L83

Added line #L83 was not covered by tests

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""Enable autocast context."""
with torch.xpu.autocast(True):
yield

Check warning on line 89 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L88-L89

Added lines #L88 - L89 were not covered by tests

def state_dict(self) -> Dict[str, Any]:
"""Returns state dict of the plugin."""
if self.scaler is not None:
return self.scaler.state_dict()
return {}

Check warning on line 95 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L93-L95

Added lines #L93 - L95 were not covered by tests

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Loads state dict to the plugin."""
if self.scaler is not None:
self.scaler.load_state_dict(state_dict)

Check warning on line 100 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L99-L100

Added lines #L99 - L100 were not covered by tests


def _optimizer_handles_unscaling(optimizer: Any) -> bool:
"""Determines if a PyTorch optimizer handles unscaling gradients in the step method ratherthan through the scaler.

Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return
value will only be reliable for built-in PyTorch optimizers.
"""
return getattr(optimizer, "_step_supports_amp_scaling", False)

Check warning on line 109 in src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py#L109

Added line #L109 was not covered by tests
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Lightning strategy for single XPU device."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from .xpu_single import SingleXPUStrategy

__all__ = ["SingleXPUStrategy"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Lightning strategy for single XPU device."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from typing import Optional

import pytorch_lightning as pl
import torch
from lightning_fabric.plugins import CheckpointIO
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies import StrategyRegistry
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from otx.algorithms.common.utils.utils import is_xpu_available


class SingleXPUStrategy(SingleDeviceStrategy):
"""Strategy for training on single XPU device."""

strategy_name = "xpu_single"

def __init__(
self,
device: _DEVICE = "xpu:0",
accelerator: Optional["pl.accelerators.Accelerator"] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):
if not is_xpu_available():
raise MisconfigurationException("`SingleXPUStrategy` requires XPU devices to run")

super().__init__(
accelerator=accelerator,
device=device,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)

@property
def is_distributed(self) -> bool:
"""Returns true if the strategy supports distributed training."""
return False

def setup_optimizers(self, trainer: "pl.Trainer") -> None:
"""Sets up optimizers."""
super().setup_optimizers(trainer)
if len(self.optimizers) != 1: # type: ignore
raise RuntimeError("XPU strategy doesn't support multiple optimizers")

Check warning on line 52 in src/otx/algorithms/anomaly/adapters/anomalib/strategies/xpu_single.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/adapters/anomalib/strategies/xpu_single.py#L52

Added line #L52 was not covered by tests
model, optimizer = torch.xpu.optimize(trainer.model, optimizer=self.optimizers[0]) # type: ignore
self.optimizers = [optimizer]
trainer.model = model


StrategyRegistry.register(
SingleXPUStrategy.strategy_name, SingleXPUStrategy, description="Strategy that enables training on single XPU"
)
17 changes: 16 additions & 1 deletion src/otx/algorithms/anomaly/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

from otx.algorithms.anomaly.adapters.anomalib.callbacks import IterationTimer, ProgressCallback
from otx.algorithms.anomaly.adapters.anomalib.data import OTXAnomalyDataModule
from otx.algorithms.anomaly.adapters.anomalib.plugins.xpu_precision import MixedPrecisionXPUPlugin
from otx.algorithms.common.utils.utils import is_xpu_available
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.model import ModelEntity
from otx.api.entities.train_parameters import TrainParameters
Expand Down Expand Up @@ -90,7 +92,20 @@
IterationTimer(on_step=False),
]

self.trainer = Trainer(**config.trainer, logger=CSVLogger(self.project_path, name=""), callbacks=callbacks)
plugins = []
if config.trainer.plugins is not None:
plugins.extend(config.trainer.plugins)
config.trainer.pop("plugins")

Check warning on line 98 in src/otx/algorithms/anomaly/tasks/train.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/tasks/train.py#L95-L98

Added lines #L95 - L98 were not covered by tests

if is_xpu_available():
config.trainer.strategy = "xpu_single"
config.trainer.accelerator = "xpu"
if config.trainer.precision == 16:
plugins.append(MixedPrecisionXPUPlugin())

Check warning on line 104 in src/otx/algorithms/anomaly/tasks/train.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/tasks/train.py#L100-L104

Added lines #L100 - L104 were not covered by tests

self.trainer = Trainer(

Check warning on line 106 in src/otx/algorithms/anomaly/tasks/train.py

View check run for this annotation

Codecov / codecov/patch

src/otx/algorithms/anomaly/tasks/train.py#L106

Added line #L106 was not covered by tests
**config.trainer, logger=CSVLogger(self.project_path, name=""), callbacks=callbacks, plugins=plugins
)
self.trainer.fit(model=self.model, datamodule=datamodule)

self.save_model(output_model)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Adapters of classification - mmcls."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .train import train_model

__all__ = ["train_model"]
Loading
Loading