add gather op on gpu

pull/8000/head
zhouyuanshen 4 years ago
parent ff7ecebdcd
commit f0f67b8aa8

@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherGpuFwdKernel, float, int)
MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
GatherGpuFwdKernel, float, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
GatherGpuFwdKernel, half, int)
MS_REG_GPU_KERNEL_TWO(
GatherD,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
GatherGpuFwdKernel, half, int64_t)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,123 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_GATHER_GPU_KERNEL_H
#define MINDSPORE_GATHER_GPU_KERNEL_H
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/gather.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class GatherGpuFwdKernel : public GpuKernel {
public:
GatherGpuFwdKernel() : axis_(0), handle_(nullptr) {}
~GatherGpuFwdKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
T *input_addr = GetDeviceAddress<T>(inputs, 0);
S *index_addr = GetDeviceAddress<S>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[2],
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuFwdKernel needs 2.";
}
input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
index_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
axis_ = GetAttr<int>(kernel_node, "dim");
if (axis_ < 0) {
axis_ = axis_ + SizeToInt(input_shapes_.size());
}
Reshape();
InitSizeLists();
return true;
}
protected:
void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); }
void InitSizeLists() override {
size_t size = GetSize(input_shapes_, true);
input_size_list_.push_back(size);
size = GetSize(index_shapes_, false);
input_size_list_.push_back(size);
size = GetSize(output_shapes_, true);
output_size_list_.push_back(size);
}
private:
void Reshape() {
size_t dim_before_axis = 1;
for (size_t i = 0; i < IntToSize(axis_); i++) {
dim_before_axis *= output_shapes_[i];
}
size_t dim_of_index = output_shapes_[IntToSize(axis_)];
size_t dim_after_index = 1;
for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) {
dim_after_index *= output_shapes_[i];
}
dims_[0] = dim_before_axis;
dims_[1] = dim_of_index;
dims_[2] = dim_after_index;
return;
}
size_t GetSize(const std::vector<size_t> &shape, const bool flag = true) const {
if (shape.size() == 0) {
return 0;
}
size_t result = flag ? sizeof(T) : sizeof(S);
for (size_t i = 0; i < shape.size(); i++) {
result *= shape[i];
}
return result;
}
std::vector<size_t> input_shapes_;
std::vector<size_t> index_shapes_;
std::vector<size_t> output_shapes_;
size_t dims_[3] = {};
int axis_;
cudnnHandle_t handle_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_GATHER_GPU_KERNEL_H

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "backend/kernel_compiler/gpu/cuda_impl/gather.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S>
__global__ void GatherKernel(const T *input, const S *index, T *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2) {
size_t num = output_dim0 * output_dim1 * output_dim2;
size_t i, k;
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < num;
id += blockDim.x * gridDim.x) {
i = id / (output_dim1 * output_dim2) % output_dim0;
k = id % output_dim2;
size_t j_read = static_cast<size_t>(index[id]);
size_t read_id = i * output_dim1 * output_dim2 + j_read * output_dim2 + k;
output[id] = input[read_id];
}
return;
}
template <typename T, typename S>
void Gather(const T *input, const S *index, T *output, const size_t output_dim0, const size_t output_dim1,
const size_t output_dim2, cudaStream_t stream) {
size_t size = output_dim0 * output_dim1 * output_dim2;
GatherKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, index, output, output_dim0, output_dim1,
output_dim2);
return;
}
template void Gather<float, int>(const float *input, const int *index, float *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
template void Gather<float, int64_t>(const float *input, const int64_t *index, float *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
template void Gather<half, int>(const half *input, const int *index, half *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);
template void Gather<half, int64_t>(const half *input, const int64_t *index, half *output, const size_t output_dim0,
const size_t output_dim1, const size_t output_dim2, cudaStream_t stream);

@ -0,0 +1,23 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_GATHER_GPU_CU_H
#define MINDSPORE_GATHER_GPU_CU_H
template <typename T, typename S>
void Gather(const T *input, const S *index, T *output, const size_t output_dim0, const size_t output_dim1,
const size_t output_dim2, cudaStream_t stream);
#endif

@ -37,6 +37,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimReduceSum->name(), {1});
Register(prim::kPrimReduceMean->name(), {1});
Register(prim::kPrimGatherV2->name(), {2});
Register(prim::kPrimGatherD->name(), {1});
Register(prim::kPrimEmbeddingLookup->name(), {2, 3, 4, 5});
Register(prim::kPrimEmbeddingLookupCommGrad->name(), {1});
Register(prim::kPrimSubscalar->name(), {1});

@ -55,6 +55,12 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
continue;
}
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) {
auto ms_context = MsContext::GetInstance();
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
continue;
}
}
if (AnfAlgo::IsDynamicShape(cnode)) {
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
continue;

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -590,6 +590,13 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int> *te
if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) {
reg_exist = false;
}
if (op_run_info->op_name == prim::kPrimGatherD->name()) {
auto ms_context = MsContext::GetInstance();
// Gather op needs converting const input to attr on GPU device
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
reg_exist = false;
}
}
op_prim->BeginRecordAddAttr();
size_t input_num = op_run_info->op_inputs.size();

@ -84,6 +84,7 @@ inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
inline const PrimitivePtr kPrimGatherV2 = std::make_shared<Primitive>("GatherV2");
inline const PrimitivePtr kPrimGatherD = std::make_shared<Primitive>("GatherD");
inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2");
inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape");

@ -3978,6 +3978,7 @@ class GatherD(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""Initialize GatherD"""
self.init_prim_io_names(inputs=['x', 'dim', 'index'], outputs=['output'])
def __infer__(self, x, dim, index):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save