Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Support latest version of BaaL (1.5.2) and add necessary utilities #1315

Merged
merged 13 commits into from
May 6, 2022
6 changes: 6 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def _compare_version(package: str, op, version) -> bool:
return False
try:
pkg_version = Version(pkg.__version__)
except AttributeError:
# In case the module doesn't have __version__ attribute (example: baal)
import pkg_resources

pkg_version = Version(pkg_resources.get_distribution("baal").version)
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
except TypeError:
# this is mock by sphinx, so it shall return True to generate all summaries
return True
Expand Down Expand Up @@ -128,6 +133,7 @@ class Image:
_PANDAS_GREATER_EQUAL_1_3_0 = _compare_version("pandas", operator.ge, "1.3.0")
_ICEVISION_GREATER_EQUAL_0_11_0 = _compare_version("icevision", operator.ge, "0.11.0")
_TM_GREATER_EQUAL_0_7_0 = _compare_version("torchmetrics", operator.ge, "0.7.0")
_BAAL_GREATER_EQUAL_1_5_2 = _compare_version("baal", operator.ge, "1.5.2")

_TEXT_AVAILABLE = all(
[
Expand Down
14 changes: 12 additions & 2 deletions flash/image/classification/integrations/baal/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,20 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException

import flash
from flash.core.utilities.imports import _BAAL_AVAILABLE
from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2

if _BAAL_AVAILABLE:
from baal.bayesian.dropout import _patch_dropout_layers
# _patch_dropout_layers function was replaced with replace_layers_in_module helper
# function in v1.5.2 (https://github.com/ElementAI/baal/pull/194 for more details)
if _BAAL_GREATER_EQUAL_1_5_2:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
from baal.bayesian.common import replace_layers_in_module
from baal.bayesian.consistent_dropout import _consistent_dropout_mapping_fn

def _patch_dropout_layers(module: torch.nn.Module):
return replace_layers_in_module(module, _consistent_dropout_mapping_fn)

else:
from baal.bayesian.consistent_dropout import _patch_dropout_layers


class InferenceMCDropoutTask(flash.Task):
Expand Down
1 change: 0 additions & 1 deletion requirements/datatype_image_extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ effdet
albumentations
learn2learn
structlog==21.1.0 # remove when baal resolved its dependency.
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
baal
fastface
fairscale

Expand Down
2 changes: 2 additions & 0 deletions requirements/datatype_image_extras_baal.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# This is a separate file, as baal integration is affected by vissl installation (conflicts)
baal>=1.5.2
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 3 additions & 1 deletion tests/image/classification/test_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.utils.data import SequentialSampler

import flash
from flash.core.utilities.imports import _BAAL_AVAILABLE, _IMAGE_AVAILABLE
from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2, _IMAGE_AVAILABLE
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop
from tests.image.classification.test_data import _rand_image
Expand Down Expand Up @@ -110,6 +110,8 @@ def test_active_learning_training(simple_datamodule, initial_num_labels, query_s
assert len(active_learning_dm._dataset) == 20
assert active_learning_loop.progress.total.completed == 3
labelled = active_learning_loop.state_dict()["state_dict"]["datamodule_state_dict"]["labelled"]
if _BAAL_GREATER_EQUAL_1_5_2:
labelled = labelled > 0
assert isinstance(labelled, np.ndarray)

# Check that we iterate over the actual pool and that shuffle is disabled.
Expand Down