diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index f597f0a769..f41268b6bf 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -250,6 +250,10 @@ def _is_equal_one(x): return False return bool(x.asnumpy().mean() == 1.0) +@constexpr +def _dtype_check(x_dtype): + if x_dtype not in [mstype.float32, mstype.float16]: + raise TypeError("The input type must be float32 or float16.") class ClipByNorm(Cell): r""" @@ -264,12 +268,11 @@ class ClipByNorm(Cell): where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`. Inputs: - - **input** (Tensor) - Tensor of shape N-D. - - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)` and of - the same type as the input Tensor. + - **input** (Tensor) - Tensor of shape N-D. The type should be float32 or float16. + - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`. Outputs: - Tensor, clipped tensor with the same shape as the input. + Tensor, clipped tensor with the same shape as the input, whose type is float32. Examples: >>> net = nn.ClipByNorm() @@ -300,10 +303,10 @@ class ClipByNorm(Cell): l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32) cond = self.greater_(l2sum, 0) ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) - l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum) + _dtype_check(self.dtype(x)) if _is_equal_one(clip_norm): intermediate = x else: diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index a241f6c848..1a7f891c8e 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -827,13 +827,3 @@ def get_bprop_unique(self): dx = op(dout, out) return (dx,) return bprop - - -@bprop_getters.register(P.UnsortedSegmentSum) -def get_bprop_unsorted_segment_sum(self): - """Generate bprop for UnsortedSegmentSum""" - op = G.UnsortedSegmentSumGrad() - def bprop(x, segment_ids, num_segments, out, dout): - dx = op(dout, segment_ids) - return (dx, zeros_like(segment_ids), zeros_like(num_segments)) - return bprop diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 1a10590a70..7940662f48 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -502,20 +502,6 @@ class UniqueGrad(Primitive): raise NotImplementedError -class UnsortedSegmentSumGrad(PrimitiveWithInfer): - """Gradients of UnsortedSegmentSum operation.""" - - @prim_attr_register - def __init__(self): - self.init_prim_io_names(inputs=['grads', 'ids'], outputs=['y']) - - def infer_shape(self, grads, ids): - return ids + grads[len(ids):] - - def infer_dtype(self, grads, ids): - return grads - - class BNTrainingReduceGrad(PrimitiveWithInfer): """Gradients of FusedBatchNorm operation.""" diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index d8775b7c28..371a9daacc 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -93,8 +93,12 @@ class BoundingBoxEncode(PrimitiveWithInfer): @prim_attr_register def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): - validator.check_value_type('means', means, [tuple], self.name) - validator.check_value_type('stds', stds, [tuple], self.name) + validator.check_value_type('means', means, [tuple, list], self.name) + validator.check_value_type('stds', stds, [tuple, list], self.name) + for i, value in enumerate(means): + validator.check_value_type("means[%d]" % i, value, [float], self.name) + for i, value in enumerate(stds): + validator.check_value_type("stds[%d]" % i, value, [float], self.name) validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) @@ -143,8 +147,12 @@ class BoundingBoxDecode(PrimitiveWithInfer): @prim_attr_register def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016): - validator.check_value_type('means', means, [tuple], self.name) - validator.check_value_type('stds', stds, [tuple], self.name) + validator.check_value_type('means', means, [tuple, list], self.name) + validator.check_value_type('stds', stds, [tuple, list], self.name) + for i, value in enumerate(means): + validator.check_value_type("means[%d]" % i, value, [float], self.name) + for i, value in enumerate(stds): + validator.check_value_type("stds[%d]" % i, value, [float], self.name) validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name) validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)