-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCausalityTrainer.py
57 lines (51 loc) · 2.2 KB
/
CausalityTrainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import data_io
import CausalityFeatureFunctions as f
from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import Pipeline
class CausalityTrainer:
def __init__(self, directionForward=True):
self.directionForward = directionForward
def getFeatureExtractor(self, features):
combined = f.FeatureMapper(features)
return combined
def getPipeline(self, feat):
features = self.getFeatureExtractor(feat)
steps = [("extract_features", features),
("classify", RandomForestRegressor(compute_importances=True, n_estimators=500,
verbose=2, n_jobs=1, min_samples_split=10,
random_state=0))]
return Pipeline(steps)
def getTrainingDataset(self):
print "Reading in the training data"
train = data_io.read_train_pairs()
print "Reading the information about the training data"
train2 = data_io.read_train_info()
train["A type"] = train2["A type"]
train["B type"] = train2["B type"]
return train
def run(self):
features = f.features
train = self.getTrainingDataset()
print "Reading preprocessed features"
if f.preprocessedFeatures != []:
intermediate = data_io.read_intermediate_train()
for i in f.preprocessedFeatures:
train[i] = intermediate[i]
for i in features:
if i[0] in f.preprocessedFeatures:
i[1] = i[0]
i[2] = f.SimpleTransform(transformer=f.ff.identity)
print "Reading targets"
target = data_io.read_train_target()
print "Extracting features and training model"
classifier = self.getPipeline(features)
if self.directionForward:
finalTarget = [ x*(x+1)/2 for x in target.Target]
else:
finalTarget = [ -x*(x-1)/2 for x in target.Target]
classifier.fit(train, finalTarget)
print classifier.steps[-1][1].feature_importances_
print "Saving the classifier"
data_io.save_model(classifier)
if __name__=="__main__":
ct = CausalityTrainer()