'''
求数据集的均值和标准差
:param loader:
:return:
'''
data_sum,data_squared_sum,num_batches = 0,0,0
for data,_ in loader:
# data: [batch_size,channels,height,width]
# 计算 dim=0,2,3 维度的均值和,dim=1 为通道数量,不用参与计算
data_sum += torch.mean(data,dim=[0,2,3]) # [batch_size,channels,height,width]
# 计算 dim=0,2,3 维度的平方均值和,dim=1 为通道数量,不用参与计算
data_squared_sum += torch.mean(data**2,dim=[0,2,3]) # [batch_size,channels,height,width]
# 统计 batch 的数量
num_batches += 1
# 计算均值
mean = data_sum/num_batches
# 计算标准差
std = (data_squared_sum/num_batches - mean**2)**0.5
return mean,std
为什么可以这样计算均值,从这个代码中我的到一个结论:"每个样本均值的和/样本数=整体数据的均值"
有点不太理解这个东西,有大佬能用数学公式证明一下吗
简单说明一下数据情况:这是 CIFAR10 数据集,每个样本的结构是( batch_size,channels,height,width),
即(样本数量,RGB 通道,图片高度,图片宽度)