Skip to content

Commit

Permalink
Merge pull request #74 from pycroscopy/cls_reg
Browse files Browse the repository at this point in the history
Add 'Classifier' class to atomai models
  • Loading branch information
ziatdinovmax authored Apr 3, 2023
2 parents e781591 + d736a94 commit 10d98c3
Show file tree
Hide file tree
Showing 14 changed files with 562 additions and 46 deletions.
4 changes: 3 additions & 1 deletion atomai/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .segmentor import Segmentor
from .imspec import ImSpec
from .regressor import Regressor
from .classifier import Classifier
from .dgm import BaseVAE, VAE, rVAE, jVAE, jrVAE
from .dklgp import dklGPR
from .loaders import load_model, load_ensemble, load_pretrained_model

__all__ = ["Segmentor", "ImSpec", "BaseVAE", "VAE", "rVAE",
"jVAE", "jrVAE", "load_model", "load_ensemble",
"load_pretrained_model", "dklGPR", "Regressor"]
"load_pretrained_model", "dklGPR", "Regressor",
"Classifier"]
134 changes: 134 additions & 0 deletions atomai/models/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from typing import Type, Union, Optional
import torch
import numpy as np
from ..trainers import clsTrainer
from ..predictors import clsPredictor
from ..transforms import reg_augmentor


class Classifier(clsTrainer):
"""
Model for classification tasks
Args:
model:
The backbone regressor model (defaults to 'mobilenet')
nb_classes:
Number of classes
Example:
>>> # Initialize and train a classification model
>>> model = aoi.models.Classifier(nb_classes=4)
>>> model.fit(train_images, train_targets, test_images, test_targets,
>>> full_epoch=True, training_cycles=30, swa=True)
>>> # Make a prediction with the trained model
>>> prediction = model.predict(imgs_new, norm=True)
"""
def __init__(self,
model: str = 'mobilenet',
nb_classes: int = None,
**kwargs) -> None:
if nb_classes is None:
raise AssertionError(
"You must specify a number of classes (nb_classes) for your classification model")
super(Classifier, self).__init__(nb_classes, model, **kwargs)

def fit(self,
X_train: Union[np.ndarray, torch.Tensor],
y_train: Union[np.ndarray, torch.Tensor],
X_test: Optional[Union[np.ndarray, torch.Tensor]] = None,
y_test: Optional[Union[np.ndarray, torch.Tensor]] = None,
loss: str = 'ce',
optimizer: Optional[Type[torch.optim.Optimizer]] = None,
training_cycles: int = 1000,
batch_size: int = 64,
compute_accuracy: bool = False,
full_epoch: bool = False,
swa: bool = False,
perturb_weights: bool = False,
**kwargs):
"""
Compiles a trainer and performs model training
Args:
X_train:
4D numpy array with image data (n_samples x 1 x height x width).
It is also possible to pass 3D by ignoring the channel dim,
which will be added automatically.
y_train:
1D numpy array of integers with target classes
X_test:
4D numpy array with image data (n_samples x 1 x height x width).
It is also possible to pass 3D by ignoring the channel dim,
which will be added automatically.
y_test:
1D numpy array of integers with target classes.
loss:
Loss function (defaults to 'ce')
optimizer:
weights optimizer (defaults to Adam optimizer with lr=1e-3)
training_cycles: Number of training 'epochs'.
If full_epoch argument is set to False, 1 epoch == 1 mini-batch.
Otherwise, each cycle corresponds to all mini-batches of data
passing through a NN.
batch_size:
Size of training and test mini-batches
full_epoch:
If True, passes all mini-batches of training/test data
at each training cycle and computes the average loss. If False,
passes a single (randomly chosen) mini-batch at each cycle.
swa:
Saves the recent stochastic weights and averages
them at the end of training
perturb_weights:
Time-dependent weight perturbation, :math:`w\\leftarrow w + a / (1 + e)^\\gamma`,
where parameters *a* and *gamma* can be passed as a dictionary,
together with parameter *e_p* determining every *n*-th epoch at
which a perturbation is applied
**print_loss (int):
Prints loss every *n*-th epoch
**filename (str):
Filename for model weights
(appended with "_test_weights_best.pt" and "_weights_final.pt")
**plot_training_history (bool):
Plots training and test curves vs. training cycles
at the end of training
**kwargs:
One can also pass kwargs for utils.datatransform class
to perform the augmentation "on-the-fly"
(e.g. gauss_noise=[20, 60], etc.)
"""
self.compile_trainer(
(X_train, y_train, X_test, y_test),
loss, optimizer, training_cycles, batch_size,
compute_accuracy, full_epoch, swa, perturb_weights,
**kwargs)

self.augment_fn = reg_augmentor(**kwargs) # use the regression model's augmentor
_ = self.run()

def predict(self,
data: np.ndarray,
**kwargs) -> np.ndarray:
"""
Apply (trained model) to new data
Args:
data: Input image or batch of images
**num_batches (int): number of batches (Default: 10)
**norm (bool): Normalize data to (0, 1) during pre-processing
**verbose (bool): verbosity (Default: True)
"""
use_gpu = self.device == 'cuda'
nn_output = clsPredictor(
self.net, self.nb_classes, use_gpu,
**kwargs).run(data, **kwargs)
return nn_output

def load_weights(self, filepath: str) -> None:
"""
Loads saved weights dictionary
"""
weight_dict = torch.load(filepath, map_location=self.device)
self.net.load_state_dict(weight_dict)
27 changes: 27 additions & 0 deletions atomai/models/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .segmentor import Segmentor
from .imspec import ImSpec
from .regressor import Regressor
from .classifier import Classifier
from .dgm import BaseVAE, VAE, rVAE, jrVAE, jVAE
from ..utils import average_weights

Expand Down Expand Up @@ -44,6 +45,8 @@ def load_model(filepath: str) -> Union[Segmentor, Union[VAE, rVAE, jrVAE, jVAE],
model = load_imspec_model(loaded_dict)
elif model_type == "reg":
model = load_reg_model(loaded_dict)
elif model_type == "cls":
model = load_cls_model(loaded_dict)
elif model_type == "vae":
model = load_vae_model(loaded_dict)
else:
Expand Down Expand Up @@ -130,6 +133,30 @@ def load_reg_model(meta_dict: Dict[str, torch.Tensor]) -> Type[Regressor]:
return model


def load_cls_model(meta_dict: Dict[str, torch.Tensor]) -> Type[Regressor]:
"""
Loads trained AtomAI classification models
Args:
meta_dict (str):
dictionary with trained weights and key information
about model's structure
Returns:
Classifier object with NN in evaluation state
"""
backbone = meta_dict.pop("backbone")
nb_classes = meta_dict.pop("nb_classes")
weights = meta_dict.pop("weights")
model = Classifier(backbone, nb_classes, **meta_dict)
model.net.load_state_dict(weights)
if "optimizer" in meta_dict.keys():
optimizer = meta_dict.pop("optimizer")
model.optimizer = optimizer
model.net.eval()
return model


def load_vae_model(meta_dict: Dict[str, torch.Tensor]) -> Type[BaseVAE]:
"""
Loads trained AtomAI ImSpec models
Expand Down
2 changes: 1 addition & 1 deletion atomai/models/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Regressor(RegTrainer):
Args:
model:
The ackbone regressor model (defaults to 'mobilenet')
The backbone regressor model (defaults to 'mobilenet')
out_dim:
Output dimensions (Defaults to 1)
Expand Down
4 changes: 2 additions & 2 deletions atomai/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
rDecoderNet, init_imspec_model, init_VAE_nets)
from .fcnn import Unet, dilnet, SegResNet, ResHedNet, init_fcnn_model
from .gp import fcFeatureExtractor, GPRegressionModel
from .reg_cls import RegressorNet, ClassifierNet, init_reg_model
from .reg_cls import RegressorNet, ClassifierNet, init_reg_model, init_cls_model

__all__ = ['ConvBlock', 'ResBlock', 'ResModule', 'UpsampleBlock', 'DilatedBlock',
'init_fcnn_model', 'SegResNet', 'Unet', 'ResHedNet', 'dilnet', 'fcEncoderNet',
'fcDecoderNet', 'convEncoderNet', 'convDecoderNet', 'rDecoderNet',
'coord_latent', 'load_model', 'load_ensemble', 'init_imspec_model',
'init_VAE_nets', 'SignalEncoder', 'SignalDecoder', 'SignalED',
'fcFeatureExtractor', 'GPRegressionModel', 'CustomBackbone', 'RegressorNet',
'ClassifierNet']
'ClassifierNet', 'init_reg_model', 'init_cls_model']
12 changes: 12 additions & 0 deletions atomai/nets/reg_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,15 @@ def init_reg_model(out_dim, backbone_type, input_channels=1, **kwargs):
"out_dim": out_dim
}
return net, meta_state_dict


def init_cls_model(num_classes, backbone_type, input_channels=1, **kwargs):
"""Initializes a regression model with a specified backbone type"""
net = ClassifierNet(input_channels, num_classes, backbone_type)
meta_state_dict = {
"model_type": "cls",
"backbone": backbone_type,
"in_channels": input_channels,
"nb_classes": num_classes
}
return net, meta_state_dict
5 changes: 3 additions & 2 deletions atomai/predictors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .predictor import BasePredictor, SegPredictor, ImSpecPredictor, RegPredictor, Locator
from .epredictor import EnsemblePredictor, ensemble_locate
from .predictor import (BasePredictor, ImSpecPredictor, Locator, RegPredictor,
SegPredictor, clsPredictor)

__all__ = ["BasePredictor", "SegPredictor", "ImSpecPredictor", "RegPredictor",
"EnsemblePredictor", "ensemble_locate", "Locator"]
"clsPredictor", "EnsemblePredictor", "ensemble_locate", "Locator"]
49 changes: 49 additions & 0 deletions atomai/predictors/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,55 @@ def run(self,
+ str(np.around(time.time() - start_time, decimals=4))
+ ' seconds')
return prediction


class clsPredictor(RegPredictor):
"""
Prediction with a trained classifier
Args:
trained_model:
Pre-trained neural network
nb_classes:
number of classes in a classification scheme
use_gpu:
Use GPU accelration for prediction
verbose:
Verbosity
Example:
>>> # Make predictions with trained regression model
>>> nb_classes = 10
>>> prediction = clsPredictor(trained_model, nb_classes).run(data)
"""
def __init__(self,
trained_model: Type[torch.nn.Module],
nb_classes: int,
use_gpu: bool = False,
**kwargs: str) -> None:
"""
Initialize predictor
"""
super(clsPredictor, self).__init__(trained_model, nb_classes, use_gpu, **kwargs)

def predict(self,
image_data: np.ndarray,
**kwargs: int) -> np.ndarray:
"""
Categorizes an input image or a batch of input images
Args:
image_data (numpy array): Input image or batch of images
**num_batches (int): number of batches (Default: 10)
**norm (bool): Normalize data to (0, 1) during pre-processing
"""
num_batches = kwargs.get("num_batches", 10)
image_data = self.preprocess(image_data, kwargs.get("norm", True))
output = self.batch_predict(
image_data, (len(image_data), self.output_dim), num_batches)
output = torch.argmax(output, 1)
return output.squeeze().numpy()


class Locator:
Expand Down
4 changes: 2 additions & 2 deletions atomai/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .trainer import SegTrainer, ImSpecTrainer, RegTrainer, BaseTrainer
from .trainer import SegTrainer, ImSpecTrainer, RegTrainer, clsTrainer, BaseTrainer
from .etrainer import BaseEnsembleTrainer, EnsembleTrainer
from .vitrainer import viBaseTrainer
from .gptrainer import dklGPTrainer

__all__ = ["SegTrainer", "ImSpecTrainer", "BaseTrainer", "BaseEnsembleTrainer",
"EnsembleTrainer", "viBaseTrainer", "dklGPTrainer"]
"EnsembleTrainer", "viBaseTrainer", "dklGPTrainer", "RegTrainer", "clsTrainer"]
Loading

0 comments on commit 10d98c3

Please sign in to comment.