Skip to content

Commit

Permalink
fixes #12: force new model to be trained once data has been reloaded
Browse files Browse the repository at this point in the history
  • Loading branch information
mukeshmk committed Jul 16, 2020
1 parent 9ecec12 commit 5165f6b
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/fmlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(self):
self._accuracy = None
# to store feature encoders
self._encoders = {}
# boolean to track retraining of the model
self._retain = False

def get_encoders(self):
return self._encoders
Expand All @@ -30,7 +32,10 @@ def get_X_cols(self):
return self._X.columns

def load_data(self):
# TODO: force new model to be trained once data has been reloaded?
# force new model to be trained once data has been reloaded
# retraining of model occurs only when the predict() function is being called
# this is due to absense of background processes framework in the application.
self._retain = True

# loads data from the SQL database and pre-processes the data.
self._df = utils.get_df_from_db()
Expand Down Expand Up @@ -68,12 +73,17 @@ def load_data_and_train(self):
self.train()

def predict(self, X_pred):
# TODO: force retrain of model if the model is older than a set time frame?

# check if the shape of the input df matches that used to train the model.
if X_pred.shape[1] != self._X.shape[1]:
raise RuntimeError('Input Shape miss match! aborting!')

# TODO: force retrain of model if the model is older than a set time frame?
# or if a set of new data records have been added to the model.
# at this point reload the data and train the model.
if self._retain == True:
self.train()
self._retain = False

y_pred = self._model.predict(X_pred)

# decodes the predicted value using the inverse_transform method of the encoder.
Expand Down

0 comments on commit 5165f6b

Please sign in to comment.