Skip to content

关于3.13章丢弃法的疑惑 #17

Closed
@VittorioYan

Description

@VittorioYan
def dropout(X, drop_prob):
    X = X.float()
    assert 0 <= drop_prob <= 1
    keep_prob = 1 - drop_prob
    # 这种情况下把全部元素都丢弃
    if keep_prob == 0:
        return torch.zeros_like(X)
    mask = (torch.randn(X.shape) < keep_prob).float()
    
    return mask * X / keep_prob

此处randn生成的数不是分布在0-1之间,导致drop_prob为0时也会产生丢弃。

X = torch.arange(16).view(2, 8)
dropout(X, 0)
tensor([[ 0.,  0.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10.,  0., 12., 13., 14., 15.]])

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions