diff --git a/src/api.py b/src/api.py index 45ae8ac..7fe857a 100644 --- a/src/api.py +++ b/src/api.py @@ -20,10 +20,6 @@ # for algorithm predection (at least that's the plan) fml = fmlearn() -# doing this so that dataload and training of the model happens once at application start. -fml.load_data() -fml.train() - # Create a Metric @metrics_api.route('', methods=[POST]) def add_metric(): @@ -115,9 +111,13 @@ def predict_fmlearn(): data = {} - if fml._new_recs == math.inf: + if fml._new_recs == math.inf or len(utils.get_df_from_db()) <= fml.MAX_NEW_RECORDS: data['response'] = 'Model not trained!' return json.dumps(data) + + if not fml.is_model_trained(): + # doing this so that data load and training of the model happens once at application has enough data. + fml.load_data_and_train() # fetching the encoder for target type for encoding the input data tt_encoder = fml.get_encoders()[utils.TARGET_TYPE]