diff --git a/sup3r/preprocessing/batch_handlers/factory.py b/sup3r/preprocessing/batch_handlers/factory.py index 3a8ff1fc4..60c8a8413 100644 --- a/sup3r/preprocessing/batch_handlers/factory.py +++ b/sup3r/preprocessing/batch_handlers/factory.py @@ -2,11 +2,8 @@ samplers.""" import logging -from typing import Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union -from sup3r.preprocessing.base import ( - Container, -) from sup3r.preprocessing.batch_queues.base import SingleBatchQueue from sup3r.preprocessing.batch_queues.conditional import ( QueueMom1, @@ -27,6 +24,9 @@ log_args, ) +if TYPE_CHECKING: + from sup3r.preprocessing.base import Container + logger = logging.getLogger(__name__) @@ -86,8 +86,8 @@ class BatchHandler(MainQueueClass): @log_args def __init__( self, - train_containers: List[Container], - val_containers: Optional[List[Container]] = None, + train_containers: List['Container'], + val_containers: Optional[List['Container']] = None, sample_shape: Optional[tuple] = None, batch_size: int = 16, n_batches: int = 64, diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 56d0934d1..28494af6c 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -9,16 +9,18 @@ import time from abc import ABC, abstractmethod from collections import namedtuple -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import dask import numpy as np import tensorflow as tf from sup3r.preprocessing.collections.base import Collection -from sup3r.preprocessing.samplers import DualSampler, Sampler from sup3r.utilities.utilities import RANDOM_GENERATOR, Timer +if TYPE_CHECKING: + from sup3r.preprocessing.samplers import DualSampler, Sampler + logger = logging.getLogger(__name__) @@ -31,7 +33,7 @@ class AbstractBatchQueue(Collection, ABC): def __init__( self, - samplers: Union[List[Sampler], List[DualSampler]], + samplers: Union[List['Sampler'], List['DualSampler']], batch_size: int = 16, n_batches: int = 64, s_enhance: int = 1, diff --git a/sup3r/preprocessing/batch_queues/conditional.py b/sup3r/preprocessing/batch_queues/conditional.py index e940a511a..43d479b5f 100644 --- a/sup3r/preprocessing/batch_queues/conditional.py +++ b/sup3r/preprocessing/batch_queues/conditional.py @@ -3,17 +3,19 @@ import logging from abc import abstractmethod from collections import namedtuple -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union import numpy as np from sup3r.models.conditional import Sup3rCondMom -from sup3r.preprocessing.samplers import DualSampler, Sampler from sup3r.preprocessing.utilities import numpy_if_tensor from .base import SingleBatchQueue from .utilities import spatial_simple_enhancing, temporal_simple_enhancing +if TYPE_CHECKING: + from sup3r.preprocessing.samplers import DualSampler, Sampler + logger = logging.getLogger(__name__) @@ -26,7 +28,7 @@ class ConditionalBatchQueue(SingleBatchQueue): def __init__( self, - samplers: Union[List[Sampler], List[DualSampler]], + samplers: Union[List['Sampler'], List['DualSampler']], time_enhance_mode: str = 'constant', lower_models: Optional[Dict[int, Sup3rCondMom]] = None, s_padding: int = 0, diff --git a/sup3r/preprocessing/cachers/base.py b/sup3r/preprocessing/cachers/base.py index a202a0e0f..4a6fe707c 100644 --- a/sup3r/preprocessing/cachers/base.py +++ b/sup3r/preprocessing/cachers/base.py @@ -6,21 +6,25 @@ import itertools import logging import os -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, TYPE_CHECKING import netCDF4 as nc4 # noqa import h5py import dask import dask.array as da import numpy as np -from sup3r.preprocessing.accessor import Sup3rX -from sup3r.preprocessing.base import Container, Sup3rDataset +from sup3r.preprocessing.base import Container from sup3r.preprocessing.names import Dimension from sup3r.preprocessing.utilities import _mem_check from sup3r.utilities.utilities import safe_serialize from .utilities import _check_for_cache +if TYPE_CHECKING: + from sup3r.preprocessing.accessor import Sup3rX + from sup3r.preprocessing.base import Sup3rDataset + + logger = logging.getLogger(__name__) @@ -32,7 +36,7 @@ class Cacher(Container): def __init__( self, - data: Union[Sup3rX, Sup3rDataset], + data: Union['Sup3rX', 'Sup3rDataset'], cache_kwargs: Optional[Dict] = None, ): """ diff --git a/sup3r/preprocessing/collections/base.py b/sup3r/preprocessing/collections/base.py index 64a6a4b3f..f9f373bc9 100644 --- a/sup3r/preprocessing/collections/base.py +++ b/sup3r/preprocessing/collections/base.py @@ -7,13 +7,15 @@ integrated into xarray (in progress as of 8/8/2024) """ -from typing import List, Union +from typing import TYPE_CHECKING, List, Union import numpy as np from sup3r.preprocessing.base import Container -from sup3r.preprocessing.samplers.base import Sampler -from sup3r.preprocessing.samplers.dual import DualSampler + +if TYPE_CHECKING: + from sup3r.preprocessing.samplers.base import Sampler + from sup3r.preprocessing.samplers.dual import DualSampler class Collection(Container): @@ -26,9 +28,9 @@ class Collection(Container): def __init__( self, containers: Union[ - List[Container], - List[Sampler], - List[DualSampler], + List['Container'], + List['Sampler'], + List['DualSampler'], ], ): super().__init__()