Skip to content

Commit

Permalink
Add XPU support to anomaly task (#2677)
Browse files Browse the repository at this point in the history
* Update base.txt

updated dependency version of datumaro

* Update __init__.py

update version string

* Update requirements.txt

* Temporarily skip visual prompting openvino integration test (#2323)

* Fix import dm.DatasetSubset (#2324)

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>

* Fix semantic segmentation soft prediction dtype (#2322)

* Fix semantic segmentation soft prediction dtype

* relax ref sal vals check

---------

Co-authored-by: Songki Choi <songki.choi@intel.com>

* Contrain yapf verison lesser than 0.40.0 (#2328)

contrain_yapf_version

* Fix detection e2e tests (#2327)

Fix for detection

* Mergeback: Label addtion/deletion 1.2.4 --> 1.4.0 (#2326)

* Make black happy

* Fix conflicts

* Merge-back: add test datasets and edit the test code

* Make black happy

* Fix mis-merge

* Make balck happy

* Fix typo

* Fix typoi

---------

Co-authored-by: Songki Choi <songki.choi@intel.com>

* Bump datumaro up to 1.4.0rc2 (#2332)

bump datumaro up to 1.4.0rc2

* Tiling Doc for releases 1.4.0 (#2333)

* Add tiling documentation

* Bump otx version to 1.4.0rc2 (#2341)

* OTX deploy for visual prompting task  (#2311)

* Enable `otx deploy`

* (WIP) integration test

* Docstring

* Update args for create_model

* Manually set image embedding layout

* Enable to use model api for preprocessing
- `fit_to_window` doesn't work expectedly, so newly implemented `VisualPromptingOpenvinoAdapter` to use new resize function

* Remove skipped test

* Updated

* Update unit tests on model wrappers

* Update

* Update configuration

* Fix not to patch pretrained path

* pylint & update model api version in docstring

---------

Co-authored-by: Wonju Lee <wonju.lee@intel.com>

* Bump albumentations version in anomaly requirements (#2350)

increment albumentations version

* Update action detection (#2346)

* Remove skip mark for PTQ test of action detection

* Update action detection documentation

* Fix e2e (#2348)

* Change classification dataset from dummy to toy

* Revert test changes

* Change label name for multilabel dataset

* Revert e2e test changes

* Change ov test cases' threshold

* Add parent's label

* Update ModelAPI in 1.4 release (#2347)

* Upgrade model API

* Update otx in exportable code

* Fix unit tests

* Fix black

* Fix detection inference

* Fix det tiling

* Fix mypy

* Fix demo

* Fix visualizer in demo

* Fix black

* Add OTX optimize for visual prompting task (#2318)

* Initial commit

* Update block

* (WIP) otx optimize

* Fix

* WIP

* Update configs & exported outputs

* Remove unused modules for torch

* Add unit tests

* pre-commit

* Update CHANGELOG

* Update detection docs (#2335)

* Update detection docs

* Revert template id changes

* Fix wrong template id

* Update docs/source/guide/explanation/algorithms/object_detection/object_detection.rst

Co-authored-by: Eunwoo Shin <eunwoo.shin@intel.com>

* Update docs/source/guide/explanation/algorithms/object_detection/object_detection.rst

Co-authored-by: Eunwoo Shin <eunwoo.shin@intel.com>

---------

Co-authored-by: Eunwoo Shin <eunwoo.shin@intel.com>

* Add visual prompting documentation (#2354)

* (WIP) write docs

* Add visual prompting documentation

* Update CHANGELOG

---------

Co-authored-by: sungchul.kim <sungchul@ikvensx010>

* Remove custom modelapi patch in visual prompting (#2359)

* Remove custom modelapi patch

* Update test

* Fix graph metric order and label issues (#2356)

* Fix graph metric going backward issue
* Add license notice
* Fix pre-commit issue
* Add rename items & logic for metric
---------
Signed-off-by: Songki Choi <songki.choi@intel.com>

* Update multi-label document and conversion script (#2358)

Update docs, label convert script

* Update third party programs (#2365)

* Make anomaly task compatible with older albumentations versions (#2363)

* fix transforms export in metadata

* wrap transform dict

* add todo for updating to_dict call

* Fixing detection saliency map for one class case (#2368)

* fix softmax

* fix validity tests

* Add e2e test for visual prompting (#2360)

* (WIP) otx optimize

* pre-commit

* (WIP) set e2e

* Remove nncf config

* Add visual prompting requirement

* Add visual prompting in tox

* Add visual prompting in setup.py

* Fix typo

* Delete unused configuration.yaml

* Edit test_name

* Add to limit activation range

* Update from `vp` to `visprompt`

* Fix about no returning the first label

* pre-commit

* (WIP) otx optimize

* pre-commit

* (WIP) set e2e

* Remove nncf config

* Add visual prompting requirement

* Add visual prompting in tox

* Add visual prompting in setup.py

* Fix typo

* pre-commit

* Add actions

* Update tests/e2e/cli/visual_prompting/test_visual_prompting.py

Co-authored-by: Jaeguk Hyun <jaeguk.hyun@intel.com>

* Skip PTQ e2e test

* Change task name

* Remove skipped tc

---------

Co-authored-by: Jaeguk Hyun <jaeguk.hyun@intel.com>

* Fix e2e (#2366)

* Change e2e reference name

* Update openvino eval threshold for multiclass classification

* Change comment message

* Fix tiling e2e tests

---------

Co-authored-by: GalyaZalesskaya <galina.zalesskaya@intel.com>

* Add Dino head unit tests (#2344)

Recover DINO head unit tests

* Update for release 1.4.0rc2 (#2370)

* update for release 1.4.0rc2

* Add skip mark for unstable unit tests

---------

Co-authored-by: jaegukhyun <jaeguk.hyun@intel.com>

* Fix NNCF training on CPU (#2373)

* Align label order between Geti and OTX (#2369)

* align label order

* align with pre-commit

* update CHANGELOG.md

* deal with edge case

* update type hint

* Remove CenterCrop from Classification test pipeline and editing missing docs link (#2375)

* Fix missing link for docs and removing centercrop for classification data pipeline

* Revert the test threshold

* Fix H-label classification (#2377)

* Fix h-labelissue

* Update unit tests

* Make black happy

* Fix unittests

* Make black happy

* Fix update heades information func

* Update the logic: consider the loss per batch

* Update for release 1.4 (#2380)

* updated for 1.4.0rc3

* update changelog & release note

* bump datumaro version up

---------

Co-authored-by: Songki Choi <songki.choi@intel.com>

* Switch to PTQ for sseg (#2374)

* Switch to PTQ for sseg

* Update log messages

* Fix invalid import structures in otx.api (#2383)

Update tiler.py

* Update for 1.4.0rc4 (#2385)

update for release 1.4.0rc4

* [release 1.4.0] XAI: Return saliency maps for Mask RCNN IR async infer (#2395)

* Return saliency maps for openvino async infer

* add workaround to fix yapf importing error

---------

Co-authored-by: eunwoosh <eunwoo.shin@intel.com>

* Update for release 1.4.0 (#2399)

update version string

Co-authored-by: Sungman Cho <sungman.cho@intel.com>

* Fix broken links in documentation (#2405)

* fix docs links to datumaro's docs
* fix docs links to otx's docs
* bump version to 1.4.1

* Update exportable code README (#2411)

* Updated for release 1.4.1 (#2412)

updated for release 1.4.1

* Add workaround for the incorrect meta info M-RCNN (used for XAI) (#2437)

Add workaround for the incorrect mata info

* Add model category attributes to model template (#2439)

Add model category attributes to model template

* Add model category & status fields in model template

* Add is_default_for_task attr to model template

* Update model templates with category attrs

* Add integration tests for model templates consistency

* Fix license & doc string

* Fix typo

* Refactor test cases

* Refactor common tests by generator

---------
Signed-off-by: Songki Choi <songki.choi@intel.com>

* Update for 1.4.2rc1 (#2441)

update for release 1.4.2rc1

* Fix label list order for h-label classification (#2440)

* Fix label list for h-label cls
* Fix unit tests

* Modified fq numbers for lite HRNET (#2445)

modified fq numbers for lite HRNET

* Update PTQ ignored scope for hrnet 18  mod2 (#2449)

Update ptq ignored scope for hrnet 18  mod2

* Fix OpenVINO inference for legacy models (#2450)

* bug fix for legacy openvino models

* Add tests

* Specific exceptions

---------

* Update for 1.4.2rc2 (#2455)

update for release 1.4.2rc2

* Prevent zero-sized saliency map in tiling if tile size is too big (#2452)

* Prevent zero-sized saliency map in tiling if tile size is too big

* Prevent zero-sized saliency in tiling (PyTorch)

* Add unit tests for Tiler merge features methods

---------

Co-authored-by: Galina <galina.zalesskaya@intel.com>

* Update pot fq reference number (#2456)

update pot fq reference number to 15

* Bump datumaro version to 1.5.0rc0 (#2470)

bump datumaro version to 1.5.0rc0

* Set tox version constraint (#2472)

set tox version constraint - tox-dev/tox#3110

* Bug fix for albumentations (#2467)

* bug fix for legacy openvino models

* Address albumentation issue

---------

Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>

* update for release 1.4.2rc3

* Add a dummy hierarchical config required by MAPI (#2483)

* bump version to 1.4.2rc4

* Bump datumaro version (#2502)

* bump datumaro version

* remove deprecated/reomved attribute usage of the datumaro

* Upgrade nncf version for 1.4 release (#2459)

* Upgrade nncf version

* Fix nncf interface warning

* Set the exact nncf version

* Update FQ refs after NNCF upgrade

* Use NNCF from pypi

* Update version for release 1.4.2rc5 (#2507)

update version for release 1.4.2rc5

* Update for 1.4.2 (#2514)

update for release 1.4.2

* create branch release/1.5.0

* Delete mem cache handler after training is done (#2535)

release mem cache handler after training is done

* Fix bug that auto batch size doesn't consider distributed training (#2533)

* consider distributed training while searching batch size

* update unit test

* reveret gpu memory upper bound

* fix typo

* change allocated to reserved

* add unit test for distributed training

* align with pre-commit

* Apply fix progress hook to release 1.5.0 (#2539)

* Fix hook's ordering issue. AdaptiveRepeatHook changes the runner.max_iters before the ProgressHook

* Change the expression

* Fix typo

* Fix multi-label, h-label issue

* Fix auto_bs issue

* Apply suggestions from code review

Co-authored-by: Eunwoo Shin <eunwoo.shin@intel.com>

* Reflecting reviews

* Refactor the name of get_data_cfg

* Revert adaptive hook sampler init

* Refactor the function name: get_data_cfg -> get_subset_data_cfg

* Fix unit test errors

* Remove adding AdaptiveRepeatDataHook for autobs

* Remove unused import

* Fix detection and segmentation case in Geti scenario

---------

Co-authored-by: Eunwoo Shin <eunwoo.shin@intel.com>

* Re introduce adaptive scheduling for training (#2541)

* Re-introduce adaptive patience for training

* Revert unit tests

* Update for release 1.4.3rc1 (#2542)

* Mirror Anomaly ModelAPI changes (#2531)

* Migrate anomaly exportable code to modelAPI (#2432)

* Fix license in PR template

* Migrate to modelAPI

* Remove color conversion in streamer

* Remove reverse_input_channels

* Add float

* Remove test as metadata is no longer used

* Remove metadata from load method

* remove anomalib openvino inferencer

* fix signature

* Support logacy OpenVINO model

* Transform image

* add configs

* Re-introduce adaptive training (#2543)

* Re-introduce adaptive patience for training

* Revert unit tests

* Fix auto input size mismatch in eval & export (#2530)

* Fix auto input size mismatch in eval & export

* Re-enable E2E tests for Issue#2518

* Add input size check in export testing

* Format float numbers in log

* Fix NNCF export shape mismatch

* Fix saliency map issue

* Disable auto input size if tiling enabled

---------

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Update ref. fq number for anomaly e2e2 (#2547)

* Skip e2e det tests by issue2548 (#2550)

* Add skip to chained TC for issue #2548 (#2552)

* Update for release 1.4.3 (#2551)

* Update MAPI for 1.5 release (#2555)

Upgrade MAPI to v 0.1.6 (#2529)

* Upgrade MAPI

* Update exp code demo commit

* Fix MAPI imports

* Update ModelAPI configuration (#2564)

* Update MAPI rt infor for detection

* Upadte export info for cls, det and seg

* Update unit tests

* Disable QAT for SegNexts (#2565)

* Disable NNCF QAT for SegNext

* Del obsolete pot configs

* Move NNCF skip marks to test commands to avoid duplication

* Add Anomaly modelAPI changes to releases/1.4.0 (#2563)

* bug fix for legacy openvino models

* Apply otx anomaly 1.5 changes

* Fix tests

* Fix compression config

* fix modelAPI imports

* update integration tests

* Edit config types

* Update keys in deployed model

---------

Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>
Co-authored-by: Kim, Sungchul <sungchul.kim@intel.com>

* Fix the CustomNonLinearClsHead when the batch_size is set to 1 (#2571)

Fix bn1d issue

Co-authored-by: sungmanc <sungmanc@intel.com>

* Update ModelAPI configuration (#2564 from 1.4) (#2568)

Update ModelAPI configuration (#2564)

* Update MAPI rt infor for detection

* Upadte export info for cls, det and seg

* Update unit tests

* Update for 1.4.4rc1 (#2572)

* Hotfix DatasetEntity.get_combined_subset function loop (#2577)

Fix get_combined_subset function

* Revert default input size to `Default` due to YOLOX perf regression (#2580)

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Fix for the degradation issue of the classification task (#2585)

* Revert to sync with 1.4.0

* Remove repeat data

* Convert to the RGB value

* Fix color conversion logic

* Fix precommit

* Bump datumaro version to 1.5.1rc3 (#2587)

* Add label ids to anomaly OpenVINO model xml (#2590)

* Add label ids to model xml

---------

* Fix DeiT-Tiny model regression during class incremental training (#2594)

* enable IBloss for DeiT-Tiny

* update changelog

* add docstring

* Add label ids to model xml in release 1.5 (#2591)

Add label ids to model xml

* Fix DeiT-Tiny regression test for release/1.4.0 (#2595)

* Fix DeiT regression test

* update changelog

* temp

* Fix mmcls bug not wrapping model in DataParallel on CPUs (#2601)

Wrap multi-label and h-label classification models by MMDataParallel in case of CPU training.
---------
Signed-off-by: Songki Choi <songki.choi@intel.com>

* Fix h-label loss normalization issue w/ exclusive label group of singe label (#2604)

* Fix h-label loss normalization issue w/ exclusive label group with signle label

* Fix non-linear version

---------
Signed-off-by: Songki Choi <songki.choi@intel.com>

* Boost up Image numpy accessing speed through PIL (#2586)

* boost up numpy accessing speed through PIL

* update CHANGELOG

* resolve precommit error

* resolve precommit error

* add fallback logic with PIL open

* use convert instead of draft

* Add missing import pathlib for cls e2e testing (#2610)

* Fix division by zero in class incremental learning for classification (#2606)

* Add empty label to reproduce zero-division error

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Fix minor typo

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Fix empty label 4 -> 3

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Prevent division by zero

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Update license

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Update CHANGELOG.md

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Fix inefficient sampling

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Revert indexing

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Fix minor typo

Signed-off-by: Songki Choi <songki.choi@intel.com>

---------

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Unify logger usage (#2612)

* unify logger

* align with pre-commit

* unify anomaly logger to otx

* change logger file path

* align with pre-commit

* change logger file path in missing file

* configure logger after ConfigManager is initialized

* configure logger when ConfigManager instance is initialized

* update unit test code

* move config_logger to each cli file

* align with pre-commit

* change part still using mmcv logger

* Fix XAI algorithm for Detection (#2609)

* Impove saliency maps algorithm for Detection

* Remove extra changes

* Update unit tests

* Changes for 1 class

* Fix pre-commit

* Update CHANGELOG

* Tighten dependency constraint only adapting latest patches (#2607)

* tighten dependency constratint only adapting latest patches

* adjust scikit-image version w.r.t python version

* adjust tensorboard version w.r.t python version

* remove version specifier for scikit-image

* Add metadata to optimized model (#2618)

* bug fix for legacy openvino models

* Add metadata to optimized model

* Revert formatting changes

---------

Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>

* modify omegaconf version constraint

* [release 1.5.0] Fix XAI algorithm for Detection (#2617)

Update detection XAI algorithm

* Update dependency constraint (#2622)

* Update tpp (#2621)

* Fix h-label bug of missing parent labels in output (#2626)

* Fix h-label bug of missing parent labels in output

* Fix h-label test data label schema

* Update CHANGELOG.md

---------
Signed-off-by: Songki Choi <songki.choi@intel.com>

* Update publish workflow (#2625)

update publish workflow to push whl to internal pypi

* bump datumaro version to ~=1.5.0

* fixed mistake while mergeing back 1.4.4

* modifiy readme

* remove openvino model wrapper class

* remove openvino model wrapper tests

* [release 1.5.0] DeiT: enable tests + add ViTFeatureVectorHook (#2630)

Add ViT feature vector hook

* Fix docs broken link to datatumaro_h-label

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Fix wrong label settings for non-anomaly task ModelAPIs

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Update publish workflow for tag checking (#2632)

* Update e2e tests for XAI Detection (#2634)

Fix e2e XAI ref value

* Disable QAT for newly added models (#2636)

* Update release note and readme (#2637)

* update release note and readme

* remove package upload step on internal publish wf

* update release note and, changelog, and readme

* update version string to 1.6.0dev

* fix datumaro version to 1.6.0rc0

* Mergeback 1.5.0 to develop (#2642)

* Update publish workflow for tag checking (#2632)

* Update e2e tests for XAI Detection (#2634)

* Disable QAT for newly added models (#2636)

* Update release note and readme (#2637)

* remove package upload step on internal publish wf

* update release note and, changelog, and readme

* update version string to 1.6.0dev

---------

Co-authored-by: Galina Zalesskaya <galina.zalesskaya@intel.com>
Co-authored-by: Jaeguk Hyun <jaeguk.hyun@intel.com>

* Revert "Mergeback 1.5.0 to develop" (#2645)

Revert "Mergeback 1.5.0 to develop (#2642)"

This reverts commit 2f67686.

* Add a tool to help conduct experiments (#2651)

* implement run and experiment

* implement experiment result aggregator

* refactor experiment.py

* refactor run.py

* get export model speed

* add var collumn

* refactor experiment.py

* refine a way to update argument in cmd

* refine resource tracker

* support anomaly on research framework

* refine code aggregating exp result

* bugfix

* make other task available

* eval task save avg_time_per_images as result

* Add new argument to track CPU&GPU utilization and memory usage (#2500)

* add argument to track resource usage

* fix bug

* fix a bug in a multi gpu case

* use total cpu usage

* add unit test

* add mark to unit test

* cover edge case

* add pynvml in requirement

* align with pre-commit

* add license comment

* update changelog

* refine argument help

* align with pre-commit

* add version to requirement and raise an error if not supported values are given

* apply new resource tracker format

* refactor run.py

* support optimize in research framework

* cover edge case

* Handle a case where fail cases exist

* make argparse raise error rather than exit if problem exist

* revert tensorboard aggregator

* bugfix

* save failed cases as yaml file

* deal with integer in variables

* add epoch to metric

* use latest log.json file

* align with otx logging method

* move experiment.py from cli to tools

* refactor experiment.py

* merge otx run feature into experiment.py

* move set_arguments_to_cmd definition into experiment.py

* refactor experiment.py

* bugfix

* minor bugfix

* use otx.cli instead of each otx entry

* add feature to parse single workspace

* add comments

* fix bugs

* align with pre-commit

* revert parser argument

* align with pre-commit

* Revert inference batch size to 1 for instance segmentation (#2648)

Signed-off-by: Songki Choi <songki.choi@intel.com>

* Remove unnecessary log while building a model (#2658)

* revert logger in otx/algorithms/detection/adapters/mmdet/utils/builder.py

* revert logger in otx/algorithms/classification/adapters/mmcls/utils/builder.py

* make change more readable

* Fix a minor bug of experiment.py (#2662)

fix bug

* Not check avg_time_per_image during test (#2665)

* ignore avg_time_per_image during test

* do not call stdev when length of array is less than 2

* ignore avg_time_per_image during regerssion test

* Update device selection logic in classificaiton

* Add xpu accelerator

* Tmp patch for anomaly trainer

* Use XPU callback for anomaly training

* Update xpu accelerator

* Fix for anomaly xpu callback

* Fix validation batch logic

* Cleanup, add docstrings

* Refine xpu callback

---------

Signed-off-by: Kim, Vinnam <vinnam.kim@intel.com>
Signed-off-by: Songki Choi <songki.choi@intel.com>
Co-authored-by: Yunchu Lee <yunchu.lee@intel.com>
Co-authored-by: Kim, Sungchul <sungchul.kim@intel.com>
Co-authored-by: Vinnam Kim <vinnam.kim@intel.com>
Co-authored-by: Evgeny Tsykunov <evgeny.tsykunov@intel.com>
Co-authored-by: Songki Choi <songki.choi@intel.com>
Co-authored-by: Eunwoo Shin <eunwoo.shin@intel.com>
Co-authored-by: Jaeguk Hyun <jaeguk.hyun@intel.com>
Co-authored-by: Sungman Cho <sungman.cho@intel.com>
Co-authored-by: Eugene Liu <eugene.liu@intel.com>
Co-authored-by: Wonju Lee <wonju.lee@intel.com>
Co-authored-by: Dick Ameln <dick.ameln@intel.com>
Co-authored-by: sungchul.kim <sungchul@ikvensx010>
Co-authored-by: GalyaZalesskaya <galina.zalesskaya@intel.com>
Co-authored-by: Harim Kang <harim.kang@intel.com>
Co-authored-by: Ashwin Vaidya <ashwin.vaidya@intel.com>
Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>
Co-authored-by: sungmanc <sungmanc@intel.com>
  • Loading branch information
18 people authored Nov 30, 2023
1 parent 03e87f5 commit c1beb05
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@

from .inference import AnomalyInferenceCallback
from .progress import ProgressCallback
from .xpu import XPUCallback

__all__ = ["AnomalyInferenceCallback", "ProgressCallback"]
__all__ = ["AnomalyInferenceCallback", "ProgressCallback", "XPUCallback"]
36 changes: 36 additions & 0 deletions src/otx/algorithms/anomaly/adapters/anomalib/callbacks/xpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Anomaly XPU device callback."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import torch
from pytorch_lightning import Callback


class XPUCallback(Callback):
"""XPU device callback.
Applies IPEX optimization before training, moves data to XPU.
"""

def __init__(self, device_idx=0):
self.device = torch.device(f"xpu:{device_idx}")

def on_fit_start(self, trainer, pl_module):
"""Applies IPEX optimization before training."""
pl_module.to(self.device)
model, optimizer = torch.xpu.optimize(trainer.model, optimizer=trainer.optimizers[0])
trainer.optimizers = [optimizer]
trainer.model = model

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
"""Moves train batch tensors to XPU."""
for k in batch:
if not isinstance(batch[k], list):
batch[k] = batch[k].to(self.device)

def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
"""Moves validation batch tensors to XPU."""
for k in batch:
if not isinstance(batch[k], list):
batch[k] = batch[k].to(self.device)
5 changes: 5 additions & 0 deletions src/otx/algorithms/anomaly/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from pytorch_lightning import Trainer, seed_everything

from otx.algorithms.anomaly.adapters.anomalib.callbacks import ProgressCallback
from otx.algorithms.anomaly.adapters.anomalib.callbacks.xpu import XPUCallback
from otx.algorithms.anomaly.adapters.anomalib.data import OTXAnomalyDataModule
from otx.algorithms.common.utils.utils import is_xpu_available
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.model import ModelEntity
from otx.api.entities.train_parameters import TrainParameters
Expand Down Expand Up @@ -88,6 +90,9 @@ def train(
),
]

if is_xpu_available():
callbacks.append(XPUCallback())

self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks)
self.trainer.fit(model=self.model, datamodule=datamodule)

Expand Down

0 comments on commit c1beb05

Please sign in to comment.