-
Notifications
You must be signed in to change notification settings - Fork 0
/
batchnorm.py
29 lines (27 loc) · 1.35 KB
/
batchnorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum, phrase='traing'):
"""batchnorm 单个batch数据
batchnorm 是在特征维度上进行
"""
# 当前模式是训练模式还是预测模式
if not phrase == 'traing':
# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
X_hat = (X - moving_mean) / np.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# 使用全连接层的情况,计算特征维上的均值和方差
mean = X.mean(axis=0)
var = ((X - mean)**2).mean(axis=0)
else:
# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。这里我们需要保持
# X的形状以便后面可以做广播运算
mean = X.mean(axis=(0, 2, 3), keepdims=True)
var = ((X - mean)**2).mean(axis=(0, 2, 3), keepdims=True)
# 训练模式下用当前的均值和方差做标准化
X_hat = (X - mean) / np.sqrt(var + eps)
# 更新移动平均的均值和方差
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
Y = gamma * X_hat + beta # 拉伸和偏移
return Y, moving_mean, moving_var