diff --git a/src/api.py b/src/api.py index 6e4e03d..45ae8ac 100644 --- a/src/api.py +++ b/src/api.py @@ -115,7 +115,7 @@ def predict_fmlearn(): data = {} - if not fml.is_model_trained(): + if fml._new_recs == math.inf: data['response'] = 'Model not trained!' return json.dumps(data) diff --git a/src/fmlearn.py b/src/fmlearn.py index a4c7a20..52d2d08 100644 --- a/src/fmlearn.py +++ b/src/fmlearn.py @@ -123,6 +123,9 @@ def predict(self, X_pred): # or if the retrain flag is set to true because new data has been loaded if self._retain == True or self._new_recs >= self.MAX_NEW_RECORDS: self.train() + + if not self.is_model_trained(): + self.load_data_and_train() y_pred = self._model.predict(X_pred)