Closed
Description
def dropout(X, drop_prob):
X = X.float()
assert 0 <= drop_prob <= 1
keep_prob = 1 - drop_prob
# 这种情况下把全部元素都丢弃
if keep_prob == 0:
return torch.zeros_like(X)
mask = (torch.randn(X.shape) < keep_prob).float()
return mask * X / keep_prob
此处randn生成的数不是分布在0-1之间,导致drop_prob为0时也会产生丢弃。
X = torch.arange(16).view(2, 8)
dropout(X, 0)
tensor([[ 0., 0., 2., 3., 4., 5., 6., 7.],
[ 8., 9., 10., 0., 12., 13., 14., 15.]])
Metadata
Metadata
Assignees
Labels
No labels