diff --git a/.github/install_mindspore.py b/.github/install_mindspore.py index 8c79ecad1..7e2e264e1 100644 --- a/.github/install_mindspore.py +++ b/.github/install_mindspore.py @@ -4,7 +4,7 @@ import platform def gen_url(os_name, py_version): - hf_url = 'https://hf-mirror.com/lvyufeng/mindspore-daily/resolve/main/' + hf_url = 'https://hf.co/lvyufeng/mindspore-daily/resolve/main/' whl_name = 'mindspore-newest-cp{}-cp{}-{}.whl' py_version = py_version.replace('.', '') diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index 133d68be4..26ed6677a 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -223,12 +223,108 @@ def pad(input, pad, mode='constant', value=0.0): return ops.pad(input, pad, mode, value) def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): - # _nll_loss = _get_cache_prim(ops.NLLLoss)(reduction, ignore_index) - # return _nll_loss(input, target, weight) - return ops.nll_loss(input, target, weight, ignore_index, reduction, label_smoothing) + if label_smoothing != 0.0 or target.ndim != 1: + return _inner_nll_loss(input, target, weight, ignore_index, reduction, label_smoothing) + if weight is None: + weight = ops.ones(input.shape[-1], dtype=input.dtype) + _nll_loss = _get_cache_prim(ops.NLLLoss)(reduction, ignore_index) + return _nll_loss(input, target, weight)[0] def cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): - return ops.cross_entropy(input, target, weight, ignore_index, reduction, label_smoothing) + class_dim = 0 if input.ndim == 1 else 1 + if target.dtype in [mindspore.float32, mindspore.float16]: + return _cross_entropy(input, target, class_dim, weight, reduction, label_smoothing) + return nll_loss(log_softmax(input, class_dim), target, weight, ignore_index, reduction, label_smoothing) + + +def _cross_entropy(inputs, target, target_dim, weight=None, reduction='mean', label_smoothing=0.0): + """cross entropy inner function""" + class_dim = 0 if inputs.ndim == 1 else 1 + n_classes = inputs.shape[class_dim] + inputs = log_softmax(inputs, class_dim) + if label_smoothing > 0.0: + target = target * (1 - label_smoothing) + label_smoothing / n_classes + + if weight is None: + weight = ops.ones_like(inputs) + elif inputs.ndim != 1: + broadcast_shape = [1 for _ in range(inputs.ndim)] + broadcast_shape[1] = weight.shape[0] + weight = weight.reshape(broadcast_shape) + + if reduction == 'mean': + return -(inputs * target * weight).sum() / (inputs.size / n_classes) + if reduction == 'sum': + return -(inputs * target * weight).sum() + return -(inputs * target * weight).sum(class_dim) + + +def _inner_nll_loss(inputs, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): + ndim = inputs.ndim + if ndim == 2: + ret = _nll_loss(inputs, target, -1, weight, ignore_index, reduction, label_smoothing) + elif ndim == 4: + ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing) + elif ndim == 1: + ret = _nll_loss(inputs, target, 0, weight, ignore_index, reduction, label_smoothing) + else: + n = inputs.shape[0] + c = inputs.shape[1] + out_size = (n,) + inputs.shape[2:] + inputs = inputs.view((n, c, 1, -1)) + target = target.view((n, 1, -1)) + if reduction != 'none': + ret = _nll_loss(inputs, target, 1, weight, ignore_index, reduction, label_smoothing) + else: + ret = _nll_loss(inputs, target, 1, weight, ignore_index, label_smoothing=label_smoothing) + ret = ret.view(out_size) + return ret + + +def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0): + """nll loss inner function""" + if target.ndim == inputs.ndim - 1: + target = target.expand_dims(target_dim) + if ignore_index is not None: + non_pad_mask = ops.equal(target, ignore_index) + target = target.masked_fill(non_pad_mask, ops.cast(0, target.dtype)) + else: + non_pad_mask = target + if weight is not None: + loss_weights = ops.gather(weight, target, 0) + orig_shape = inputs.shape + if inputs.ndim != 2: + inputs = inputs.view(orig_shape[:2] + (-1,)) + weight = weight.view(weight.shape + (1,)) + weighted_inputs = inputs * weight + weighted_inputs = weighted_inputs.view(orig_shape) + loss = ops.neg(ops.gather_d(weighted_inputs, target_dim, target)) + smooth_loss = ops.neg(weighted_inputs.sum(axis=target_dim, keepdims=True)) + else: + loss = ops.neg(ops.gather_d(inputs, target_dim, target)) + smooth_loss = ops.neg(inputs.sum(axis=target_dim, keepdims=True)) + loss_weights = ops.ones_like(loss) + + if ignore_index is not None: + loss = loss.masked_fill(non_pad_mask, ops.cast(0, loss.dtype)) + loss_weights = loss_weights.masked_fill(non_pad_mask, ops.cast(0, loss_weights.dtype)) + smooth_loss = smooth_loss.masked_fill(non_pad_mask, ops.cast(0, smooth_loss.dtype)) + + loss = loss.squeeze(target_dim) + smooth_loss = smooth_loss.squeeze(target_dim) + + if reduction == 'sum': + loss = loss.sum() + smooth_loss = smooth_loss.sum() + if reduction == 'mean': + loss = loss.sum() / loss_weights.sum() + smooth_loss = smooth_loss.sum() / loss_weights.sum() + + eps_i = label_smoothing / inputs.shape[target_dim] + if label_smoothing != 0: + loss = (1. - label_smoothing) * loss + eps_i * smooth_loss + + return loss def mse_loss(input, target, reduction='mean'): return ops.mse_loss(input, target, reduction) diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py index 5e0431b49..ca87dbe35 100644 --- a/mindnlp/core/ops/reduction.py +++ b/mindnlp/core/ops/reduction.py @@ -137,6 +137,7 @@ def sum(input, dim=None, keepdim=False, *, dtype=None): def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None): if USE_PYBOOST: return mindspore.mint.unique(input, sorted, return_inverse, return_counts, dim) + out, inverse = ops.unique(input) outs = (out,) if return_inverse: diff --git a/mindnlp/transformers/models/wav2vec2/modeling_wav2vec2.py b/mindnlp/transformers/models/wav2vec2/modeling_wav2vec2.py index 73b795581..ab03d970b 100644 --- a/mindnlp/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/mindnlp/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1683,7 +1683,7 @@ def forward( neg_is_pos = (quantized_features == negative_quantized_features).all(-1) if neg_is_pos.any(): - logits[1:][neg_is_pos] = float("-inf") + logits[1:][neg_is_pos] = float(ops.finfo(logits.dtype).min) # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) = # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa)) @@ -1694,7 +1694,6 @@ def forward( # 7. compute diversity loss: \mathbf{L}_d num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum() - # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss diff --git a/tests/ut/transformers/models/wav2vec2/test_modeling_wav2vec2.py b/tests/ut/transformers/models/wav2vec2/test_modeling_wav2vec2.py index ea6268be1..48ff4eca2 100644 --- a/tests/ut/transformers/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/ut/transformers/models/wav2vec2/test_modeling_wav2vec2.py @@ -579,17 +579,17 @@ def test_initialization(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.fill_(3) + nn.init.constant_(module.weight, 3) if hasattr(module, "weight_g") and module.weight_g is not None: - module.weight_g.fill_(3) + nn.init.constant_(module.weight_g, 3) if hasattr(module, "weight_v") and module.weight_v is not None: - module.weight_v.fill_(3) + nn.init.constant_(module.weight_v, 3) if hasattr(module, "bias") and module.bias is not None: - module.bias.fill_(3) + nn.init.constant_(module.bias, 3) if hasattr(module, "codevectors") and module.codevectors is not None: - module.codevectors.fill_(3) + nn.init.constant_(module.codevectors, 3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: - module.masked_spec_embed.fill_(3) + nn.init.constant_(module.masked_spec_embed, 3) def test_mask_feature_prob_ctc(self): model = Wav2Vec2ForCTC.from_pretrained( @@ -771,17 +771,17 @@ def test_initialization(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.fill_(3) + nn.init.constant_(module.weight, 3) if hasattr(module, "weight_g") and module.weight_g is not None: - module.weight_g.fill_(3) + nn.init.constant_(module.weight_g, 3) if hasattr(module, "weight_v") and module.weight_v is not None: - module.weight_v.fill_(3) + nn.init.constant_(module.weight_v, 3) if hasattr(module, "bias") and module.bias is not None: - module.bias.fill_(3) + nn.init.constant_(module.bias, 3) if hasattr(module, "codevectors") and module.codevectors is not None: - module.codevectors.fill_(3) + nn.init.constant_(module.codevectors, 3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: - module.masked_spec_embed.fill_(3) + nn.init.constant_(module.masked_spec_embed, 3) def test_model_for_pretraining(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -822,6 +822,7 @@ def test_model_for_pretraining(self): sampled_negative_indices=sampled_negative_indices, ).loss + print(loss, loss_more_masked) # loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted self.assertTrue(loss.item() <= loss_more_masked.item()) @@ -1166,7 +1167,7 @@ def test_sample_negatives(self): self.assertTrue(((negative - features) == 0).sum() == 0.0) # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim - self.assertEqual(ops.unique(negatives, dim=-1).shape, (num_negatives, batch_size, sequence_length, 1)) + self.assertEqual(ops.unique_consecutive(negatives, dim=-1).shape, (num_negatives, batch_size, sequence_length, 1)) def test_sample_negatives_with_mask(self): batch_size = 2 @@ -1204,7 +1205,7 @@ def test_sample_negatives_with_mask(self): self.assertTrue(((negative - features) == 0).sum() == 0.0) # make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim - self.assertEqual(ops.unique(negatives, dim=-1).shape, (num_negatives, batch_size, sequence_length, 1)) + self.assertEqual(ops.unique_consecutive(negatives, dim=-1).shape, (num_negatives, batch_size, sequence_length, 1)) @require_mindspore