有大佬对计算数据的均值和方差了解么

查看 35|回复 4
作者:bler   
def get_mean_std_value(loader):
    '''
    求数据集的均值和标准差
    :param loader:
    :return:
    '''
    data_sum,data_squared_sum,num_batches = 0,0,0
    for data,_ in loader:
        # data: [batch_size,channels,height,width]
        # 计算 dim=0,2,3 维度的均值和,dim=1 为通道数量,不用参与计算
        data_sum += torch.mean(data,dim=[0,2,3])    # [batch_size,channels,height,width]
        # 计算 dim=0,2,3 维度的平方均值和,dim=1 为通道数量,不用参与计算
        data_squared_sum += torch.mean(data**2,dim=[0,2,3])  # [batch_size,channels,height,width]
        # 统计 batch 的数量
        num_batches += 1
    # 计算均值
    mean = data_sum/num_batches
    # 计算标准差
    std = (data_squared_sum/num_batches - mean**2)**0.5
    return mean,std
为什么可以这样计算均值,从这个代码中我的到一个结论:"每个样本均值的和/样本数=整体数据的均值"
有点不太理解这个东西,有大佬能用数学公式证明一下吗
简单说明一下数据情况:这是 CIFAR10 数据集,每个样本的结构是( batch_size,channels,height,width),
即(样本数量,RGB 通道,图片高度,图片宽度)

均值, 方差, 标准差

dji38838c   
这个很显然呀。
比如:假如一共有 300 个数据(a1, a2,... a300),分成 100 组,每组 3 个。
那么 [(a1+a2+a3)/3 + (a4+a5+a6)/3 + .... (a298+a299+a300) / 3] / 100
可以整理成 [(a1+a2+a3+...+a300)/3] / 100 = (a1+a2+...+a300)/300
NessajCN   
这里能这么算的前提是每个样本的采样数量,也就是计算始终用来计算 torch.mean() 的分母,都是一样的才成立
bler
OP
  
@dji38838c 我发现这个方法还是存在很大的问题的,这种方法只适用于"总数据量/batch_size=整数"这种情况下的计算出的结果才能成立。假设最后的数据恰好是一个异常数据,那么通过这个计算方式计算出来的均值就是有极大异常的均值
Eureka0   
这个只对每个样本的样本容量都一样的情况成立,其实就是
均值=求和(样本均值 i*样本容量 i)/求和(样本容量 i)
样本容量都一样就可以约掉了
您需要登录后才可以回帖 登录 | 立即注册

返回顶部