From e3d1e2f55b43d74e313ae5810e2945018414b2f3 Mon Sep 17 00:00:00 2001 From: yanzhenxiang2020 Date: Tue, 26 May 2020 09:57:48 +0800 Subject: [PATCH] add RNNTLoss and RandomCategorical op for aicpu --- mindspore/ops/_grad/grad_nn_ops.py | 12 ++++ mindspore/ops/_op_impl/aicpu/__init__.py | 2 + .../ops/_op_impl/aicpu/random_categorical.py | 48 +++++++++++++++ mindspore/ops/_op_impl/aicpu/rnnt_loss.py | 37 ++++++++++++ mindspore/ops/operations/__init__.py | 5 +- mindspore/ops/operations/nn_ops.py | 55 ++++++++++++++++++ mindspore/ops/operations/random_ops.py | 58 +++++++++++++++++++ .../test_aicpu_ops/test_random_categorical.py | 38 ++++++++++++ .../ascend/test_aicpu_ops/test_rnnt_loss.py | 43 ++++++++++++++ 9 files changed, 297 insertions(+), 1 deletion(-) create mode 100644 mindspore/ops/_op_impl/aicpu/random_categorical.py create mode 100644 mindspore/ops/_op_impl/aicpu/rnnt_loss.py create mode 100644 tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py create mode 100644 tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index c557301285..9f543c63cd 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -518,6 +518,18 @@ def get_bprop_l2_loss(self): return bprop +@bprop_getters.register(P.RNNTLoss) +def get_bprop_rnnt_loss(self): + """Grad definition for `RNNTLoss` operation.""" + expand = P.ExpandDims() + + def bprop(acts, labels, act_lens, label_lens, out, dout): + grad_loss = out[1] + grad = grad_loss * expand(expand(expand(dout[0], -1), -1), -1) + return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens) + return bprop + + @bprop_getters.register(P.PReLU) def get_bprop_prelu(self): """Grad definition for `PReLU` operation.""" diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index 5138d0f28c..bb490d050b 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -25,3 +25,5 @@ from .squeeze import _squeeze_aicpu from .expand_dims import _expand_dims_aicpu from .random_choice_with_mask import _random_choice_with_mask_aicpu from .ctcloss import _ctcloss_aicpu +from .rnnt_loss import _rnnt_loss_aicpu +from .random_categorical import _random_categorical_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/random_categorical.py b/mindspore/ops/_op_impl/aicpu/random_categorical.py new file mode 100644 index 0000000000..a0c6f64c97 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/random_categorical.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RandomCategorical op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +random_categorical_op_info = AiCPURegOp("RandomCategorical") \ + .fusion_type("OPAQUE") \ + .input(0, "logits", "required") \ + .input(1, "num_sample", "required") \ + .input(2, "seed", "required") \ + .output(0, "output", "required") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .get_op_info() + +@op_info_register(random_categorical_op_info) +def _random_categorical_aicpu(): + """RandomCategorical AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/rnnt_loss.py b/mindspore/ops/_op_impl/aicpu/rnnt_loss.py new file mode 100644 index 0000000000..d35d102048 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/rnnt_loss.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RNNTLoss op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +rnnt_loss_op_info = AiCPURegOp("RNNTLoss") \ + .fusion_type("OPAQUE") \ + .input(0, "acts", "required") \ + .input(1, "labels", "required") \ + .input(2, "input_lengths", "required") \ + .input(3, "label_lengths", "required") \ + .output(0, "costs", "required") \ + .output(1, "grads", "required") \ + .attr("blank_label", "int") \ + .dtype_format(DataType.F32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + +@op_info_register(rnnt_loss_op_info) +def _rnnt_loss_aicpu(): + """RNNTLoss AiCPU register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index af9e84685a..87601c5592 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -48,7 +48,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul Reciprocal, CumSum, Sin, Sqrt, Rsqrt, Square, Sub, TensorAdd, Sign, Round, SquareSumAll) -from .random_ops import (RandomChoiceWithMask) +from .random_ops import (RandomChoiceWithMask, RandomCategorical) from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, DepthwiseConv2dNative, @@ -63,6 +63,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, ResizeBilinear, Sigmoid, SigmoidCrossEntropyWithLogits, SmoothL1Loss, Softmax, Softplus, + RNNTLoss, SoftmaxCrossEntropyWithLogits, ROIAlign, SparseSoftmaxCrossEntropyWithLogits, Tanh, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, @@ -147,6 +148,7 @@ __all__ = [ 'HSigmoid', 'Tanh', 'RandomChoiceWithMask', + 'RandomCategorical', 'ResizeBilinear', 'ScalarSummary', 'ImageSummary', @@ -174,6 +176,7 @@ __all__ = [ 'SmoothL1Loss', 'L2Loss', 'CTCLoss', + 'RNNTLoss', 'ReduceAll', 'ScalarToArray', 'ScalarToTensor', diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index dcc5810105..4a19d0e113 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1599,6 +1599,61 @@ class L2Loss(PrimitiveWithInfer): return x_type +class RNNTLoss(PrimitiveWithInfer): + """ + Computes the RNNTLoss and its gradient with respect to the softmax outputs. + + Args: + blank_label (int): blank label. Default: 0. + + Inputs: + - **acts** (Tensor[float32]) - Tensor of shape :math:`(B, T, U, V)`. + - **labels** (Tensor[int32]) - Tensor of shape :math:`(B, N)`. + - **input_lengths** (Tensor[int32]) - Tensor of shape :math:`(B,)`. + - **label_lebgths** (Tensor[int32]) - Tensor of shape :math:`(B,)`. + + Outputs: + - **costs** (Tensor[int32]) - Tensor of shape :math:`(B,)`. + - **grads** (Tensor[int32]) - Has the same shape as `acts`. + + Examples: + >>> B, T, U, V = 1, 2, 3, 5 + >>> acts = np.random.random((B, T, U, V)).astype(np.float32) + >>> labels = np.array([[1, 2]]).astype(np.int32) + >>> input_length = np.array([T] * B).astype(np.int32) + >>> label_length = np.array([len(l) for l in labels]).astype(np.int32) + >>> rnnt_loss = P.RNNTLoss(blank_label=blank) + >>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) + """ + @prim_attr_register + def __init__(self, blank_label=0): + validator.check_value_type('blank_label', blank_label, [int], self.name) + self.init_prim_io_names(inputs=['acts', 'labels', 'input_length', 'label_length'], + outputs=['costs', 'grads']) + + def infer_shape(self, acts_shape, labels_shape, input_length_shape, label_length_shape): + validator.check_integer('acts_rank', len(acts_shape), 4, Rel.EQ, self.name) + validator.check_integer('labels_rank', len(labels_shape), 2, Rel.EQ, self.name) + validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name) + validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name) + validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) + validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) + validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) + costs_shape = (acts_shape[0],) + return (costs_shape, acts_shape) + + def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type): + validator.check_subclass("acts_type", acts_type, mstype.tensor, self.name) + validator.check_subclass("labels_type", labels_type, mstype.tensor, self.name) + validator.check_subclass("input_length_type", input_length_type, mstype.tensor, self.name) + validator.check_subclass("label_length_type", label_length_type, mstype.tensor, self.name) + validator.check_tensor_type_same({"acts_type": acts_type}, [mstype.float32], self.name) + validator.check_tensor_type_same({"labels_type": labels_type}, [mstype.int32], self.name) + validator.check_tensor_type_same({"input_length_type": input_length_type}, [mstype.int32], self.name) + validator.check_tensor_type_same({"label_length_type": label_length_type}, [mstype.int32], self.name) + return (acts_type, acts_type) + + class SGD(PrimitiveWithInfer): """ Computes stochastic gradient descent (optionally with momentum). diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 2692b43b46..77201c25f9 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -64,3 +64,61 @@ class RandomChoiceWithMask(PrimitiveWithInfer): def infer_dtype(self, x_dtype): validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name) return (mstype.int32, mstype.bool_) + + +class RandomCategorical(PrimitiveWithInfer): + """ + Generates random samples from a given categorical distribution tensor. + + Args: + dtype (mindspore.dtype): The type of output. Its value should be one of [mindspore.int16, + mindspore.int32, mindspore.int64]. Default: mindspore.int64. + + Inputs: + - **logits** (Tensor) - The input tensor. 2-D Tensor with shape [batch_size, num_classes]. + - **num_sample** (int) - Number of sample to be drawn. Only constant values is allowed. + - **seed** (int) - Random seed. Default: 0. + + Outputs: + - **output** (Tensor) - The output Tensor with shape [batch_size, num_samples]. + + Examples: + >>> class Net(nn.Cell): + >>> def __init__(self, num_sample): + >>> super(Net, self).__init__() + >>> self.random_categorical = P.RandomCategorical(mindspore.int64) + >>> self.num_sample = num_sample + >>> def construct(self, logits, seed=0): + >>> return self.random_categorical(logits, self.num_sample, seed) + >>> + >>> x = np.random.random((10, 5)).astype(np.float32) + >>> net = Net(8) + >>> output = net(Tensor(x)) + """ + @prim_attr_register + def __init__(self, dtype=mstype.int64): + """Init RandomCategorical""" + self.dtype = dtype + + valid_values = (mstype.int32, mstype.int16, mstype.int64) + validator.check_type_name("dtype", dtype, valid_values, self.name) + self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'], + outputs=['output']) + + def __infer__(self, logits, num_samples, seed): + logits_dtype = logits['dtype'] + valid_types = (mstype.float32, mstype.float16, mstype.float64) + validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name) + num_samples_v = num_samples['value'] + seed_v = seed['value'] + validator.check_value_type('num_samples', num_samples_v, (int,), self.name) + validator.check_value_type('seed', seed_v, (int,), self.name) + validator.check_integer("num_samples", num_samples_v, 0, Rel.GT, self.name) + x_shape = list(logits['shape']) + if len(x_shape) != 2: + raise ValueError("RandomCategorical shape should be 2-dimension.") + ndim = len(x_shape) - 1 + x_shape[ndim] = num_samples_v + return {'shape': (x_shape), + 'dtype': (self.dtype), + 'value': None} diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py b/tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py new file mode 100644 index 0000000000..6304e8b111 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +class Net(nn.Cell): + def __init__(self, num_sample): + super(Net, self).__init__() + self.random_categorical = P.RandomCategorical(mindspore.int64) + self.num_sample = num_sample + + def construct(self, logits, seed=0): + return self.random_categorical(logits, self.num_sample, seed) + +def test_net(): + x = np.random.random((10, 5)).astype(np.float32) + net = Net(8) + output = net(Tensor(x)) + print(x) + print(output.asnumpy()) + print(output.dtype()) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py b/tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py new file mode 100644 index 0000000000..c7e2df07f8 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore as ms +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.rnnt_loss = P.RNNTLoss(blank_label=0) + + def construct(self, acts, labels, act_lens, label_lens): + return self.rnnt_loss(acts, labels, act_lens, label_lens) + + +def test_net(): + B, T, U, V = 1, 2, 3, 5 + acts = np.random.random((B, T, U, V)).astype(np.float32) + labels = np.array([[np.random.randint(1, V-1) for _ in range(U-1)]]).astype(np.int32) + input_length = np.array([T] * B).astype(np.int32) + label_length = np.array([len(l) for l in labels]).astype(np.int32) + + rnnt_loss = Net() + costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) + print(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) + print(costs.asnumpy()) + print(grads.asnumpy())