Skip to content

Commit

Permalink
Merge pull request #28 from paucablop/add-scale-by-index
Browse files Browse the repository at this point in the history
Add scale by index
  • Loading branch information
paucablop authored Apr 17, 2023
2 parents 1b6fd61 + 13a531f commit 70fab9f
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 9 deletions.
5 changes: 3 additions & 2 deletions chemotools/scale/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .min_max_normalize import MinMaxScaler
from .l_normalize import LNormalize
from .index_scaler import IndexScaler
from .min_max_scaler import MinMaxScaler
from .norm_scaler import NormScaler
41 changes: 41 additions & 0 deletions chemotools/scale/index_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted

from chemotools.utils.check_inputs import check_input


class IndexScaler(BaseEstimator, TransformerMixin):
def __init__(self, index: int = 0):
self.index = index


def fit(self, X: np.ndarray, y=None) -> "IndexScaler":
# Check that X is a 2D array and has only finite values
X = check_input(X)

# Set the number of features
self.n_features_in_ = X.shape[1]

# Set the fitted attribute to True
self._is_fitted = True

return self

def transform(self, X: np.ndarray, y=None) -> np.ndarray:
# Check that the estimator is fitted
check_is_fitted(self, "_is_fitted")

# Check that X is a 2D array and has only finite values
X = check_input(X)
X_ = X.copy()

# Check that the number of features is the same as the fitted data
if X_.shape[1] != self.n_features_in_:
raise ValueError(f"Expected {self.n_features_in_} features but got {X_.shape[1]}")

# Scale the data by index
for i, x in enumerate(X_):
X_[i] = x / x[self.index]

return X_.reshape(-1, 1) if X_.ndim == 1 else X_
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from chemotools.utils.check_inputs import check_input


class LNormalize(BaseEstimator, TransformerMixin):
class NormScaler(BaseEstimator, TransformerMixin):
def __init__(self, l_norm: int = 2):
self.l_norm = l_norm

def fit(self, X: np.ndarray, y=None) -> "LNormalize":
def fit(self, X: np.ndarray, y=None) -> "NormScaler":
# Check that X is a 2D array and has only finite values
X = check_input(X)

Expand Down
1 change: 1 addition & 0 deletions index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Table of contents
* [Non-negative](#non-negative)
* [Subtract reference spectrum](#subtract-reference-spectrum)
* [Scale](#scale)
* [Index scaler](#index-scaler)
* [Min-max scaler](#minmax-scaler)
* [L-Norm scaler](#l-norm-scaler)
* [Smooth](#smooth)
Expand Down
15 changes: 12 additions & 3 deletions tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from chemotools.baseline import AirPls, LinearCorrection, NonNegative, SubtractReference
from chemotools.derivative import NorrisWilliams, SavitzkyGolay
from chemotools.scale import LNormalize, MinMaxScaler
from chemotools.scale import IndexScaler, MinMaxScaler, NormScaler
from chemotools.scatter import MultiplicativeScatterCorrection, StandardNormalVariate
from chemotools.smooth import MeanFilter, MedianFilter, WhittakerSmooth
from tests.fixtures import (
Expand All @@ -26,11 +26,20 @@ def test_air_pls(spectrum, reference_airpls):
# Assert
assert np.allclose(spectrum_corrected[0], reference_airpls[0], atol=1e-8)

def test_index_scaler(spectrum):
# Arrange
index_scaler = IndexScaler(index=0)
reference_spectrum = [value/spectrum[0][0] for value in spectrum[0]]
# Act
spectrum_corrected = index_scaler.fit_transform(spectrum)

# Assert
assert np.allclose(spectrum_corrected[0], reference_spectrum, atol=1e-8)

def test_l1_norm(spectrum):
# Arrange
norm = 1
l1_norm = LNormalize(l_norm=norm)
l1_norm = NormScaler(l_norm=norm)
spectrum_norm = np.linalg.norm(spectrum[0], ord=norm)

# Act
Expand All @@ -43,7 +52,7 @@ def test_l1_norm(spectrum):
def test_l2_norm(spectrum):
# Arrange
norm = 2
l1_norm = LNormalize(l_norm=norm)
l1_norm = NormScaler(l_norm=norm)
spectrum_norm = np.linalg.norm(spectrum[0], ord=norm)

# Act
Expand Down
11 changes: 9 additions & 2 deletions tests/test_sklearn_compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from chemotools.baseline import AirPls, CubicSplineCorrection, LinearCorrection, NonNegative, PolynomialCorrection, SubtractReference
from chemotools.derivative import NorrisWilliams, SavitzkyGolay
from chemotools.scale import MinMaxScaler, LNormalize
from chemotools.scale import IndexScaler, MinMaxScaler, NormScaler
from chemotools.scatter import MultiplicativeScatterCorrection, StandardNormalVariate
from chemotools.smooth import MeanFilter, MedianFilter, SavitzkyGolayFilter, WhittakerSmooth

Expand All @@ -23,6 +23,13 @@ def test_compliance_cubic_spline_correction():
# Act & Assert
check_estimator(transformer)

# IndexScaler
def test_compliance_index_scaler():
# Arrange
transformer = IndexScaler()
# Act & Assert
check_estimator(transformer)

# LinearCorrection
def test_compliance_linear_correction():
# Arrange
Expand All @@ -33,7 +40,7 @@ def test_compliance_linear_correction():
# LNormalize
def test_compliance_l_norm():
# Arrange
transformer = LNormalize()
transformer = NormScaler()
# Act & Assert
check_estimator(transformer)

Expand Down

0 comments on commit 70fab9f

Please sign in to comment.