-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
executable file
·72 lines (55 loc) · 1.97 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python3
import argparse
import pickle
import pandas as pd
from sklearn import pipeline
featureList = ["R_odds", "B_odds", "R_wins",
"B_wins", "lose_streak_dif", "win_streak_dif",
"age_dif", "height_dif", "reach_dif"]
def parseArgs():
parser = argparse.ArgumentParser(
prog="UFC Winner predictor", description="A predictor for UFC fights")
parser.add_argument('action', default="predict", choices=[
"predict", "load"], nargs="?")
parser.add_argument("--model-name", '-m',
default="./models/model1.pkl", type=str)
parser.add_argument("--data-file", '-d',
default="./data/example.csv", type=str)
args = parser.parse_args()
return args
def loadData(fileName: str, label: str = "Winner", features: list[str] = featureList):
df = pd.read_csv(fileName)
X = df[features]
Y = df[label]
return X, Y
def loadDataNoDrop(fileName: str):
df = pd.read_csv(fileName)
return df
def predict(modelName, dataFile):
x, _ = loadData(dataFile)
xNames = loadDataNoDrop(dataFile)
model = pickle.load(open(modelName, 'rb'))
winners = model.predict(x)
for i in range(len(winners)):
winner = "Null"
if winners[i] == "Blue":
# 3 is the index of the B_fighter column
winner = xNames.iloc[i][1]
else:
# 0 is the index of the R_fighter column
winner = xNames.iloc[i][0]
print("\n", xNames.iloc[i][1], "Vs",
xNames.iloc[i][0], "Winner:", winner)
def showModel(modelName):
model = pickle.load(open(modelName, 'rb'))
print(model.best_score_)
print(model.best_params_)
print(model.feature_names_in_)
print(model.best_estimator_)
def main(args):
if args.action == "predict":
predict(args.model_name, args.data_file)
if args.action == "load":
showModel(args.model_name)
if __name__ == "__main__":
main(parseArgs())