diff --git a/mindspore/ccsrc/kernel/gpu/nn/ctcloss_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/ctcloss_gpu_kernel.cc new file mode 100644 index 0000000000..355d238ab4 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/ctcloss_gpu_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/nn/ctcloss_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(CTCLossV2, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CtcLossGpuKernel, float) + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/ctcloss_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/ctcloss_gpu_kernel.h new file mode 100644 index 0000000000..2bd83b3176 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/ctcloss_gpu_kernel.h @@ -0,0 +1,166 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ + +#include +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "device/gpu/gpu_memory_allocator.h" + +namespace mindspore { +namespace kernel { +template +class CtcLossGpuKernel : public GpuKernel { + public: + CtcLossGpuKernel() + : cudnn_handle_(nullptr), + probs_desc_(nullptr), + ctcloss_desc_(nullptr), + label_size_(0), + input_lengths_size_(0), + label_lengths_size_(0) {} + ~CtcLossGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + float *probs = GetDeviceAddress(inputs, 0); + int *labels = GetDeviceAddress(inputs, 1); + int *input_lengths = GetDeviceAddress(inputs, 2); + int *label_lengths = GetDeviceAddress(inputs, 3); + float *costs = GetDeviceAddress(outputs, 0); + float *grads = GetDeviceAddress(outputs, 1); + + // Copy labels/input_lengths/label_length to host as cudnn7.x.x requires + void *labels_host = nullptr; + void *input_lengths_host = nullptr; + void *label_lengths_host = nullptr; + CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&labels_host, inputs[1]->size), "cudaMallocHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&input_lengths_host, inputs[2]->size), "cudaMallocHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&label_lengths_host, inputs[3]->size), "cudaMallocHost failed."); + cudaStream_t stream = reinterpret_cast(stream_ptr); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(labels_host, labels, inputs[1]->size, cudaMemcpyDeviceToHost, stream), + "cudaMemcpyAsync failed."); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(input_lengths_host, input_lengths, inputs[2]->size, cudaMemcpyDeviceToHost, stream), + "cudaMemcpyAsync failed."); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(label_lengths_host, label_lengths, inputs[3]->size, cudaMemcpyDeviceToHost, stream), + "cudaMemcpyAsync failed."); + + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); + size_t workspace_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetCTCLossWorkspaceSize(cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast(labels_host), + reinterpret_cast(label_lengths_host), + reinterpret_cast(input_lengths_host), CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, + ctcloss_desc_, &workspace_size), + "cudnnGetCTCLossWorkspaceSize failed."); + void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size); + if (workspace == nullptr) { + MS_LOG(EXCEPTION) << "Failed to alloc workspace, size: " << workspace_size; + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast(labels_host), + reinterpret_cast(label_lengths_host), reinterpret_cast(input_lengths_host), costs, + probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size), + "cudnnCtcLoss failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); + + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(workspace); + CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed."); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (probs_shape.size() != 3) { + MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support."; + } + probs_dims_[0] = probs_shape[0]; + probs_dims_[1] = probs_shape[1]; + probs_dims_[2] = probs_shape[2]; + + auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + if (labels_dims.size() != 1 && labels_dims.size() != 2) { + MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support."; + } + label_size_ = sizeof(int); + for (auto i : labels_dims) { + label_size_ *= i; + } + + auto input_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + input_lengths_size_ = input_length_dims[0] * sizeof(int); + auto label_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + label_lengths_size_ = label_length_dims[0] * sizeof(int); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(probs_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 3, probs_dims_), + "cudnnSetTensorNdDescriptorEx failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetCTCLossDescriptorEx(ctcloss_desc_, CUDNN_DATA_FLOAT, + CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN), + "cudnnSetCTCLossDescriptorEx failed."); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&probs_desc_), "cudnnCreateTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateCTCLossDescriptor(&ctcloss_desc_), "cudnnCreateCTCLossDescriptor failed."); + } + + void InitSizeLists() override { + input_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float)); + input_size_list_.push_back(label_size_); + input_size_list_.push_back(input_lengths_size_); + input_size_list_.push_back(label_lengths_size_); + + output_size_list_.push_back(probs_dims_[1] * sizeof(float)); + output_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float)); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyCTCLossDescriptor(ctcloss_desc_), "cudnnDestroyCTCLossDescriptor failed."); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(probs_desc_), "cudnnDestroyTensorDescriptor failed."); + } + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t probs_desc_; + cudnnCTCLossDescriptor_t ctcloss_desc_; + int probs_dims_[3] = {0}; + int label_size_; + int input_lengths_size_; + int label_lengths_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.cc b/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.cc index 77cb7f8608..99af1add46 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.cc @@ -47,6 +47,5 @@ MS_REG_GPU_KERNEL_ONE(FusedAdam, .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), FusedAdamWeightDecayGpuKernel, float) - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 107de1768c..dff925bb86 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -701,6 +701,19 @@ def get_bprop_ctc_loss(self): return bprop +@bprop_getters.register(P.CTCLossV2) +def get_bprop_ctc_loss_v2(self): + """Grad definition for `CTCLossV2` operation""" + expand = P.ExpandDims() + + def bprop(inputs, labels, input_lengths, labels_lengths, out, dout): + grad_loss = out[1] + grad = grad_loss * expand(dout[0], -1) + return grad, zeros_like(labels), zeros_like(input_lengths), zeros_like(labels_lengths) + + return bprop + + @bprop_getters.register(P.BasicLSTMCell) def get_bprop_basic_lstm_cell(self): """Grad definition for `BasicLSTMCell` operation.""" diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index bc4edce193..487ca3dce7 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -61,7 +61,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl DropoutDoMask, DropoutGrad, Dropout, DropoutGenMask, Flatten, FusedBatchNorm, BNTrainingReduce, BNTrainingUpdate, Gelu, Elu, - GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, + GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, LogSoftmax, MaxPool, DataFormatDimMap, AvgPool, Conv2DBackpropInput, ConfusionMulGrad, diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index c07f072f38..70cfe45978 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -4765,3 +4765,56 @@ class LRN(PrimitiveWithInfer): def infer_shape(self, x_shape): return x_shape + +class CTCLossV2(PrimitiveWithInfer): + r""" + Calculates the CTC(Connectionist Temporal Classification) loss. Also calculates the gradient. + Note: + - Cudnn Uses label value of for the `blank` + + Inputs: + - **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is + :math:`(max_time, batch_size, num_class)`. `num_class` should be `num_labels + 1` classes, `num_labels` + indicates the number of actual labels. Blank labels are reserved. + - **labels** (Tensor) - The labels Tensor should be a `1-D` tensor whose shape is + :math:`(\sigma{label_lengths})` + or `2-D` tensor whose shape is + :math:`(max_time, max{label_lengths})` + The type must be int32. + - **input_lengths** (Tensor) - A `1-D` input tensor whose shape is + :math:`(batch_size,)`. The values should be batch. The type must be int32. + - **label_lengths** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`. + The type must be int32. Each value in the tensor should not greater than `max_time`. + + Outputs: + - **loss** (Tensor) - A tensor containing log-probabilities, the shape is :math:`(batch_size)`. Has the same + type with `inputs`. + - **gradient** (Tensor) - The gradient of `loss`. Has the same type and shape with `inputs`. + + Examples: + >>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32) + >>> labels = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32) + >>> input_lengths = Tensor(np.array([3, 3, 3]), mindspore.int32) + >>> label_lengths = Tensor(np.array([3, 3, 3]), mindspore.int32) + >>> ctc_loss = P.CTCLossV2() + >>> output = ctc_loss(inputs, labels, input_lengths, label_lengths) + """ + @prim_attr_register + def __init__(self): + pass + + def infer_dtype(self, input_dtype, labels_dtype, input_lengths_dtype, label_lengths_dtype): + validator.check_tensor_type_same({"input": input_dtype}, (mstype.float32,), self.name) + validator.check_tensor_type_same({"labels": labels_dtype}, (mstype.int32,), self.name) + validator.check_tensor_type_same({"input_lengths": input_lengths_dtype}, (mstype.int32,), self.name) + validator.check_tensor_type_same({"target_lengths": label_lengths_dtype}, (mstype.int32,), self.name) + return mstype.float32, mstype.float32 + + def infer_shape(self, input_shape, labels_shape, input_lengths_shape, label_lengths_shape): + validator.check_integer("input shape", len(input_shape), 3, Rel.EQ, self.name) + validator.check_number_range("labels shape", len(labels_shape), 1, 2, Rel.INC_BOTH, self.name) + validator.check_integer("input lengths shape", len(input_lengths_shape), 1, Rel.EQ, self.name) + validator.check_integer("label lengths shape", len(label_lengths_shape), 1, Rel.EQ, self.name) + validator.check_integer("input[1]", input_shape[1], input_lengths_shape[0], Rel.EQ, self.name) + validator.check_integer("input[1]", input_shape[1], label_lengths_shape[0], Rel.EQ, self.name) + return (input_shape[1],), input_shape diff --git a/tests/st/ops/gpu/test_ctcloss_op.py b/tests/st/ops/gpu/test_ctcloss_op.py new file mode 100644 index 0000000000..b9a88e7e70 --- /dev/null +++ b/tests/st/ops/gpu/test_ctcloss_op.py @@ -0,0 +1,119 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype +from mindspore.ops.composite import GradOperation + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.loss = P.CTCLossV2() + self.div = P.RealDiv() + self.cast = P.Cast() + self.mean = P.ReduceMean() + + def construct(self, probs, label, input_length, label_length): + x, _ = self.loss(probs, label, input_length, label_length) + x = self.div(x, self.cast(label_length, mstype.float32)) + x = self.mean(x) + return x + +class GradData(nn.Cell): + def __init__(self, network): + super(GradData, self).__init__() + self.grad = GradOperation(name="get_all", get_all=True, sens_param=False) + self.network = network + + def construct(self, probs, labels, input_lengths, label_lengths): + return self.grad(self.network)(probs, labels, input_lengths, label_lengths) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ctcloss(): + probs = Tensor([[[-4.4131, -4.6093, -3.4333, -3.9268, -2.8917, -3.4093, -4.2243, -1.1379, -7.1046, -0.6902], + [-2.5109, -3.3397, -4.9384, -1.2723, -1.1443, -2.4683, -2.6768, -4.1282, -2.7062, -3.1906], + [-2.5092, -1.6392, -2.0864, -4.0059, -1.5610, -2.3223, -2.4816, -2.9922, -3.1412, -2.3311]], + + [[-2.1243, -3.5773, -3.1108, -4.4253, -2.7080, -1.9653, -2.0499, -2.4418, -1.8620, -1.5229], + [-2.2479, -3.5128, -1.4189, -2.8701, -1.8562, -2.2752, -2.7019, -2.1865, -2.5634, -2.9869], + [-3.2144, -1.3986, -3.1083, -3.9634, -3.5131, -3.2317, -2.6200, -1.7938, -1.8159, -1.7255]], + + [[-3.1301, -2.1649, -0.9286, -2.9452, -2.5992, -2.0263, -2.9201, -3.2155, -2.8302, -3.3636], + [-1.4661, -3.6311, -2.4781, -4.6180, -2.7308, -1.7019, -1.5570, -2.6012, -4.0788, -2.3073], + [-2.6833, -1.5033, -3.6922, -2.6360, -2.6974, -2.6847, -2.7579, -2.1396, -1.4093, -2.9630]], + + [[-2.0094, -2.3024, -3.3673, -1.0220, -2.8326, -2.2613, -3.0535, -2.9879, -3.7015, -2.4510], + [-1.9071, -3.2603, -2.3229, -2.0572, -4.3450, -2.1284, -2.6306, -1.3824, -2.9815, -2.5061], + [-2.7931, -3.7631, -3.2440, -4.3887, -1.0271, -3.8851, -1.2418, -4.5123, -2.2993, -2.4607]], + + [[-1.5763, -2.7539, -3.6941, -3.8166, -1.2599, -2.6903, -2.5826, -4.8208, -2.9562, -1.6321], + [-3.3031, -3.0087, -1.9982, -1.9081, -3.8731, -2.8764, -2.2485, -2.3808, -1.4283, -2.1625], + [-2.4516, -3.2394, -4.2053, -4.3541, -2.5229, -4.0717, -1.4894, -2.3151, -1.1098, -2.3465]]], + dtype=mstype.float32) + labels = Tensor([9, 4, 6, 4, 7, 1, 4, 6, 6, 8], dtype=mstype.int32) + input_lengths = Tensor([5, 5, 5], dtype=mstype.int32) + label_lengths = Tensor([3, 3, 4], dtype=mstype.int32) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = Net() + ctc_loss = net(probs, labels, input_lengths, label_lengths) + expect_loss = [2.4099] + assert np.allclose(ctc_loss.asnumpy(), expect_loss) + + grad = GradData(net)(probs, labels, input_lengths, label_lengths) + expect_grad = [[[8.8442e-05, 1.1065e-03, 3.5867e-03, 2.1896e-03, 6.1646e-03, + 3.6738e-03, 1.6262e-03, 3.5610e-02, 9.1258e-05, -5.4134e-02], + [-3.7523e-03, 3.9386e-03, 7.9623e-04, 3.1132e-02, -6.2954e-02, + 9.4143e-03, 7.6425e-03, 1.7902e-03, 7.4211e-03, 4.5719e-03], + [6.7778e-03, 1.6178e-02, 1.0344e-02, 1.5173e-03, -6.5840e-02, + 8.1707e-03, 6.9674e-03, 4.1814e-03, 3.6026e-03, 8.0991e-03]], + + [[-1.2581e-02, 3.1057e-03, 4.9517e-03, 1.3301e-03, -2.6320e-02, + 1.5568e-02, 1.4305e-02, 9.6671e-03, 1.7262e-02, -2.7292e-02], + [-1.5566e-02, 3.3126e-03, 2.6887e-02, 6.2993e-03, -3.9716e-02, + 1.1420e-02, 7.4531e-03, -1.4252e-02, 8.5603e-03, 5.6048e-03], + [3.3483e-03, 2.0579e-02, 3.7231e-03, 1.5832e-03, 2.4837e-03, + 3.2909e-03, -7.7267e-02, 1.3861e-02, 1.3558e-02, 1.4840e-02]], + + [[-8.0007e-03, 1.2751e-02, 4.3901e-02, 5.8435e-03, -7.2627e-02, + 1.4647e-02, -8.0584e-03, 4.4595e-03, 6.5557e-03, 5.2891e-04], + [-3.6006e-02, 1.5308e-03, 9.3225e-03, 1.0969e-03, -2.5098e-03, + 2.0260e-02, 2.3419e-02, -3.0053e-02, 1.8809e-03, 1.1059e-02], + [-7.7639e-02, 1.8533e-02, 2.0764e-03, 5.9706e-03, 5.6150e-03, + 5.6868e-03, 5.2854e-03, 9.8085e-03, 2.0360e-02, 4.3053e-03]], + + [[-2.6776e-02, 1.1113e-02, 3.8314e-03, 3.9986e-02, -1.6020e-02, + 1.1579e-02, -4.1635e-02, 5.5992e-03, 2.7429e-03, 9.5786e-03], + [-6.8619e-03, -6.4066e-03, 1.0888e-02, 1.4201e-02, 1.4413e-03, + 1.3225e-02, 8.0039e-03, -4.9191e-02, 5.6352e-03, 9.0651e-03], + [5.1026e-03, 1.9343e-03, 3.2506e-03, 1.0347e-03, 2.9837e-02, + 1.7121e-03, -5.9261e-02, 9.1443e-04, 8.3608e-03, 7.1146e-03]], + + [[-2.0848e-02, 7.0754e-03, 2.7633e-03, 2.4447e-03, 3.1520e-02, + 7.5401e-03, -5.8895e-02, 8.9559e-04, 5.7796e-03, 2.1724e-02], + [-1.3499e-03, -1.0019e-01, 1.5064e-02, 1.6485e-02, 2.3104e-03, + 6.2597e-03, 1.1729e-02, 1.0275e-02, 2.6635e-02, 1.2782e-02], + [7.1796e-03, 3.2656e-03, 1.2430e-03, 1.0712e-03, 6.6856e-03, + 1.4207e-03, 1.8792e-02, 8.2297e-03, -5.5865e-02, 7.9753e-03]]] + assert np.allclose(grad[0].asnumpy(), expect_grad, atol=1e-5)