-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy path8.3.py
87 lines (75 loc) · 2.36 KB
/
8.3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 26 18:08:51 2017
@author: ZQ
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_gaussian_quantiles
#构造数据集
X1,y1 = make_gaussian_quantiles(cov=2.0,
n_samples=200,
n_features=2,
n_classes=2,
random_state=1)
X2,y2 = make_gaussian_quantiles(mean=(3,3),
cov=1.5,
n_samples=300,
n_features=2,
n_classes=2,
random_state=1)
X = np.concatenate((X1,X2))
y = np.concatenate((y1,-y2+1))
#拟合提升树
bdt = AdaBoostClassifier(DecisionTreeClassifier(max_depth=1),
algorithm="SAMME",
n_estimators=200)
bdt.fit(X,y)
plot_colors = 'br'
plot_step = 0.02
class_names = 'AB'
plt.figure(figsize = (10,5))
#决策边界
plt.subplot(121)
x_min,x_max = X[:,0].min()-1,X[:,0].max()+1
y_min,y_max = X[:,1].min()-1,X[:,1].max()+1
xx,yy = np.meshgrid(np.arange(x_min,x_max,plot_step),
np.arange(y_min,y_max,plot_step))
Z = bdt.predict(np.c_[xx.ravel(),yy.ravel()])
Z = Z.reshape(xx.shape)
cs = plt.contourf(xx,yy,Z,cmap = plt.cm.Paired)
plt.axis("tight")
#训练点
for i,n,c in zip(range(2),class_names,plot_colors):
idx = np.where(y==i)
plt.scatter(X[idx,0],X[idx,1],
c=c,cmap=plt.cm.Paired,
label="Class %s"%n)
plt.xlim(x_min,x_max)
plt.ylim(y_min,y_max)
plt.legend(loc='upper right')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Decision Boundary')
#两类的分类情况
twoclass_output = bdt.decision_function(X)
plot_range = (twoclass_output.min(),twoclass_output.max())
plt.subplot(122)
for i,n,c in zip(range(2),class_names,plot_colors):
plt.hist(twoclass_output[y==i],
bins=10,
range=plot_range,
facecolor=c,
label='Class %s'%n,
alpha=0.5)
x1, x2, y1, y2 = plt.axis()
plt.axis((x1, x2, y1, y2 * 1.2))
plt.legend(loc='upper right')
plt.ylabel('Samples')
plt.xlabel('Score')
plt.title('Decision Scores')
plt.tight_layout()
plt.subplots_adjust(wspace=0.35)
plt.show()