diff --git a/flash/image/embedding/strategies/default.py b/flash/image/embedding/strategies/default.py index 357223cd828..32650f3695d 100644 --- a/flash/image/embedding/strategies/default.py +++ b/flash/image/embedding/strategies/default.py @@ -28,11 +28,10 @@ class DefaultAdapter(Adapter): required_extras: str = "image" - def __init__(self, backbone: torch.nn.Module, head: torch.nn.Module): + def __init__(self, backbone: torch.nn.Module): super().__init__() self.backbone = backbone - self.head = head @classmethod @catch_url_error @@ -41,10 +40,9 @@ def from_task( *args, task: AdapterTask, backbone: torch.nn.Module, - head: torch.nn.Module, **kwargs, ) -> Adapter: - adapter = cls(backbone, head) + adapter = cls(backbone) adapter.__dict__["_task"] = task return adapter