!370 Gpu Support UnsortedSegmentSum kernel
Merge pull request !370 from chenweifeng/unsorted_segment_sumpull/370/MERGE
commit
0edc6d254a
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
UnsortedSegmentSum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentSumGpuKernel, float, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
UnsortedSegmentSum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentSumGpuKernel, float, int64_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
UnsortedSegmentSum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentSumGpuKernel, int, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
UnsortedSegmentSum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentSumGpuKernel, int, int64_t)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,90 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_
|
||||
|
||||
#include <vector>
|
||||
#include "kernel/gpu/gpu_kernel.h"
|
||||
#include "kernel/gpu/gpu_kernel_factory.h"
|
||||
#include "kernel/gpu/cuda_impl/unsorted_segment_sum.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class UnsortedSegmentSumGpuKernel : public GpuKernel {
|
||||
public:
|
||||
UnsortedSegmentSumGpuKernel() : input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1) {}
|
||||
~UnsortedSegmentSumGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemsetAsync(output_addr, 0, outputs[0]->size, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemSet Failed");
|
||||
UnsortedSegmentSum(input_dim0_, input_dim1_, output_dim0_, output_dim1_, input_addr, indices_addr, output_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
|
||||
input_dim0_ = input_shapes[0];
|
||||
for (size_t i = 1; i < input_shapes.size(); i++) {
|
||||
input_dim1_ *= input_shapes[i];
|
||||
}
|
||||
|
||||
output_dim0_ = output_shapes[0];
|
||||
for (size_t i = 1; i < output_shapes.size(); i++) {
|
||||
output_dim1_ *= output_shapes[i];
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_dim0_ * input_dim1_ * sizeof(T));
|
||||
input_size_list_.push_back(output_dim0_ * sizeof(S));
|
||||
input_size_list_.push_back(output_dim0_ * sizeof(int));
|
||||
output_size_list_.push_back(output_dim0_ * output_dim1_ * sizeof(S));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t input_dim0_;
|
||||
size_t input_dim1_;
|
||||
size_t output_dim0_;
|
||||
size_t output_dim1_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_
|
@ -0,0 +1,56 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/cuda_impl/unsorted_segment_sum.cuh"
|
||||
|
||||
template<typename T, typename S>
|
||||
__global__ void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
|
||||
T* input_addr, S* ids_addr, T* output_addr) {
|
||||
for (int input_index = blockIdx.x * blockDim.x + threadIdx.x; input_index < input_dim0 * input_dim1;
|
||||
input_index += blockDim.x * gridDim.x) {
|
||||
size_t j = input_index / input_dim1;
|
||||
size_t k = input_index % input_dim1;
|
||||
|
||||
S i = ids_addr[j];
|
||||
if (i < 0 || i >= output_dim0) {
|
||||
continue;
|
||||
}
|
||||
size_t output_index = i * output_dim1 + k;
|
||||
atomicAdd(output_addr + output_index, input_addr[input_index]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename S>
|
||||
void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
|
||||
T* input_addr, S* ids_addr, T* output_addr, cudaStream_t stream) {
|
||||
int size = input_dim0 * input_dim1;
|
||||
UnsortedSegmentSum<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input_dim0, input_dim1,
|
||||
output_dim0, output_dim1, input_addr, ids_addr, output_addr);
|
||||
return;
|
||||
}
|
||||
|
||||
template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
|
||||
float* input_addr, int* ids_addr, float* output_addr, cudaStream_t stream);
|
||||
template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
|
||||
float* input_addr, int64_t* ids_addr, float* output_addr, cudaStream_t stream);
|
||||
|
||||
template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
|
||||
int* input_addr, int* ids_addr, int* output_addr, cudaStream_t stream);
|
||||
template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
|
||||
int* input_addr, int64_t* ids_addr, int* output_addr, cudaStream_t stream);
|
||||
|
||||
|
||||
|
@ -0,0 +1,27 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
template<typename T, typename S>
|
||||
void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1,
|
||||
T* input_addr, S* ids, T* output_addr, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_
|
@ -0,0 +1,111 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
context.set_context(device_target='GPU')
|
||||
|
||||
class UnsortedSegmentSumNet(nn.Cell):
|
||||
def __init__(self, num_segments):
|
||||
super(UnsortedSegmentSumNet, self).__init__()
|
||||
self.unsorted_segment_sum = P.UnsortedSegmentSum()
|
||||
self.num_segments = num_segments
|
||||
|
||||
def construct(self, data, ids):
|
||||
return self.unsorted_segment_sum(data, ids, self.num_segments)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_1D():
|
||||
input_x = Tensor([1, 2, 3, 4], mstype.float32)
|
||||
segment_ids = Tensor([0, 0, 1, 2], mstype.int32)
|
||||
num_segments = 4
|
||||
|
||||
net = UnsortedSegmentSumNet(num_segments)
|
||||
output = net(input_x, segment_ids)
|
||||
expect = [3, 3, 4, 0]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_2D():
|
||||
input_x = Tensor([[1, 2, 3, 4],
|
||||
[5, 6, 7, 8],
|
||||
[9, 10, 11, 12]], mstype.float32)
|
||||
segment_ids = Tensor([2, 1, 1], mstype.int32)
|
||||
num_segments = 4
|
||||
|
||||
net = UnsortedSegmentSumNet(num_segments)
|
||||
output = net(input_x, segment_ids)
|
||||
expect = [[ 0, 0, 0, 0],
|
||||
[14, 16, 18, 20],
|
||||
[ 1, 2, 3, 4],
|
||||
[ 0, 0, 0, 0]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_3D():
|
||||
input_x = Tensor(np.arange(4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3))
|
||||
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
|
||||
num_segments = 5
|
||||
|
||||
net = UnsortedSegmentSumNet(num_segments)
|
||||
output = net(input_x, segment_ids)
|
||||
expect = [[[ 0., 0., 0.],
|
||||
[ 0., 0., 0.],
|
||||
[ 0., 0., 0.],
|
||||
[ 0., 0., 0.],
|
||||
[ 0., 0., 0.]],
|
||||
|
||||
[[45., 47., 49.],
|
||||
[51., 53., 55.],
|
||||
[57., 59., 61.],
|
||||
[63., 65., 67.],
|
||||
[69., 71., 73.]],
|
||||
|
||||
[[ 0., 1., 2.],
|
||||
[ 3., 4., 5.],
|
||||
[ 6., 7., 8.],
|
||||
[ 9., 10., 11.],
|
||||
[12., 13., 14.]],
|
||||
|
||||
[[ 0., 0., 0.],
|
||||
[ 0., 0., 0.],
|
||||
[ 0., 0., 0.],
|
||||
[ 0., 0., 0.],
|
||||
[ 0., 0., 0.]],
|
||||
|
||||
[[ 0., 0., 0.],
|
||||
[ 0., 0., 0.],
|
||||
[ 0., 0., 0.],
|
||||
[ 0., 0., 0.],
|
||||
[ 0., 0., 0.]]]
|
||||
assert (output.asnumpy() == expect).all()
|
Loading…
Reference in new issue