Skip to content

Commit

Permalink
type checking additions
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Aug 22, 2024
1 parent 9453b25 commit 9058109
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 22 deletions.
12 changes: 6 additions & 6 deletions sup3r/preprocessing/batch_handlers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
samplers."""

import logging
from typing import Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union

from sup3r.preprocessing.base import (
Container,
)
from sup3r.preprocessing.batch_queues.base import SingleBatchQueue
from sup3r.preprocessing.batch_queues.conditional import (
QueueMom1,
Expand All @@ -27,6 +24,9 @@
log_args,
)

if TYPE_CHECKING:
from sup3r.preprocessing.base import Container

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -86,8 +86,8 @@ class BatchHandler(MainQueueClass):
@log_args
def __init__(
self,
train_containers: List[Container],
val_containers: Optional[List[Container]] = None,
train_containers: List['Container'],
val_containers: Optional[List['Container']] = None,
sample_shape: Optional[tuple] = None,
batch_size: int = 16,
n_batches: int = 64,
Expand Down
8 changes: 5 additions & 3 deletions sup3r/preprocessing/batch_queues/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@
import time
from abc import ABC, abstractmethod
from collections import namedtuple
from typing import List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union

import dask
import numpy as np
import tensorflow as tf

from sup3r.preprocessing.collections.base import Collection
from sup3r.preprocessing.samplers import DualSampler, Sampler
from sup3r.utilities.utilities import RANDOM_GENERATOR, Timer

if TYPE_CHECKING:
from sup3r.preprocessing.samplers import DualSampler, Sampler

logger = logging.getLogger(__name__)


Expand All @@ -31,7 +33,7 @@ class AbstractBatchQueue(Collection, ABC):

def __init__(
self,
samplers: Union[List[Sampler], List[DualSampler]],
samplers: Union[List['Sampler'], List['DualSampler']],
batch_size: int = 16,
n_batches: int = 64,
s_enhance: int = 1,
Expand Down
8 changes: 5 additions & 3 deletions sup3r/preprocessing/batch_queues/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
import logging
from abc import abstractmethod
from collections import namedtuple
from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import numpy as np

from sup3r.models.conditional import Sup3rCondMom
from sup3r.preprocessing.samplers import DualSampler, Sampler
from sup3r.preprocessing.utilities import numpy_if_tensor

from .base import SingleBatchQueue
from .utilities import spatial_simple_enhancing, temporal_simple_enhancing

if TYPE_CHECKING:
from sup3r.preprocessing.samplers import DualSampler, Sampler

logger = logging.getLogger(__name__)


Expand All @@ -26,7 +28,7 @@ class ConditionalBatchQueue(SingleBatchQueue):

def __init__(
self,
samplers: Union[List[Sampler], List[DualSampler]],
samplers: Union[List['Sampler'], List['DualSampler']],
time_enhance_mode: str = 'constant',
lower_models: Optional[Dict[int, Sup3rCondMom]] = None,
s_padding: int = 0,
Expand Down
12 changes: 8 additions & 4 deletions sup3r/preprocessing/cachers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,25 @@
import itertools
import logging
import os
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, TYPE_CHECKING
import netCDF4 as nc4 # noqa
import h5py
import dask
import dask.array as da
import numpy as np

from sup3r.preprocessing.accessor import Sup3rX
from sup3r.preprocessing.base import Container, Sup3rDataset
from sup3r.preprocessing.base import Container
from sup3r.preprocessing.names import Dimension
from sup3r.preprocessing.utilities import _mem_check
from sup3r.utilities.utilities import safe_serialize

from .utilities import _check_for_cache

if TYPE_CHECKING:
from sup3r.preprocessing.accessor import Sup3rX
from sup3r.preprocessing.base import Sup3rDataset


logger = logging.getLogger(__name__)


Expand All @@ -32,7 +36,7 @@ class Cacher(Container):

def __init__(
self,
data: Union[Sup3rX, Sup3rDataset],
data: Union['Sup3rX', 'Sup3rDataset'],
cache_kwargs: Optional[Dict] = None,
):
"""
Expand Down
14 changes: 8 additions & 6 deletions sup3r/preprocessing/collections/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
integrated into xarray (in progress as of 8/8/2024)
"""

from typing import List, Union
from typing import TYPE_CHECKING, List, Union

import numpy as np

from sup3r.preprocessing.base import Container
from sup3r.preprocessing.samplers.base import Sampler
from sup3r.preprocessing.samplers.dual import DualSampler

if TYPE_CHECKING:
from sup3r.preprocessing.samplers.base import Sampler
from sup3r.preprocessing.samplers.dual import DualSampler


class Collection(Container):
Expand All @@ -26,9 +28,9 @@ class Collection(Container):
def __init__(
self,
containers: Union[
List[Container],
List[Sampler],
List[DualSampler],
List['Container'],
List['Sampler'],
List['DualSampler'],
],
):
super().__init__()
Expand Down

0 comments on commit 9058109

Please sign in to comment.