diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.cc new file mode 100644 index 0000000000..62acfb46fd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + TensorDot, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + TensorDotGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + TensorDot, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + TensorDotGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h new file mode 100644 index 0000000000..3e5b022fcd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h @@ -0,0 +1,212 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TENSORDOT_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TENSORDOT_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace kernel { +template +class TensorDotGpuKernel : public GpuKernel { + public: + TensorDotGpuKernel() + : batch_(0), + m_(0), + n_(0), + k_(0), + is_null_input_(false), + handle_(nullptr), + dtype_a_(CUDA_R_32F), + dtype_b_(CUDA_R_32F), + dtype_c_(CUDA_R_32F), + algo_(CUBLAS_GEMM_DEFAULT) {} + ~TensorDotGpuKernel() = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *x1_input = GetDeviceAddress(inputs, 0); + T *x2_input = GetDeviceAddress(inputs, 1); + size_t *x1_input_shape = GetDeviceAddress(workspace, 0); + size_t *x2_input_shape = GetDeviceAddress(workspace, 1); + size_t *x1_input_trans_axes = GetDeviceAddress(workspace, 2); + size_t *x2_input_trans_axes = GetDeviceAddress(workspace, 3); + // transposed interim values moved to workspace, then multiplied + T *x1_reshape = GetDeviceAddress(workspace, 4); + T *x2_reshape = GetDeviceAddress(workspace, 5); + T *output_addr = GetDeviceAddress(outputs, 0); + + // Transpose X1 + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(x1_input_shape, &x1_input_shape_[0], x1_input_shape_.size() * sizeof(size_t), + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync x1_input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(x1_input_trans_axes, &x1_transpose_fwd_[0], x1_input_shape_.size() * sizeof(size_t), + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_axis_x1 failed"); + int size_x1 = SizeToInt(input_size_x1_ / sizeof(T)); + CalTranspose(size_x1, x1_input, x1_input_shape, x1_input_trans_axes, SizeToInt(x1_input_shape_.size()), x1_reshape, + reinterpret_cast(stream_ptr)); + + // Transpose X2 + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(x2_input_shape, &x2_input_shape_[0], (x2_input_shape_.size() * sizeof(size_t)), + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync x2_input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(x2_input_trans_axes, &x2_transpose_fwd_[0], (x2_input_shape_.size() * sizeof(size_t)), + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_axis_x2 failed"); + int size_x2 = SizeToInt(input_size_x2_ / sizeof(T)); + CalTranspose(size_x2, x2_input, x2_input_shape, x2_input_trans_axes, SizeToInt(x2_input_shape_.size()), x2_reshape, + reinterpret_cast(stream_ptr)); + + // Matrix Mulitply interim transposed values with GEMM + const float alpha = 1; // constants for cublas API + const float beta = 0; + const int lda = SizeToInt(k_); + const int ldb = SizeToInt(n_); + const int ldc = n_; + auto stride_a = SizeToInt(m_ * k_); + auto stride_b = SizeToInt(k_ * n_); + auto stride_c = SizeToInt(m_ * n_); + try { + CHECK_CUBLAS_RET_WITH_EXCEPT( + cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_N, CUBLAS_OP_N, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), + &alpha, x2_reshape, dtype_b_, ldb, stride_b, x1_reshape, dtype_a_, lda, stride_a, + &beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), + "cublasSgemm Call Fail"); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx"; + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); + // checking for FP16 op, switch to Tensor Core if available + dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); + dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))); + if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) { + MS_LOG(INFO) << "Input and output type is float16, allow to use Tensor Core operations if possible"; + algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + + auto tmp_x1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto tmp_x2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + input_size_x1_ = sizeof(T); + for (size_t i = 0; i < tmp_x1_shape.size(); i++) { + x1_input_shape_.push_back(tmp_x1_shape[i]); + input_size_x1_ *= tmp_x1_shape[i]; + } + input_size_x2_ = sizeof(T); + for (size_t i = 0; i < tmp_x2_shape.size(); i++) { + x2_input_shape_.push_back(tmp_x2_shape[i]); + input_size_x2_ *= tmp_x2_shape[i]; + } + + // holding in temp values to convert to size_t vectors + auto x1_transpose_fwd_temp = GetAttr>(kernel_node, "x1_transpose_fwd"); + auto x2_transpose_fwd_temp = GetAttr>(kernel_node, "x2_transpose_fwd"); + + for (size_t i = 0; i < x1_transpose_fwd_temp.size(); i++) { + x1_transpose_fwd_.push_back(x1_transpose_fwd_temp[i]); + } + + for (size_t i = 0; i < x2_transpose_fwd_temp.size(); i++) { + x2_transpose_fwd_.push_back(x2_transpose_fwd_temp[i]); + } + + // values to decide multiplication call specifics + x1_reshape_fwd_ = GetAttr>(kernel_node, "x1_reshape_fwd"); + x2_reshape_fwd_ = GetAttr>(kernel_node, "x2_reshape_fwd"); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + output_size_ = sizeof(T); + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + } + is_null_input_ = CHECK_NULL_INPUT(output_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "input is null"; + InitSizeLists(); + return true; + } + m_ = x1_reshape_fwd_[0]; + k_ = x1_reshape_fwd_[1]; + n_ = x2_reshape_fwd_[1]; + batch_ = 1; // kept as a single multiplication operation + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t size_t_size = sizeof(size_t); + input_size_list_.push_back(input_size_x1_); + input_size_list_.push_back(input_size_x2_); + workspace_size_list_.push_back(x1_input_shape_.size() * size_t_size); + workspace_size_list_.push_back(x2_input_shape_.size() * size_t_size); + workspace_size_list_.push_back(x1_transpose_fwd_.size() * size_t_size); + workspace_size_list_.push_back(x2_transpose_fwd_.size() * size_t_size); + workspace_size_list_.push_back(input_size_x1_); + workspace_size_list_.push_back(input_size_x2_); + output_size_list_.push_back(output_size_); + } + + private: + size_t batch_; + size_t m_; + size_t n_; + size_t k_; + bool is_null_input_; + std::vector x1_input_shape_; + std::vector x2_input_shape_; + size_t input_size_x1_; + size_t input_size_x2_; + size_t output_size_; + std::vector x1_transpose_fwd_; // For transpose + std::vector x2_transpose_fwd_; + std::vector x1_reshape_fwd_; // For mulitplication shape + std::vector x2_reshape_fwd_; + cublasHandle_t handle_; + cudaDataType_t dtype_a_; + cudaDataType_t dtype_b_; + cudaDataType_t dtype_c_; + cublasGemmAlgo_t algo_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index dbd0df22b0..1267681c9b 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -156,6 +156,48 @@ def bprop_batchmatmul(self): return bprop +@bprop_getters.register(P.TensorDot) +def bprop_tensordot(self): + """Grad definition for `TensorDot` operation.""" + mul_op_x1 = P.MatMul(transpose_a=False, transpose_b=True) + mul_op_x2 = P.MatMul(transpose_a=True, transpose_b=False) + invert_permutation_op = P.InvertPermutation() + transpose_op = P.Transpose() + reshape_op = P.Reshape() + + # pull transformation specifics from P.TensorDot class + x1_transpose_fwd = tuple(self.x1_transpose_fwd) + x2_transpose_fwd = tuple(self.x2_transpose_fwd) + x1_reshape_fwd = tuple(self.x1_reshape_fwd) + x2_reshape_fwd = tuple(self.x2_reshape_fwd) + dout_reshape = (self.x1_reshape_fwd[0], self.x2_reshape_fwd[1]) + + # precalculated in fwd pass due to easier computation + x1_reshape_back = tuple(self.x1_reshape_back) + x2_reshape_back = tuple(self.x2_reshape_back) + + def bprop(x1, x2, out, dout): + # reshape dy values to 2D for MatMul + dout_reshaped = reshape_op(dout, dout_reshape) + # transform inputs to forward pass equivalents + x1_transpose = transpose_op(x1, x1_transpose_fwd) + x2_transpose = transpose_op(x2, x2_transpose_fwd) + x1_reshape = reshape_op(x1_transpose, x1_reshape_fwd) + x2_reshape = reshape_op(x2_transpose, x2_reshape_fwd) + # calculate dx values for x1 and x2 + dx1_interim = mul_op_x1(dout_reshaped, x2_reshape) + dx2_interim = mul_op_x2(x1_reshape, dout_reshaped) + # reverse transformations on dx values for both inputs + dx1_reshape = reshape_op(dx1_interim, x1_reshape_back) + dx2_reshape = reshape_op(dx2_interim, x2_reshape_back) + dx1_retranspose_axes = invert_permutation_op(x1_transpose_fwd) + dx2_retranspose_axes = invert_permutation_op(x2_transpose_fwd) + dx1_transpose = transpose_op(dx1_reshape, dx1_retranspose_axes) + dx2_transpose = transpose_op(dx2_reshape, dx2_retranspose_axes) + return dx1_transpose, dx2_transpose + return bprop + + @bprop_getters.register(P.TensorAdd) def get_bprop_tensor_add(self): """Grad definition for `TensorAdd` operation.""" diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index a504f53e3e..e98484c5f0 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -54,7 +54,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, IFMR, - Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) + Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, TensorDot) from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, RandomCategorical, StandardLaplace, Multinomial) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 3e123c6a7d..7ae10528db 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -744,6 +744,126 @@ class BatchMatMul(MatMul): 'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}') +class TensorDot(PrimitiveWithInfer): + """ + Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`. + + Contraction allows for the summation of products of elements of `a` and `b` on specified axes. + The same number of axes must be specified for both x1 and x2, and values must be within range + of number of dims of both `a` and `b`. + + Selected dims in both inputs must also match. + + axes = 0 leads to outer product, and axes = 1 leads to normal matrix multiplication. + axes = 1 is the same as axes = ((0,),(1,) where length of input shape is 2 for both `a` and `b` + axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b` + + Args: + **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]): Single value or + tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed, + automatically picks up first N dims from `a` input shape and last N dims from `b` input shape. + + Inputs: + - **x1** (Tensor): First tensor in TensorDot op with datatype float16 or float32 + - **x2** (Tensor): Second tensor in TensorDot op with datatype float16 or float32 + + Outputs: + Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not + contracted in both inputs + + Examples: + >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32) + >>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32) + >>> tensordot = P.TensorDot(((0,1),(1,2))) + >>> output = tensordot(input_x1, input_x2) + """ + + @prim_attr_register + def __init__(self, axes): + self.axes = axes + validator.check_value_type('axes', axes, [int, tuple, list], self.name) + if not isinstance(self.axes, int): + self.axes = list(self.axes) # to avoid immutability issues + if len(self.axes) != 2: + raise ValueError("Require two axes inputs, given less") + self.int_to_tuple_conv() # convert before length checks + if len(self.axes[0]) != len(self.axes[1]): + raise ValueError("Axes have to be the same size/length") + if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])): + raise ValueError("Axes cannot have duplicating values") + + def int_to_tuple_conv(self): + """ + Converts ints to tuples in input axes, expected by most validation checks. + """ + for x in [0, 1]: + if isinstance(self.axes[x], int): + self.axes[x] = (self.axes[x],) + + def check_input_axes(self, x1_shape, x2_shape): + """ + Convert from single int axes to 2d tuple if required and check for validity with inputs. + """ + if isinstance(self.axes, int): + if self.axes <= 0: + # outer product, no input validation required + self.axes = ([], []) # no axes selected for either + return + if self.axes > len(x1_shape) or self.axes > len(x2_shape): + raise ValueError( + "Axes value too high for given input arrays dimensions.") + x1_ind = tuple(range(len(x1_shape))[-1 * self.axes:]) + x2_ind = tuple(range(len(x2_shape))[:self.axes]) + self.axes = tuple((x1_ind, x2_ind)) + self.int_to_tuple_conv() + for i in range(len(self.axes[0])): # sizes already validated + if x1_shape[self.axes[0][i]] != x2_shape[self.axes[1][i]]: + raise ValueError( + "Given Axes are incompatible with given input arrays") + + def calc_new_shape(self, shape, position=0): + """ + Calculate transpose and reshape parameters for input transformations, + 'position' refers to whether tensor is first or second in the op. + """ + contraction_axes = [i if i >= 0 else i + len(shape) for i in self.axes[position]] + prod_contraction = int(np.prod([shape[i] for i in contraction_axes])) + free_axes = [i for i in range(len(shape)) if i not in contraction_axes] + free_dims = [shape[i] for i in free_axes] + prod_free = int(np.prod(free_dims)) + + transpose_perm = list(contraction_axes) + free_axes if position else free_axes + list(contraction_axes) + new_shape = [prod_contraction, prod_free] if position else [prod_free, prod_contraction] + return new_shape, transpose_perm, free_dims + + def generate_transform_dims(self, x1_shape, x2_shape): + """ + Initiate calls for input transform calculations and calculate paramters for output + and for backprop tranformations. + """ + self.x1_reshape_fwd, self.x1_transpose_fwd, x1_ret = self.calc_new_shape(x1_shape, 0) + self.x2_reshape_fwd, self.x2_transpose_fwd, x2_ret = self.calc_new_shape(x2_shape, 1) + self.output_shape = x1_ret + x2_ret # combine free axes from both inputs + self.x1_reshape_back = [x1_shape[x] for x in self.x1_transpose_fwd] + self.x2_reshape_back = [x2_shape[x] for x in self.x2_transpose_fwd] + + def infer_shape(self, x1, x2): + self.check_input_axes(x1, x2) + self.generate_transform_dims(x1, x2) + # processed parameters for reading directly into kernel + self.add_prim_attr('x1_transpose_fwd', self.x1_transpose_fwd) + self.add_prim_attr('x2_transpose_fwd', self.x2_transpose_fwd) + self.add_prim_attr('x1_reshape_fwd', self.x1_reshape_fwd) + self.add_prim_attr('x2_reshape_fwd', self.x2_reshape_fwd) + return self.output_shape + + def infer_dtype(self, x1, x2): + args = {"x1": x1, "x2": x2} + valid_types = [mstype.float16, mstype.float32] + validator.check_tensor_type_same(args, valid_types, self.name) + return x1 + + class CumSum(PrimitiveWithInfer): """ Computes the cumulative sum of input tensor along axis. diff --git a/tests/st/ops/gpu/test_tensordot_op.py b/tests/st/ops/gpu/test_tensordot_op.py new file mode 100644 index 0000000000..2a0e5991c8 --- /dev/null +++ b/tests/st/ops/gpu/test_tensordot_op.py @@ -0,0 +1,339 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np + +import mindspore +from mindspore import Tensor +import mindspore.nn as nn +import mindspore.context as context +from mindspore.ops import operations as P +from mindspore.ops import composite as C + + +class NetTensorDot(nn.Cell): + def __init__(self, axes): + super(NetTensorDot, self).__init__() + self.td = P.TensorDot(axes) + + def construct(self, x, y): + return self.td(x, y) + + +class GradNetwork(nn.Cell): + def __init__(self, network): + super(GradNetwork, self).__init__() + self.grad = C.GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, input_data_a, input_data_b, sens): + gout = self.grad(self.network)(input_data_a, input_data_b, sens) + return gout + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_tensor_dot_fp32(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + np.random.seed(12876) + shape_x1 = (1, 3, 9, 7) + shape_x2 = (9, 7, 3, 1) + axes = ((1, 3), (2, 1)) + x1 = np.random.random(shape_x1).astype(np.float32) + x2 = np.random.random(shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetTensorDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + np_result = np.tensordot(x1, x2, axes) + np.testing.assert_array_almost_equal(ms_result_np, np_result) + + # 1D + shape_x1 = (200) + shape_x2 = (200) + axes = 1 + x1 = np.random.random(shape_x1).astype(np.float32) + x2 = np.random.random(shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetTensorDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + np_result = np.tensordot(x1, x2, axes) + np.allclose(ms_result_np, np_result) + + # 2D + shape_x1 = (100, 300) + shape_x2 = (300, 700) + axes = ([1], [0]) + x1 = np.random.random(shape_x1).astype(np.float32) + x2 = np.random.random(shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetTensorDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + np_result = np.tensordot(x1, x2, axes) + np.allclose(ms_result_np, np_result) + + # 3D + shape_x1 = (110, 30, 900) + shape_x2 = (900, 70, 30) + axes = ((1, 2), (2, 0)) + x1 = np.random.random(shape_x1).astype(np.float32) + x2 = np.random.random(shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetTensorDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + np_result = np.tensordot(x1, x2, axes) + np.allclose(ms_result_np, np_result) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_tensor_dot_fp16(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + np.random.seed(41329) + shape_x1 = (1, 3, 4, 1) + shape_x2 = (4, 1, 7, 5) + axes = 2 # select first N from + x1 = np.random.random(shape_x1).astype(np.float16) + x2 = np.random.random(shape_x2).astype(np.float16) + x1_tensor = Tensor(x1, dtype=mindspore.float16) + x2_tensor = Tensor(x2, dtype=mindspore.float16) + + network = NetTensorDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + np_result = np.tensordot(x1, x2, axes) + np.testing.assert_array_almost_equal(ms_result_np, np_result) + + # 1D + shape_x1 = (300) + shape_x2 = (300) + axes = 1 + x1 = np.random.random(shape_x1).astype(np.float16) + x2 = np.random.random(shape_x2).astype(np.float16) + x1_tensor = Tensor(x1, dtype=mindspore.float16) + x2_tensor = Tensor(x2, dtype=mindspore.float16) + + network = NetTensorDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + np_result = np.tensordot(x1, x2, axes) + np.testing.assert_array_almost_equal(ms_result_np, np_result) + + # 2D + shape_x1 = (100, 300) + shape_x2 = (300, 100) + axes = ([1], [0]) + x1 = np.random.random(shape_x1).astype(np.float16) + x2 = np.random.random(shape_x2).astype(np.float16) + x1_tensor = Tensor(x1, dtype=mindspore.float16) + x2_tensor = Tensor(x2, dtype=mindspore.float16) + + network = NetTensorDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + np_result = np.tensordot(x1, x2, axes) + np.testing.assert_array_almost_equal(ms_result_np, np_result) + + # 3D + shape_x1 = (60, 30, 450) + shape_x2 = (450, 90, 30) + axes = ((1, 2), (2, 0)) + x1 = np.random.random(shape_x1).astype(np.float16) + x2 = np.random.random(shape_x2).astype(np.float16) + x1_tensor = Tensor(x1, dtype=mindspore.float16) + x2_tensor = Tensor(x2, dtype=mindspore.float16) + + network = NetTensorDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + np_result = np.tensordot(x1, x2, axes) + np.testing.assert_array_almost_equal(ms_result_np, np_result) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_tensor_dot_outer(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(2746) + shape_x1 = (1, 2, 3) # incompatable dims for x1 and x2 + shape_x2 = (4, 5, 6) + axes = 0 # outer product does not require multiplicable dims + x1 = np.random.random(shape_x1).astype(np.float32) + x2 = np.random.random(shape_x2).astype(np.float32) + x1_tensor = Tensor(x1, dtype=mindspore.float32) + x2_tensor = Tensor(x2, dtype=mindspore.float32) + + network = NetTensorDot(axes) + ms_result_np = network(x1_tensor, x2_tensor).asnumpy() + np_result = np.tensordot(x1, x2, axes) + np.testing.assert_array_almost_equal(ms_result_np, np_result) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_tensor_dot_backprop(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + # TEST 1 + shape_x1 = (2, 4, 2) + shape_x2 = (3, 2, 3) + axes = ((0,), (1,)) # select first N from + network = NetTensorDot(axes) + + np.random.seed(115) + x1 = np.random.random(shape_x1).astype(np.float16) + np.random.seed(1467) + x2 = np.random.random(shape_x2).astype(np.float16) + x1_tensor = Tensor(x1, dtype=mindspore.float16) + x2_tensor = Tensor(x2, dtype=mindspore.float16) + + np.random.seed(157) + grad = np.random.random((4, 2, 3, 3)) + grad_tensor = Tensor(grad, dtype=mindspore.float16) + grad_network = GradNetwork(network) + dx1, dx2 = grad_network(x1_tensor, x2_tensor, grad_tensor) + dx1, dx2 = dx1.asnumpy(), dx2.asnumpy() + + # precomputed + expect_dx1 = np.array([[[2.0293, 2.4473], + [2.9727, 1.4873], + [1.7910, 3.4727], + [2.4160, 1.7227]], + + [[2.5547, 2.5039], + [3.4062, 2.3320], + [2.6270, 3.1543], + [2.1406, 1.7666]]]) + expect_dx2 = np.array([[[2.1523, 2.9199, 0.8350], + [2.0254, 2.7734, 1.3213]], + + [[2.6836, 2.4707, 1.0156], + [2.9746, 3.0254, 1.9199]], + + [[1.8545, 1.7803, 1.3457], + [2.2676, 2.1797, 1.2764]]]) + np.allclose(dx1, expect_dx1) + np.allclose(dx2, expect_dx2) + + # TEST 2 + shape_x1 = (10, 35) + shape_x2 = (20, 10) + axes = ((0,), (1,)) # select first N from + network = NetTensorDot(axes) + + np.random.seed(215) + x1 = np.random.random(shape_x1).astype(np.float16) + np.random.seed(2467) + x2 = np.random.random(shape_x2).astype(np.float16) + x1_tensor = Tensor(x1, dtype=mindspore.float16) + x2_tensor = Tensor(x2, dtype=mindspore.float16) + + np.random.seed(257) + grad = np.random.random((35, 20)) + grad_tensor = Tensor(grad, dtype=mindspore.float16) + grad_network = GradNetwork(network) + dx1, dx2 = grad_network(x1_tensor, x2_tensor, grad_tensor) + dx1, dx2 = dx1.asnumpy(), dx2.asnumpy() + + # precomputed + expect_dx1 = np.array([[5.9727, 4.6484, 5.1836, 4.3906, 5.1641, 5.1406, 5.1211, 6.5352, 4.9922, + 4.4297, 4.4648, 6.5469, 6.2305, 4.8789, 6.8320, 5.3906, 4.7383, 6.0352, + 4.7383, 4.4844, 5.3711, 6.2617, 4.6484, 5.8672, 4.7500, 6.0234, 3.6387, + 5.3789, 5.9727, 5.7227, 6.0234, 4.9609, 5.0117, 5.4141, 5.1406], + [5.2305, 4.0078, 4.6328, 3.9238, 4.2773, 4.2539, 4.6797, 5.1289, 3.7910, + 3.8887, 3.2930, 5.5898, 5.4219, 3.6211, 5.5234, 3.5391, 4.8516, 4.7539, + 4.2500, 2.9785, 4.8867, 5.4648, 5.0195, 6.0195, 4.7109, 3.9727, 3.4922, + 4.1484, 4.7969, 5.3555, 4.9414, 5.2969, 3.1992, 5.2031, 4.4648], + [5.2266, 5.2617, 5.3750, 4.7930, 4.9062, 5.4102, 4.9336, 6.9414, 4.4961, + 4.4023, 4.7344, 5.8125, 4.9180, 4.7891, 5.9805, 5.2383, 4.6445, 6.1172, + 4.8477, 3.7578, 4.3047, 5.7969, 4.5859, 6.0273, 4.3438, 4.7305, 4.0938, + 4.8398, 5.8320, 5.3438, 5.3281, 4.8320, 4.0938, 4.9375, 5.3281], + [7.4297, 5.1484, 6.3477, 5.4844, 5.7852, 6.3906, 5.5234, 7.2383, 5.2969, + 4.9844, 4.5625, 7.3047, 7.3789, 6.4453, 8.2266, 6.6172, 5.5547, 7.0234, + 4.8594, 4.9531, 6.0469, 6.9258, 6.1055, 6.7539, 6.6953, 6.0430, 4.5117, + 5.7344, 7.4297, 6.4219, 6.8125, 6.4141, 5.2773, 6.8828, 6.0430], + [5.7969, 4.7109, 5.8281, 4.5703, 5.5078, 6.4219, 4.8359, 7.1484, 4.2617, + 4.8477, 4.2539, 5.6016, 6.4414, 5.7305, 6.4766, 5.4648, 4.5859, 6.5547, + 5.5156, 3.3848, 5.1523, 5.5352, 4.9531, 6.5938, 5.2969, 4.6055, 5.2109, + 4.4961, 5.8984, 5.4531, 5.8086, 5.7930, 5.0742, 5.4102, 4.9453], + [7.2188, 5.8789, 6.9453, 6.0039, 6.7188, 7.3359, 6.7695, 8.6172, 5.6680, + 6.4219, 6.1836, 7.7695, 7.5391, 6.5312, 8.2812, 7.5352, 5.8867, 7.7070, + 6.0039, 5.1172, 6.4844, 7.4297, 5.9219, 7.5078, 6.3125, 6.9805, 5.3750, + 5.9805, 7.2148, 7.6484, 7.8828, 6.7695, 5.7109, 6.8828, 6.9023], + [5.7656, 4.3633, 4.5039, 4.4375, 4.3867, 5.4336, 4.3672, 5.5469, 3.5742, + 4.0508, 3.7402, 5.9141, 5.7734, 4.5781, 5.6719, 4.5625, 4.5391, 5.1719, + 4.3945, 3.4844, 4.9297, 5.7227, 4.8203, 5.8125, 4.8633, 4.3125, 3.6641, + 4.3789, 5.6133, 5.1758, 4.9141, 5.8008, 4.0391, 5.8984, 4.3594], + [4.7734, 3.4238, 4.3477, 3.6270, 4.4883, 5.2031, 3.9023, 5.0078, 2.9355, + 3.8477, 3.4648, 5.1445, 4.8398, 4.4297, 5.1641, 4.2422, 4.2695, 4.6992, + 4.5039, 2.5176, 4.2500, 5.6680, 4.1875, 5.4141, 3.6094, 3.1758, 3.8398, + 3.9180, 5.3320, 4.6523, 3.9531, 4.8281, 3.9863, 4.8867, 4.3711], + [6.7578, 5.3164, 6.0000, 4.4531, 5.8789, 6.3750, 5.1094, 7.0391, 4.5781, + 4.8633, 4.5156, 6.6641, 6.3594, 5.5664, 6.9453, 5.5820, 5.1992, 6.9570, + 5.3242, 3.8574, 5.1445, 6.0547, 5.0273, 6.9180, 5.1914, 4.6914, 4.6445, + 5.1289, 5.8711, 6.2070, 6.1953, 5.7695, 4.7617, 5.5898, 4.9492], + [4.9180, 4.0117, 4.1211, 3.4629, 3.6445, 4.6602, 3.7031, 4.9062, 4.1133, + 3.0020, 3.2246, 4.6562, 4.4727, 3.3828, 5.2695, 4.0078, 3.2559, 4.9688, + 3.5742, 3.1133, 3.8223, 4.7578, 3.7949, 4.8438, 4.0664, 4.4336, 3.0957, + 4.4375, 4.2969, 4.1758, 4.5234, 4.2930, 3.9434, 4.8281, 3.0703]]) + expect_dx2 = np.array([[6.7930, 7.0000, 8.8203, 9.7031, 8.1250, + 6.7422, 8.4844, 8.7031, 7.2891, 10.1484], + [8.5781, 8.1641, 9.9609, 9.2344, 9.3281, + 8.1484, 9.8984, 9.0391, 7.9805, 11.0469], + [8.1016, 7.0781, 8.9688, 10.0938, 9.6641, + 7.1523, 8.2969, 8.8594, 8.3047, 10.2578], + [7.0938, 7.3477, 9.3594, 8.2422, 7.9141, + 6.5156, 8.2812, 8.2266, 6.9766, 8.5703], + [9.2891, 9.2500, 11.6875, 9.5234, 10.1172, + 8.8125, 9.5781, 9.5547, 8.9688, 11.2266], + [9.3594, 7.7539, 9.2500, 9.2500, 8.1094, + 8.0859, 8.7344, 8.2031, 8.5859, 10.3203], + [8.7344, 7.7227, 10.2578, 10.1641, 9.3984, + 8.1719, 8.0156, 8.6953, 8.6797, 10.6875], + [8.8750, 7.9922, 10.2422, 10.3984, 9.5234, + 8.5156, 8.7266, 8.8125, 8.2578, 10.2578], + [9.5703, 8.9844, 10.0547, 10.3047, 10.4062, + 8.2422, 10.7031, 9.7891, 9.2969, 11.0078], + [9.2891, 9.5391, 10.5938, 10.5078, 9.8203, + 8.5156, 9.0859, 9.0703, 8.7812, 10.8750], + [8.6094, 8.2734, 10.2734, 9.7891, 9.4531, + 7.5820, 8.4609, 8.6094, 7.7578, 10.3438], + [8.2891, 8.7578, 9.3906, 9.6016, 9.4375, + 7.1016, 8.6875, 8.1875, 8.2188, 9.3672], + [7.2969, 6.6953, 9.3984, 8.2422, 8.3438, + 7.5547, 7.6445, 7.5820, 7.5156, 9.0781], + [8.3906, 7.3516, 8.5938, 9.2422, 8.7734, + 8.0781, 9.1250, 7.8359, 7.7891, 10.9375], + [9.9219, 8.8281, 9.4141, 10.2500, 9.8047, + 8.5234, 8.5391, 8.4609, 8.5859, 11.2422], + [6.8984, 6.4570, 8.0000, 6.4688, 7.4609, + 6.6016, 7.0352, 6.6797, 6.5586, 7.7070], + [8.0625, 7.4805, 8.7578, 8.3281, 8.2188, + 7.4023, 8.5312, 7.5312, 7.1445, 10.3750], + [7.7773, 6.6484, 9.1094, 8.0078, 7.8281, + 7.1016, 8.2422, 8.1562, 6.8828, 10.3281], + [8.3281, 8.3672, 9.7656, 10.4922, 8.2500, + 7.5625, 8.4922, 8.9844, 8.0703, 10.3438], + [7.5195, 7.0430, 7.9453, 8.4375, 7.6641, + 6.9688, 7.7734, 8.7734, 6.3672, 9.4766]]) + np.allclose(dx1, expect_dx1) + np.allclose(dx2, expect_dx2)