Skip to content

Commit

Permalink
add sklearn base
Browse files Browse the repository at this point in the history
  • Loading branch information
ynchuang committed Dec 17, 2024
1 parent 94dbfc2 commit c38f6b0
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions ltsm/common/sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import pandas as pd

def get_default_hyperparameter(primitive, hyperparameter):

# check if input legal hyperparameter
hyperparam_buf = list(primitive.metadata.get_hyperparams().defaults().keys())
hyperparam_input = list(hyperparameter.keys())
if not set(hyperparam_buf) > set(hyperparam_input):
invalid_hyperparam = list(set(hyperparam_input) - set(hyperparam_buf))
raise TypeError(primitive.__name__ + ' got unexpected keyword argument ' + str(invalid_hyperparam))

hyperparams_class = primitive.metadata.get_hyperparams()
hyperparams = hyperparams_class.defaults()

if len(hyperparameter.items()) != 0:
hyperparams = hyperparams.replace(hyperparameter)

return hyperparams

class BaseSKI:

def __init__(self, primitive, **hyperparameter):

self.fit_available = True if 'fit' in primitive.__dict__ else False
self.predict_available = True if 'produce' in primitive.__dict__ else False
self.predict_score_available = True if 'produce_score' in dir(primitive) else False
self.produce_available = True if 'produce' in primitive.__dict__ else False

hyperparams = get_default_hyperparameter(primitive, hyperparameter)
self.primitives = primitive(hyperparams=hyperparams)

def _sys_data_check(self, data):
if self.system_num == 1:
if type(data) is np.ndarray and data.ndim == 2:
data = [data] # np.expand_dims(data, axis=0)
else:
raise AttributeError('For system_num = 1, input data should be 2D numpy array.')
elif self.system_num > 1:
if type(data) is list and len(data) == self.system_num:
for ts_data in data:
if type(ts_data) is np.ndarray and ts_data.ndim == 2:
continue
else:
raise AttributeError('For system_num > 1, each element of input list should be 2D numpy arrays.')
else:
raise AttributeError('For system_num > 1, input data should be the list of `system_num` 2D numpy arrays.')

return data

def fit(self, data):

if not self.fit_available:
raise AttributeError('type object ' + self.__class__.__name__ + ' has no attribute \'fit\'')

data = self._sys_data_check(data)

for sys_idx, primitive in enumerate(self.primitives):
sys_data = data[sys_idx]
sys_data = self._transform(sys_data)
primitive.set_training_data(inputs=sys_data)
primitive.fit()

return

def predict(self, data):

if not self.predict_available:
raise AttributeError('type object ' + self.__class__.__name__ + ' has no attribute \'predict\'')

data = self._sys_data_check(data)
output_data = self._forward(data, '_produce')

return output_data

def _transform(self, X): #transform the ndarray to d3m dataframe, select columns to use
column_name = [str(col_index) for col_index in range(X.shape[1])]
return pd.DataFrame(X, columns=column_name, generate_metadata=True)

0 comments on commit c38f6b0

Please sign in to comment.