From b3b06169c694d4a1c56278d1e3a445741fb36326 Mon Sep 17 00:00:00 2001 From: TFbunny Date: Mon, 21 Dec 2020 16:32:28 -0500 Subject: [PATCH] add validator check type for ssl --- mindspore/nn/loss/loss.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index 5c314877d5..2d8abb0265 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -281,6 +281,10 @@ class SoftmaxCrossEntropyWithLogits(_Loss): x = self.softmax_cross_entropy(logits, labels)[0] return self.get_loss(x) +@constexpr +def _check_label_dtype(labels_dtype, cls_name): + validator.check_type_name("labels", labels_dtype, [mstype.int32, mstype.int64], cls_name) + class SampledSoftmaxLoss(_Loss): r""" @@ -373,8 +377,11 @@ class SampledSoftmaxLoss(_Loss): self.zeros_like = P.ZerosLike() self.mul = P.Mul() self.expand_dims = P.ExpandDims() + self.dtype = P.DType() def construct(self, weights, biases, labels, inputs): + _check_label_dtype(self.dtype(labels), self.cls_name) + logits, labels = self._compute_sampled_logits( weights=weights, biases=biases, @@ -424,6 +431,7 @@ class SampledSoftmaxLoss(_Loss): `[batch_size, num_true + num_sampled]` out_labels: A Tensor object with the same shape as `out_logits`. """ + if not labels.dtype == mstype.int32: labels = self.cast(labels, mstype.int32) labels = self.reshape(labels, (-1, num_true))