Skip to content

Commit

Permalink
added TODOs and test function to fmlearn class
Browse files Browse the repository at this point in the history
  • Loading branch information
mukeshmk committed Jun 17, 2020
1 parent 6239317 commit 446d426
Showing 1 changed file with 39 additions and 19 deletions.
58 changes: 39 additions & 19 deletions src/fmlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(self):
self._accuracy = None

def load_data(self):
# TODO: force new model to be trained once data has been reloaded?

# loads data from the SQL database and pre-processes the data.
self._df = utils.get_df_from_db()
self._shape = self._df.shape
Expand All @@ -37,6 +39,8 @@ def load_data(self):
self._y, _ = utils.label_encode_feature(self._y, utils.DATASET_HASH)

def train(self):
# TODO: throw error if the data is not loaded before training.
# remove the below work around.
if self._X is None:
print('data not loaded! \nloading data and then training model')
self.load_data()
Expand All @@ -51,37 +55,53 @@ def train(self):
y_pred = self._model.predict(X_test)
self._accuracy = accuracy_score(y_test, y_pred)

def load_data_and_train(self):
self.load_data()
self.train()

def predict(self, X_pred):
# TODO: check if the shape of the input df matches that used to train the model.
# TODO: force retrain of model if the model is older than a set time frame?

y_pred = self._model.predict(X_pred)
return y_pred


# TODO: return either the proper dataset_hash or
# the algorithm used by that dataset which is predicted.

return y_pred

def kmc():
df = utils.get_df_from_db()
df.fillna(-1, inplace=True)
def _test(self, print_details=False):
# this function tests the entire functionality without affecting the class variables
# loads data from db, pre-processes it and trains a model and displays the

df = utils.get_df_from_db()
df.fillna(-1, inplace=True)

X, y = utils.get_Xy(df)
X, y = utils.get_Xy(df)

# pre processing of data
X, _ = utils.ohe_feature(X, utils.TARGET_TYPE)
# pre processing of data
X, _ = utils.ohe_feature(X, utils.TARGET_TYPE)

y, _ = utils.label_encode_feature(y, utils.DATASET_HASH)
y, _ = utils.label_encode_feature(y, utils.DATASET_HASH)

# train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=123)
# train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=123)

model = KNeighborsClassifier()
model = KNeighborsClassifier()

model.fit(X_train, y_train)

y_pred = model.predict(X_test)

model.fit(X_train, y_train)

y_pred = model.predict(X_test)
print('accuracy: ' + str(accuracy_score(y_test, y_pred)))

print(model.kneighbors(X_test)[1])
print(y_test.to_string(header=False))
y_pred = pd.DataFrame(y_pred)
return y_pred.to_string(header=False)
if print_details:
# kneighbors()[0] contains distances to points
# kneighbors()[1] contains indcies of nearest points
print(model.kneighbors(X_test)[1])

print(y_test.to_string(header=False))
print(pd.DataFrame(y_pred).to_string(header=False))

if __name__ == "__main__":
kmc()

0 comments on commit 446d426

Please sign in to comment.