Skip to content

Commit

Permalink
Format with isort
Browse files Browse the repository at this point in the history
  • Loading branch information
RonMcKay committed Sep 29, 2022
1 parent a83924e commit cc615c5
Show file tree
Hide file tree
Showing 39 changed files with 112 additions and 74 deletions.
5 changes: 3 additions & 2 deletions bnn_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from copy import deepcopy

from sacred import Ingredient
import torch.nn as nn

from bnn_models.lenet import BLeNet
from bnn_models.resnet import BClassifier
from datasets import datasets
from sacred import Ingredient
import torch.nn as nn
from utils import load_config_from_checkpoint

bnn_models = Ingredient("bnn_model", ingredients=(datasets,))
Expand Down
3 changes: 2 additions & 1 deletion bnn_models/lenet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Tuple

import bnn
from cls_models.base import BaseClassifier
import torch
import torch.nn as nn

from cls_models.base import BaseClassifier


class BLeNet(BaseClassifier):
def __init__(
Expand Down
3 changes: 2 additions & 1 deletion bnn_models/resnet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from math import ceil, floor, log2

import bnn
from cls_models.base import BaseClassifier
import torch
import torch.nn as nn

from cls_models.base import BaseClassifier
from utils import init_weights

N_FEATUREMAPS = 32
Expand Down
3 changes: 2 additions & 1 deletion cae_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datasets import datasets
from sacred import Ingredient
import torch.nn as nn

from datasets import datasets
from utils import load_config_from_checkpoint

from .medium import MediumCAE
Expand Down
1 change: 1 addition & 0 deletions cae_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn.functional as tf
from torch.optim import Adam
from torch.optim.lr_scheduler import MultiStepLR

from utils import save_sample_images


Expand Down
1 change: 1 addition & 0 deletions cae_models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from torch import Tensor
import torch.nn as nn

from utils import init_weights

from .base import BaseCAE
Expand Down
3 changes: 2 additions & 1 deletion cls_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from copy import deepcopy
import re

from sacred import Ingredient

from bnn_models.lenet import BLeNet
from bnn_models.resnet import BClassifier
from cls_models.lenet import LeNet
Expand All @@ -9,7 +11,6 @@
from cls_models.small import SmallClassifier
from cls_models.toy import ToyClassifier
from datasets import datasets
from sacred import Ingredient
from utils import load_config_from_checkpoint

from .base import BaseClassifier
Expand Down
1 change: 1 addition & 0 deletions cls_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import Tensor, optim
import torch.nn.functional as tf
from torch.utils.data import DataLoader

from utils import entropy

from .utils import set_model_to_mode
Expand Down
3 changes: 2 additions & 1 deletion cls_models/resnet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from math import floor, log2

from gan_models.resnet import ResidualBlock
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.conv import _ConvNd

from gan_models.resnet import ResidualBlock

from .base import BaseClassifier

N_FEATUREMAPS = 32
Expand Down
3 changes: 2 additions & 1 deletion cls_models/toy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datasets.toy import ToyDataset
import torch.nn as nn

from datasets.toy import ToyDataset

from .base import BaseClassifier


Expand Down
9 changes: 5 additions & 4 deletions confident_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
import logging
from typing import Any, Dict

from cls_models import set_model_to_mode
from cls_models.base import BaseClassifier
from datasets import load_data
from eval_ood_detection import eval_classifier
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities.types import STEP_OUTPUT
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as tf

from cls_models import set_model_to_mode
from cls_models.base import BaseClassifier
from datasets import load_data
from eval_ood_detection import eval_classifier


class ConfidentClassifier(LightningModule):
def __init__(
Expand Down
5 changes: 3 additions & 2 deletions cvae_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from sacred import Ingredient
import torch.nn as nn

from cvae_models.medium import MediumCVAE
from cvae_models.small import SmallCVAE
from datasets import datasets
from sacred import Ingredient
import torch.nn as nn
from utils import load_config_from_checkpoint

cvae_models = Ingredient("cvae_model", ingredients=(datasets,))
Expand Down
15 changes: 8 additions & 7 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
import logging
from typing import Dict, Optional, Tuple

import numpy as np
from sacred import Ingredient
import torch
from torch.utils.data import Dataset, WeightedRandomSampler
import torchvision.transforms as trans
import torchvision.transforms.functional as ftrans
from typing_extensions import TypedDict

from datasets.celeba import CelebA
from datasets.cifar10 import CIFAR10
from datasets.cifar100 import CIFAR100
Expand All @@ -18,13 +26,6 @@
from datasets.svhn import SVHN
from datasets.tinyimagenet import TinyImageNet
from datasets.toy import ToyDataset, ToyDataset2, ToyDataset3, ToyDataset4, ToyDataset5
import numpy as np
from sacred import Ingredient
import torch
from torch.utils.data import Dataset, WeightedRandomSampler
import torchvision.transforms as trans
import torchvision.transforms.functional as ftrans
from typing_extensions import TypedDict
from utils import IncompatibleRange, get_range

datasets = Ingredient("dataset")
Expand Down
3 changes: 2 additions & 1 deletion datasets/clawa2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from typing import Callable, Optional

from datasets.datahandlers.cl import AwA2
import numpy as np
import torch
from torch.utils.data import Dataset, random_split

from datasets.datahandlers.cl import AwA2

EVAL_RATIO = 0.2


Expand Down
11 changes: 6 additions & 5 deletions datasets/datahandlers/cl/awa2.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# PyTorch Dataloader based on 'https://github.com/dfan/awa2-zero-shot-learning'

import numpy as np
import torch
import logging
import os
from os.path import exists, join, splitext, isfile
from os.path import exists, isfile, join, splitext
from typing import Callable, Optional, Sequence, Union

from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
import logging
from typing import Callable, Optional, Union, Sequence


class AwA2(Dataset):
Expand Down
4 changes: 2 additions & 2 deletions datasets/datahandlers/cl/cub.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def __init__(
Args:
root: Root path of the dataset.
split: Split to use. Valid options: ('train', 'test', 'all')
minimum_attribute_certainty: Minimum certainty of the annotated attribute as provided by the human
annotator. Valid options: (1, 2, 3, 4)
minimum_attribute_certainty: Minimum certainty of the annotated attribute
as provided by the human annotator. Valid options: (1, 2, 3, 4)
transform: Image transforms.
target_transform: Target transforms.
target_type: Target types to return. Valid options: ('attr', 'class')
Expand Down
3 changes: 2 additions & 1 deletion datasets/tinyimagenet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
from typing import Callable, Optional

from datasets.datahandlers.cl import TinyImageNet as myTinyImageNet
import torch
from torch.utils.data import Dataset, random_split

from datasets.datahandlers.cl import TinyImageNet as myTinyImageNet

EVAL_RATIO = 0.2


Expand Down
2 changes: 1 addition & 1 deletion datasets/toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def __len__(self) -> int:


if __name__ == "__main__":
import matplotlib
import matplotlib # noqa: F401

# matplotlib.use("Qt5Agg")
import matplotlib.pyplot as plt
Expand Down
7 changes: 4 additions & 3 deletions eval_ood_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import os
from os.path import abspath, exists, expanduser, join

from cls_models import cls_models, load_cls_model, set_model_to_mode
from cls_models.base import BaseClassifier
from datasets import datasets, load_data
from eval.binary import aupr, auroc, ece, fprxtpr
from logging_utils import log_config
import numpy as np
Expand All @@ -15,6 +12,10 @@
from tabulate import tabulate
import torch
from torch.utils.data import ConcatDataset, DataLoader

from cls_models import cls_models, load_cls_model, set_model_to_mode
from cls_models.base import BaseClassifier
from datasets import datasets, load_data
from utils import extract_exp_id_from_path, format_int_list, get_range, init_experiment

ex = Experiment("evaluate OOD detection", ingredients=[datasets, cls_models])
Expand Down
5 changes: 3 additions & 2 deletions eval_ood_detection_deep_ensembles.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
from os.path import abspath, exists, expanduser, join

from cls_models import cls_models, load_cls_model
from datasets import datasets, load_data
from eval.binary import aupr, auroc, ece, fprxtpr
from logging_utils import log_config
import numpy as np
Expand All @@ -12,6 +10,9 @@
from tabulate import tabulate
import torch
from torch.utils.data import ConcatDataset, DataLoader

from cls_models import cls_models, load_cls_model
from datasets import datasets, load_data
from utils import (
entropy,
extract_exp_id_from_path,
Expand Down
3 changes: 2 additions & 1 deletion gan_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from copy import deepcopy
from typing import Tuple

from datasets import datasets
from sacred import Ingredient
import torch.nn as nn

from datasets import datasets
from utils import load_config_from_checkpoint

from .dcgan import Discriminator as DCGANDiscriminator
Expand Down
1 change: 1 addition & 0 deletions gan_models/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torch.distributions.normal import Normal
import torch.nn as nn

from utils import init_weights

LATENT_DIM = 128
Expand Down
1 change: 1 addition & 0 deletions gan_models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from torch.distributions.uniform import Uniform
import torch.nn as nn

from utils import init_weights

LATENT_DIM = 128
Expand Down
1 change: 1 addition & 0 deletions gan_models/small.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torch.distributions.uniform import Uniform
import torch.nn as nn

from utils import init_weights

LATENT_DIM = 128
Expand Down
3 changes: 2 additions & 1 deletion gan_models/toy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datasets.toy import ToyDataset
from pytorch_lightning import LightningModule
import torch
from torch.distributions.uniform import Uniform
import torch.nn as nn

from datasets.toy import ToyDataset
from utils import init_gan_weights


Expand Down
7 changes: 4 additions & 3 deletions gen.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from copy import deepcopy
from typing import Any, Dict

from cls_models.base import BaseClassifier, set_model_to_mode
from datasets import load_data
from eval_ood_detection import eval_classifier
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.accelerators.registry import AcceleratorRegistry
import torch
from torch import Tensor, optim
import torch.nn as nn
import torch.nn.functional as tf

from cls_models.base import BaseClassifier, set_model_to_mode
from datasets import load_data
from eval_ood_detection import eval_classifier


class GEN(LightningModule):
def __init__(
Expand Down
1 change: 1 addition & 0 deletions meta_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sacred import Ingredient
import torch.nn as nn

from utils import load_config_from_checkpoint

from .fc import MetaClassifier
Expand Down
1 change: 1 addition & 0 deletions meta_models/fc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.nn as nn

from utils import init_weights


Expand Down
5 changes: 3 additions & 2 deletions train_cae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

matplotlib.use("Agg")

from cae_models import cae_models, load_cae_model
from datasets import datasets, load_data
from logging_utils import log_config
from logging_utils.lightning_sacred import SacredLogger

from cae_models import cae_models, load_cae_model
from datasets import datasets, load_data
from utils import TimeEstimator, get_experiment_folder, init_experiment

ex = Experiment("train_cae", ingredients=[datasets, cae_models])
Expand Down
5 changes: 3 additions & 2 deletions train_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
from os.path import exists
import shutil

from cls_models import cls_models, load_cls_model
from datasets import datasets, load_data
from logging_utils import log_config
from logging_utils.lightning_sacred import SacredLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from sacred import Experiment
import torch
from torch.utils.data import DataLoader

from cls_models import cls_models, load_cls_model
from datasets import datasets, load_data
from utils import TimeEstimator, get_experiment_folder, init_experiment

ex = Experiment("train_classifier", ingredients=(cls_models, datasets))
Expand Down
Loading

0 comments on commit cc615c5

Please sign in to comment.