Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Showing function in leaf nodes when using linear regression #10

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
4 changes: 3 additions & 1 deletion advanced_ML/model_tree/models/DT_sklearn_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ def predict(self, X):

def loss(self, X, y, y_pred):
return gini_impurity(y)
def get_params(self):
return None

def gini_impurity(y):
p2 = 0.0
y_classes = list(set(y))
for c in y_classes:
p2 += (np.sum(y == c) / len(y))**2
loss = 1.0 - p2
return loss
return loss
4 changes: 3 additions & 1 deletion advanced_ML/model_tree/models/DT_sklearn_regr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ def predict(self, X):
return self.model.predict(X)

def loss(self, X, y, y_pred):
return mean_squared_error(y, y_pred)
return mean_squared_error(y, y_pred)
def get_params(self):
return None
5 changes: 4 additions & 1 deletion advanced_ML/model_tree/models/NN_regr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,7 @@ def predict(self, X):
return self.model.predict(X)

def loss(self, X, y, y_pred):
return mean_squared_error(y, y_pred)
return mean_squared_error(y, y_pred)

def get_params(self):
return None
3 changes: 2 additions & 1 deletion advanced_ML/model_tree/models/linear_regr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ def predict(self, X):

def loss(self, X, y, y_pred):
return mean_squared_error(y, y_pred)

def get_params(self):
return self.model.coef_
5 changes: 4 additions & 1 deletion advanced_ML/model_tree/models/logistic_regr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""
from sklearn.metrics import mean_squared_error
import numpy as np

class logistic_regr:

Expand All @@ -26,6 +27,8 @@ def predict(self, X):
return self.flag_y_pred * np.ones((len(X),), dtype=int)
else:
return self.model.predict(X)
def get_params(self):
return None

def loss(self, X, y, y_pred):
return mean_squared_error(y, y_pred)
return mean_squared_error(y, y_pred)
5 changes: 4 additions & 1 deletion advanced_ML/model_tree/models/mean_regr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ def predict(self, X):
return self.y_mean * np.ones(len(X))

def loss(self, X, y, y_pred):
return mean_squared_error(y, y_pred)
return mean_squared_error(y, y_pred)

def get_params(self):
return None
5 changes: 4 additions & 1 deletion advanced_ML/model_tree/models/modal_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ def predict(self, X):

def loss(self, X, y, y_pred):
return gini_impurity(y)

def get_params(self):
return None

def gini_impurity(y):
p2 = 0.0
y_classes = list(set(y))
for c in y_classes:
p2 += (np.sum(y == c) / len(y))**2
loss = 1.0 - p2
return loss
return loss
5 changes: 4 additions & 1 deletion advanced_ML/model_tree/models/svm_regr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,7 @@ def predict(self, X):
return self.model.predict(X)

def loss(self, X, y, y_pred):
return mean_squared_error(y, y_pred)
return mean_squared_error(y, y_pred)

def get_params(self):
return None
14 changes: 13 additions & 1 deletion advanced_ML/model_tree/src/ModelTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def _explain(node, x, explanation):
no_children = node["children"]["left"] is None and \
node["children"]["right"] is None
if no_children:


return explanation
else:
if x[node["j_feature"]] <= node["threshold"]: # x[j] < threshold
Expand Down Expand Up @@ -196,7 +198,17 @@ def build_graphviz_recurse(node, parent_node_index=0, parent_depth=0, edge_label
# Create node
node_index = node["index"]
if node["children"]["left"] is None and node["children"]["right"] is None:
threshold_str = ""

params = node['model'].get_params()
if params is not None :
threshold_str="y = "
for i in range(len(params)):
threshold_str += str(round(params[i],2))
threshold_str += "*X"+str(i)
threshold_str+="\n"
else:
threshold_str=""

else:
threshold_str = "{} <= {:.1f}\\n".format(feature_names[node['j_feature']], node["threshold"])

Expand Down