!12329 GPU index_add forward op

From: @tom__chen
Reviewed-by: @liangchenghui,@robingrosman
Signed-off-by: @liangchenghui
pull/12329/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 52eae8d4ad

@ -0,0 +1,106 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "runtime/device/gpu/cuda_common.h"
#include "include/cuda_fp16.h"
__global__ void InitErrorCode(IndexAddErrorCode *error_code) {
*error_code = IndexAddErrorCode::kOk;
}
__global__ void ValidateIndexValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size,
IndexAddErrorCode *error_code) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < src_axis_size; pos += blockDim.x * gridDim.x) {
const int idx_value = index[pos];
if (idx_value < 0 || idx_value >= dst_axis_size) {
*error_code = IndexAddErrorCode::kIndexOutOfRange;
return;
}
}
return;
}
template <typename T>
__global__ void IndexAddAtomic(T *dst, const int *index, const T *src, const size_t src_size, const size_t outer_size,
const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < src_size; pos += blockDim.x * gridDim.x) {
const size_t src_axis_idx = (pos / inner_size) % src_axis_size;
const size_t src_outer_idx = pos / (src_axis_size * inner_size);
const size_t dst_axis_idx = static_cast<size_t>(index[src_axis_idx]);
const size_t dst_inner_idx = pos % inner_size;
const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx;
MsAtomicAdd(&dst[dst_idx], src[pos]);
}
return;
}
template <typename T>
__global__ void IndexAdd(T *dst, const int *index, const T *src, const size_t src_size, const size_t outer_size,
const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < src_size; pos += blockDim.x * gridDim.x) {
const size_t src_axis_idx = (pos / inner_size) % src_axis_size;
const size_t src_outer_idx = pos / (src_axis_size * inner_size);
const size_t dst_axis_idx = static_cast<size_t>(index[src_axis_idx]);
const size_t dst_inner_idx = pos % inner_size;
const size_t dst_idx = src_outer_idx * (dst_axis_size * inner_size) + dst_axis_idx * inner_size + dst_inner_idx;
dst[dst_idx] += src[pos];
}
return;
}
void ValidateIndexAddInputValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size,
IndexAddErrorCode *error_code, cudaStream_t cuda_stream) {
InitErrorCode<<<1, 1, 0, cuda_stream>>>(error_code);
ValidateIndexValues<<<GET_BLOCKS(src_axis_size), GET_THREADS, 0, cuda_stream>>>(index, src_axis_size, dst_axis_size,
error_code);
}
template <typename T>
void CalIndexAdd(T *dst, const int *index, const T *src, const size_t outer_size, const size_t src_axis_size,
const size_t dst_axis_size, const size_t inner_size, const bool use_lock, cudaStream_t cuda_stream) {
size_t src_size = outer_size * src_axis_size * inner_size;
if (use_lock) {
IndexAddAtomic<<<GET_BLOCKS(src_size), GET_THREADS, 0, cuda_stream>>>(dst, index, src, src_size, outer_size,
src_axis_size, dst_axis_size, inner_size);
} else {
IndexAdd<<<GET_BLOCKS(src_size), GET_THREADS, 0, cuda_stream>>>(dst, index, src, src_size, outer_size,
src_axis_size, dst_axis_size, inner_size);
}
return;
}
template void CalIndexAdd<double>(double *dst, const int *index, const double *src, const size_t outer_size,
const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock,
cudaStream_t cuda_stream);
template void CalIndexAdd<float>(float *dst, const int *index, const float *src, const size_t outer_size,
const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock,
cudaStream_t cuda_stream);
template void CalIndexAdd<half>(half *dst, const int *index, const half *src, const size_t outer_size,
const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock,
cudaStream_t cuda_stream);
template void CalIndexAdd<int>(int *dst, const int *index, const int *src, const size_t outer_size,
const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock,
cudaStream_t cuda_stream);
template void CalIndexAdd<int16_t>(int16_t *dst, const int *index, const int16_t *src, const size_t outer_size,
const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock,
cudaStream_t cuda_stream);
template void CalIndexAdd<int8_t>(int8_t *dst, const int *index, const int8_t *src, const size_t outer_size,
const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock,
cudaStream_t cuda_stream);
template void CalIndexAdd<uint8_t>(uint8_t *dst, const int *index, const uint8_t *src, const size_t outer_size,
const size_t src_axis_size, const size_t dst_axis_size, const size_t inner_size, const bool use_lock,
cudaStream_t cuda_stream);

@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_
enum class IndexAddErrorCode {
kOk = 0,
kIndexOutOfRange
};
void ValidateIndexAddInputValues(const int *index, const size_t src_axis_size, const size_t dst_axis_size,
IndexAddErrorCode *error_code, cudaStream_t cuda_stream);
template <typename T>
void CalIndexAdd(T *dst, const int *index, const T *src, const size_t outer_size, const size_t src_axis_size,
const size_t dst_axis_size, const size_t inner_size, const bool use_lock, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_INDEXADD_H_

@ -0,0 +1,71 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/index_add_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(IndexAdd,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
IndexAddGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(IndexAdd,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
IndexAddGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(IndexAdd,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
IndexAddGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(IndexAdd,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
IndexAddGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(IndexAdd,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
IndexAddGpuKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(IndexAdd,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
IndexAddGpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(IndexAdd,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
IndexAddGpuKernel, uint8_t)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,155 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_INDEX_ADD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_INDEX_ADD_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/index_add_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class IndexAddGpuKernel : public GpuKernel {
public:
IndexAddGpuKernel()
: dst_size_(0),
index_size_(0),
src_size_(0),
output_size_(0),
outer_size_(0),
src_axis_size_(0),
dst_axis_size_(0),
inner_size_(0),
use_lock_(true),
check_index_bound_(true) {}
~IndexAddGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *dst = GetDeviceAddress<T>(inputs, 0);
int *index = GetDeviceAddress<int>(inputs, 1);
T *src = GetDeviceAddress<T>(inputs, 2);
T *dst_out = GetDeviceAddress<T>(outputs, 0);
if (check_index_bound_) {
IndexAddErrorCode *error_code_addr = GetDeviceAddress<IndexAddErrorCode>(workspace, 0);
IndexAddErrorCode error_code = IndexAddErrorCode::kOk;
ValidateIndexAddInputValues(index, src_axis_size_, dst_axis_size_, error_code_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_ERROR(kernel_node_,
cudaMemcpyAsync(&error_code, error_code_addr, sizeof(IndexAddErrorCode),
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Failed to copy error code to host.");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
LogExceptionIfNotOk(error_code);
}
CalIndexAdd(dst, index, src, outer_size_, src_axis_size_, dst_axis_size_, inner_size_, use_lock_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&dst_out[0], &dst[0], dst_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but index add needs 3 inputs.";
return false;
}
std::vector<size_t> dst_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
std::vector<size_t> index_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
std::vector<size_t> src_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
int64_t src_rank = src_shape.size();
int64_t axis = GetAttr<int64_t>(kernel_node, "axis");
if (axis < 0) {
axis += src_rank;
}
outer_size_ = 1;
for (int64_t i = axis - 1; i >= 0; i--) {
outer_size_ *= src_shape[i];
}
inner_size_ = 1;
for (int64_t i = axis + 1; i < src_rank; i++) {
inner_size_ *= src_shape[i];
}
src_axis_size_ = src_shape[axis];
dst_axis_size_ = dst_shape[axis];
dst_size_ = sizeof(T);
for (auto x : dst_shape) {
dst_size_ *= x;
}
index_size_ = sizeof(int);
for (auto x : index_shape) {
index_size_ *= x;
}
src_size_ = sizeof(T);
for (auto x : src_shape) {
src_size_ *= x;
}
output_size_ = dst_size_;
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(dst_size_);
input_size_list_.push_back(index_size_);
input_size_list_.push_back(src_size_);
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(sizeof(IndexAddErrorCode));
}
private:
void LogExceptionIfNotOk(IndexAddErrorCode error_code) {
switch (error_code) {
case IndexAddErrorCode::kOk:
return;
case IndexAddErrorCode::kIndexOutOfRange:
MS_LOG(EXCEPTION) << "gpu IndexAdd op error: values of index tensor is out of range";
break;
default:
MS_LOG(EXCEPTION) << "gpu IndexAdd op unknown error";
}
}
size_t dst_size_;
size_t index_size_;
size_t src_size_;
size_t output_size_;
size_t outer_size_;
size_t src_axis_size_;
size_t dst_axis_size_;
size_t inner_size_;
bool use_lock_;
bool check_index_bound_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_INDEX_ADD_GPU_KERNEL_H_

@ -56,7 +56,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
MatrixInverse)
MatrixInverse, IndexAdd)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
@ -418,6 +418,7 @@ __all__ = [
"SparseToDense",
"MatrixInverse",
"Range",
"IndexAdd",
]
__all__.sort()

@ -4199,3 +4199,68 @@ class MatrixInverse(PrimitiveWithInfer):
validator.check_int(len(x_shape), 2, Rel.GE, self.name, None)
validator.check_equal_int(x_shape[-1], x_shape[-2], self.name, None)
return x_shape
class IndexAdd(PrimitiveWithInfer):
"""
Adds tenosr y to specified axis and indices of tensor x.
Args:
axis (int): The dimension along wich to index.
Inputs:
- **input_x** (Tensor) - The input tensor to add to, with data type float64, float32, float16, int32, int16,
int8, uint8.
- **indices** (Tensor) - The index of `input_x` on the `axis`th dimension to add to, with data type int32.
The `indices` must be 1D with the size same as the size of the `axis`th dimension of `input_y`. The values
of `indices` should be in the range of 0 to the size of the `axis`th dimension of `input_x`.
- **input_y** (Tensor) - The input tensor with the value to add. Must have same data type as `input_x`.
The shape must be the same as `input_x` except the `axis`th dimension.
Outputs:
Tensor, has the same shape and dtype as input_x.
Supported Platforms:
``GPU``
Examples:
>>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [6, 7, 8]]), mindspore.float32)
>>> input_y = Tensor(np.array([[0.5, 1.0], [1.0, 1.5], [2.0, 2.5]]), mindspore.float32)
>>> indices = Tensor(np.array([0, 2]), mindspore.int32)
>>> index_add = ops.IndexAdd(axis=1)
>>> output = index_add(input_x, indices, input_y)
>>> print(output)
[[ 1.5 2. 4. ]
[ 5. 5. 7.5]
[ 8. 7. 10.5]]
"""
@prim_attr_register
def __init__(self, axis, use_lock=True, check_index_bound=True):
"""Initialize InplaceAdd"""
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
self.axis = axis
validator.check_value_type('axis', axis, [int], self.name)
def infer_dtype(self, x_dtype, idx_type, y_dtype):
args = {'x': x_dtype, 'y': y_dtype}
valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.int32, mstype.int16, mstype.int8,
mstype.uint8]
validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
valid_idx_type = [mstype.int32]
validator.check_tensor_dtype_valid("idx_type", idx_type, valid_idx_type, self.name)
return x_dtype
def infer_shape(self, x_shape, idx_shape, y_shape):
validator.check("x rank", len(x_shape), "y rank", len(y_shape), Rel.EQ, self.name)
validator.check("size of indices", idx_shape[0], "dimension of y[axis]", y_shape[self.axis],
Rel.EQ, self.name)
x_rank = len(x_shape)
validator.check_int_range(self.axis, -x_rank - 1, x_rank, Rel.INC_BOTH, 'axis', self.name)
axis = self.axis if self.axis >= 0 else x_rank + self.axis
for dim in range(x_rank):
if dim == axis:
validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.GE, self.name)
else:
validator.check('x dim %d' % dim, x_shape[dim], "y dim %d" % dim, y_shape[dim], Rel.EQ, self.name)
return x_shape

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save