!7611 [MS][GPU] Adding new Ops - TensorDot and TensorDot Grad

Merge pull request !7611 from danishnxt/newMaster
pull/7611/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit a3af89bd48

@ -0,0 +1,30 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/tensordot_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
TensorDot,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TensorDotGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
TensorDot,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
TensorDotGpuKernel, half)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,212 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TENSORDOT_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TENSORDOT_H_
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh"
#include "utils/convert_utils.h"
namespace mindspore {
namespace kernel {
template <typename T>
class TensorDotGpuKernel : public GpuKernel {
public:
TensorDotGpuKernel()
: batch_(0),
m_(0),
n_(0),
k_(0),
is_null_input_(false),
handle_(nullptr),
dtype_a_(CUDA_R_32F),
dtype_b_(CUDA_R_32F),
dtype_c_(CUDA_R_32F),
algo_(CUBLAS_GEMM_DEFAULT) {}
~TensorDotGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *x1_input = GetDeviceAddress<T>(inputs, 0);
T *x2_input = GetDeviceAddress<T>(inputs, 1);
size_t *x1_input_shape = GetDeviceAddress<size_t>(workspace, 0);
size_t *x2_input_shape = GetDeviceAddress<size_t>(workspace, 1);
size_t *x1_input_trans_axes = GetDeviceAddress<size_t>(workspace, 2);
size_t *x2_input_trans_axes = GetDeviceAddress<size_t>(workspace, 3);
// transposed interim values moved to workspace, then multiplied
T *x1_reshape = GetDeviceAddress<T>(workspace, 4);
T *x2_reshape = GetDeviceAddress<T>(workspace, 5);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
// Transpose X1
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(x1_input_shape, &x1_input_shape_[0], x1_input_shape_.size() * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync x1_input_shape failed");
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(x1_input_trans_axes, &x1_transpose_fwd_[0], x1_input_shape_.size() * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_axis_x1 failed");
int size_x1 = SizeToInt(input_size_x1_ / sizeof(T));
CalTranspose(size_x1, x1_input, x1_input_shape, x1_input_trans_axes, SizeToInt(x1_input_shape_.size()), x1_reshape,
reinterpret_cast<cudaStream_t>(stream_ptr));
// Transpose X2
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(x2_input_shape, &x2_input_shape_[0], (x2_input_shape_.size() * sizeof(size_t)),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync x2_input_shape failed");
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(x2_input_trans_axes, &x2_transpose_fwd_[0], (x2_input_shape_.size() * sizeof(size_t)),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_axis_x2 failed");
int size_x2 = SizeToInt(input_size_x2_ / sizeof(T));
CalTranspose(size_x2, x2_input, x2_input_shape, x2_input_trans_axes, SizeToInt(x2_input_shape_.size()), x2_reshape,
reinterpret_cast<cudaStream_t>(stream_ptr));
// Matrix Mulitply interim transposed values with GEMM
const float alpha = 1; // constants for cublas API
const float beta = 0;
const int lda = SizeToInt(k_);
const int ldb = SizeToInt(n_);
const int ldc = n_;
auto stride_a = SizeToInt(m_ * k_);
auto stride_b = SizeToInt(k_ * n_);
auto stride_c = SizeToInt(m_ * n_);
try {
CHECK_CUBLAS_RET_WITH_EXCEPT(
cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_N, CUBLAS_OP_N, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
&alpha, x2_reshape, dtype_b_, ldb, stride_b, x1_reshape, dtype_a_, lda, stride_a,
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
"cublasSgemm Call Fail");
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx";
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
// checking for FP16 op, switch to Tensor Core if available
dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1)));
dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0)));
if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) {
MS_LOG(INFO) << "Input and output type is float16, allow to use Tensor Core operations if possible";
algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
auto tmp_x1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto tmp_x2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
input_size_x1_ = sizeof(T);
for (size_t i = 0; i < tmp_x1_shape.size(); i++) {
x1_input_shape_.push_back(tmp_x1_shape[i]);
input_size_x1_ *= tmp_x1_shape[i];
}
input_size_x2_ = sizeof(T);
for (size_t i = 0; i < tmp_x2_shape.size(); i++) {
x2_input_shape_.push_back(tmp_x2_shape[i]);
input_size_x2_ *= tmp_x2_shape[i];
}
// holding in temp values to convert to size_t vectors
auto x1_transpose_fwd_temp = GetAttr<std::vector<int>>(kernel_node, "x1_transpose_fwd");
auto x2_transpose_fwd_temp = GetAttr<std::vector<int>>(kernel_node, "x2_transpose_fwd");
for (size_t i = 0; i < x1_transpose_fwd_temp.size(); i++) {
x1_transpose_fwd_.push_back(x1_transpose_fwd_temp[i]);
}
for (size_t i = 0; i < x2_transpose_fwd_temp.size(); i++) {
x2_transpose_fwd_.push_back(x2_transpose_fwd_temp[i]);
}
// values to decide multiplication call specifics
x1_reshape_fwd_ = GetAttr<std::vector<int>>(kernel_node, "x1_reshape_fwd");
x2_reshape_fwd_ = GetAttr<std::vector<int>>(kernel_node, "x2_reshape_fwd");
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
output_size_ = sizeof(T);
for (size_t i = 0; i < output_shape.size(); i++) {
output_size_ *= output_shape[i];
}
is_null_input_ = CHECK_NULL_INPUT(output_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "input is null";
InitSizeLists();
return true;
}
m_ = x1_reshape_fwd_[0];
k_ = x1_reshape_fwd_[1];
n_ = x2_reshape_fwd_[1];
batch_ = 1; // kept as a single multiplication operation
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
size_t size_t_size = sizeof(size_t);
input_size_list_.push_back(input_size_x1_);
input_size_list_.push_back(input_size_x2_);
workspace_size_list_.push_back(x1_input_shape_.size() * size_t_size);
workspace_size_list_.push_back(x2_input_shape_.size() * size_t_size);
workspace_size_list_.push_back(x1_transpose_fwd_.size() * size_t_size);
workspace_size_list_.push_back(x2_transpose_fwd_.size() * size_t_size);
workspace_size_list_.push_back(input_size_x1_);
workspace_size_list_.push_back(input_size_x2_);
output_size_list_.push_back(output_size_);
}
private:
size_t batch_;
size_t m_;
size_t n_;
size_t k_;
bool is_null_input_;
std::vector<size_t> x1_input_shape_;
std::vector<size_t> x2_input_shape_;
size_t input_size_x1_;
size_t input_size_x2_;
size_t output_size_;
std::vector<size_t> x1_transpose_fwd_; // For transpose
std::vector<size_t> x2_transpose_fwd_;
std::vector<int> x1_reshape_fwd_; // For mulitplication shape
std::vector<int> x2_reshape_fwd_;
cublasHandle_t handle_;
cudaDataType_t dtype_a_;
cudaDataType_t dtype_b_;
cudaDataType_t dtype_c_;
cublasGemmAlgo_t algo_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif

@ -156,6 +156,48 @@ def bprop_batchmatmul(self):
return bprop
@bprop_getters.register(P.TensorDot)
def bprop_tensordot(self):
"""Grad definition for `TensorDot` operation."""
mul_op_x1 = P.MatMul(transpose_a=False, transpose_b=True)
mul_op_x2 = P.MatMul(transpose_a=True, transpose_b=False)
invert_permutation_op = P.InvertPermutation()
transpose_op = P.Transpose()
reshape_op = P.Reshape()
# pull transformation specifics from P.TensorDot class
x1_transpose_fwd = tuple(self.x1_transpose_fwd)
x2_transpose_fwd = tuple(self.x2_transpose_fwd)
x1_reshape_fwd = tuple(self.x1_reshape_fwd)
x2_reshape_fwd = tuple(self.x2_reshape_fwd)
dout_reshape = (self.x1_reshape_fwd[0], self.x2_reshape_fwd[1])
# precalculated in fwd pass due to easier computation
x1_reshape_back = tuple(self.x1_reshape_back)
x2_reshape_back = tuple(self.x2_reshape_back)
def bprop(x1, x2, out, dout):
# reshape dy values to 2D for MatMul
dout_reshaped = reshape_op(dout, dout_reshape)
# transform inputs to forward pass equivalents
x1_transpose = transpose_op(x1, x1_transpose_fwd)
x2_transpose = transpose_op(x2, x2_transpose_fwd)
x1_reshape = reshape_op(x1_transpose, x1_reshape_fwd)
x2_reshape = reshape_op(x2_transpose, x2_reshape_fwd)
# calculate dx values for x1 and x2
dx1_interim = mul_op_x1(dout_reshaped, x2_reshape)
dx2_interim = mul_op_x2(x1_reshape, dout_reshaped)
# reverse transformations on dx values for both inputs
dx1_reshape = reshape_op(dx1_interim, x1_reshape_back)
dx2_reshape = reshape_op(dx2_interim, x2_reshape_back)
dx1_retranspose_axes = invert_permutation_op(x1_transpose_fwd)
dx2_retranspose_axes = invert_permutation_op(x2_transpose_fwd)
dx1_transpose = transpose_op(dx1_reshape, dx1_retranspose_axes)
dx2_transpose = transpose_op(dx2_reshape, dx2_retranspose_axes)
return dx1_transpose, dx2_transpose
return bprop
@bprop_getters.register(P.TensorAdd)
def get_bprop_tensor_add(self):
"""Grad definition for `TensorAdd` operation."""

@ -54,7 +54,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, IFMR,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, TensorDot)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial)

@ -744,6 +744,126 @@ class BatchMatMul(MatMul):
'greater or equal to 3,' + f' while x size = {len(x)}, y size= {len(y)}')
class TensorDot(PrimitiveWithInfer):
"""
Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.
Contraction allows for the summation of products of elements of `a` and `b` on specified axes.
The same number of axes must be specified for both x1 and x2, and values must be within range
of number of dims of both `a` and `b`.
Selected dims in both inputs must also match.
axes = 0 leads to outer product, and axes = 1 leads to normal matrix multiplication.
axes = 1 is the same as axes = ((0,),(1,) where length of input shape is 2 for both `a` and `b`
axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b`
Args:
**axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]): Single value or
tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed,
automatically picks up first N dims from `a` input shape and last N dims from `b` input shape.
Inputs:
- **x1** (Tensor): First tensor in TensorDot op with datatype float16 or float32
- **x2** (Tensor): Second tensor in TensorDot op with datatype float16 or float32
Outputs:
Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not
contracted in both inputs
Examples:
>>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32)
>>> tensordot = P.TensorDot(((0,1),(1,2)))
>>> output = tensordot(input_x1, input_x2)
"""
@prim_attr_register
def __init__(self, axes):
self.axes = axes
validator.check_value_type('axes', axes, [int, tuple, list], self.name)
if not isinstance(self.axes, int):
self.axes = list(self.axes) # to avoid immutability issues
if len(self.axes) != 2:
raise ValueError("Require two axes inputs, given less")
self.int_to_tuple_conv() # convert before length checks
if len(self.axes[0]) != len(self.axes[1]):
raise ValueError("Axes have to be the same size/length")
if len(self.axes[0]) != len(set(self.axes[0])) or len(self.axes[1]) != len(set(self.axes[1])):
raise ValueError("Axes cannot have duplicating values")
def int_to_tuple_conv(self):
"""
Converts ints to tuples in input axes, expected by most validation checks.
"""
for x in [0, 1]:
if isinstance(self.axes[x], int):
self.axes[x] = (self.axes[x],)
def check_input_axes(self, x1_shape, x2_shape):
"""
Convert from single int axes to 2d tuple if required and check for validity with inputs.
"""
if isinstance(self.axes, int):
if self.axes <= 0:
# outer product, no input validation required
self.axes = ([], []) # no axes selected for either
return
if self.axes > len(x1_shape) or self.axes > len(x2_shape):
raise ValueError(
"Axes value too high for given input arrays dimensions.")
x1_ind = tuple(range(len(x1_shape))[-1 * self.axes:])
x2_ind = tuple(range(len(x2_shape))[:self.axes])
self.axes = tuple((x1_ind, x2_ind))
self.int_to_tuple_conv()
for i in range(len(self.axes[0])): # sizes already validated
if x1_shape[self.axes[0][i]] != x2_shape[self.axes[1][i]]:
raise ValueError(
"Given Axes are incompatible with given input arrays")
def calc_new_shape(self, shape, position=0):
"""
Calculate transpose and reshape parameters for input transformations,
'position' refers to whether tensor is first or second in the op.
"""
contraction_axes = [i if i >= 0 else i + len(shape) for i in self.axes[position]]
prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
free_axes = [i for i in range(len(shape)) if i not in contraction_axes]
free_dims = [shape[i] for i in free_axes]
prod_free = int(np.prod(free_dims))
transpose_perm = list(contraction_axes) + free_axes if position else free_axes + list(contraction_axes)
new_shape = [prod_contraction, prod_free] if position else [prod_free, prod_contraction]
return new_shape, transpose_perm, free_dims
def generate_transform_dims(self, x1_shape, x2_shape):
"""
Initiate calls for input transform calculations and calculate paramters for output
and for backprop tranformations.
"""
self.x1_reshape_fwd, self.x1_transpose_fwd, x1_ret = self.calc_new_shape(x1_shape, 0)
self.x2_reshape_fwd, self.x2_transpose_fwd, x2_ret = self.calc_new_shape(x2_shape, 1)
self.output_shape = x1_ret + x2_ret # combine free axes from both inputs
self.x1_reshape_back = [x1_shape[x] for x in self.x1_transpose_fwd]
self.x2_reshape_back = [x2_shape[x] for x in self.x2_transpose_fwd]
def infer_shape(self, x1, x2):
self.check_input_axes(x1, x2)
self.generate_transform_dims(x1, x2)
# processed parameters for reading directly into kernel
self.add_prim_attr('x1_transpose_fwd', self.x1_transpose_fwd)
self.add_prim_attr('x2_transpose_fwd', self.x2_transpose_fwd)
self.add_prim_attr('x1_reshape_fwd', self.x1_reshape_fwd)
self.add_prim_attr('x2_reshape_fwd', self.x2_reshape_fwd)
return self.output_shape
def infer_dtype(self, x1, x2):
args = {"x1": x1, "x2": x2}
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
return x1
class CumSum(PrimitiveWithInfer):
"""
Computes the cumulative sum of input tensor along axis.

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save