目前数据预处理最常见的方法就是中心化和标准化。
中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。
标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间
批标准化:BN
在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。
所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。
batch normalization 的实现非常简单,接下来写一下对应的python代码:
import sys sys.path.append('..') import torch def simple_batch_norm_1d(x, gamma, beta): eps = 1e-5 x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True) x_hat = (x - x_mean) / torch.sqrt(x_var + eps) return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean) x = torch.arange(15).view(5, 3) gamma = torch.ones(x.shape[1]) beta = torch.zeros(x.shape[1]) print('before bn: ') print(x) y = simple_batch_norm_1d(x, gamma, beta) print('after bn: ') print(y)
测试的时候该使用批标准化吗?
答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替
下面我们实现以下能够区分训练状态和测试状态的批标准化方法
def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1): eps = 1e-5 x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True) if is_training: x_hat = (x - x_mean) / torch.sqrt(x_var + eps) moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var else: x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps) return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
下面我们在卷积网络下试用一下批标准化看看效果
def data_tf(x): x = np.array(x, dtype='float32') / 255 x = (x - 0.5) / 0.5 # 数据预处理,标准化 x = torch.from_numpy(x) x = x.unsqueeze(0) return x train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换 test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True) train_data = DataLoader(train_set, batch_size=64, shuffle=True) test_data = DataLoader(test_set, batch_size=128, shuffle=False) # 使用批标准化 class conv_bn_net(nn.Module): def __init__(self): super(conv_bn_net, self).__init__() self.stage1 = nn.Sequential( nn.Conv2d(1, 6, 3, padding=1), nn.BatchNorm2d(6), nn.ReLU(True), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.BatchNorm2d(16), nn.ReLU(True), nn.MaxPool2d(2, 2) ) self.classfy = nn.Linear(400, 10) def forward(self, x): x = self.stage1(x) x = x.view(x.shape[0], -1) x = self.classfy(x) return x net = conv_bn_net() optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1 train(net, train_data, test_data, 5, optimizer, criterion)
以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
广告合作:本站广告合作请联系QQ:858582 申请时备注:广告合作(否则不回)
免责声明:本站资源来自互联网收集,仅供用于学习和交流,请遵循相关法律法规,本站一切资源不代表本站立场,如有侵权、后门、不妥请联系本站删除!
免责声明:本站资源来自互联网收集,仅供用于学习和交流,请遵循相关法律法规,本站一切资源不代表本站立场,如有侵权、后门、不妥请联系本站删除!
暂无评论...
更新日志
2024年11月25日
2024年11月25日
- 凤飞飞《我们的主题曲》飞跃制作[正版原抓WAV+CUE]
- 刘嘉亮《亮情歌2》[WAV+CUE][1G]
- 红馆40·谭咏麟《歌者恋歌浓情30年演唱会》3CD[低速原抓WAV+CUE][1.8G]
- 刘纬武《睡眠宝宝竖琴童谣 吉卜力工作室 白噪音安抚》[320K/MP3][193.25MB]
- 【轻音乐】曼托凡尼乐团《精选辑》2CD.1998[FLAC+CUE整轨]
- 邝美云《心中有爱》1989年香港DMIJP版1MTO东芝首版[WAV+CUE]
- 群星《情叹-发烧女声DSD》天籁女声发烧碟[WAV+CUE]
- 刘纬武《睡眠宝宝竖琴童谣 吉卜力工作室 白噪音安抚》[FLAC/分轨][748.03MB]
- 理想混蛋《Origin Sessions》[320K/MP3][37.47MB]
- 公馆青少年《我其实一点都不酷》[320K/MP3][78.78MB]
- 群星《情叹-发烧男声DSD》最值得珍藏的完美男声[WAV+CUE]
- 群星《国韵飘香·贵妃醉酒HQCD黑胶王》2CD[WAV]
- 卫兰《DAUGHTER》【低速原抓WAV+CUE】
- 公馆青少年《我其实一点都不酷》[FLAC/分轨][398.22MB]
- ZWEI《迟暮的花 (Explicit)》[320K/MP3][57.16MB]