parent
45ad430af2
commit
5c0962acfa
@ -0,0 +1,31 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SplitGpuFwdKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Split,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SplitGpuFwdKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Split, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SplitGpuFwdKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,153 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class SplitGpuFwdKernel : public GpuKernel {
|
||||
public:
|
||||
SplitGpuFwdKernel()
|
||||
: axis_(0),
|
||||
output_num_(1),
|
||||
input_size_(1),
|
||||
axis_step_(1),
|
||||
all_size_before_axis_(1),
|
||||
all_size_axis_(1),
|
||||
outputs_host_(nullptr) {}
|
||||
~SplitGpuFwdKernel() override = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T **outputs_device = GetDeviceAddress<T *>(workspace, 0);
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
outputs_host_[i] = GetDeviceAddress<T>(outputs, i);
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(outputs_device, outputs_host_.get(), sizeof(T *) * output_num_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Split opt cudaMemcpyAsync outputs failed");
|
||||
SplitKernel(input_size_, axis_step_, all_size_before_axis_, all_size_axis_, input, outputs_device,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
axis_ = GetAttr<int>(kernel_node, "axis");
|
||||
if (axis_ < 0) {
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
axis_ += SizeToInt(input_shape.size());
|
||||
}
|
||||
output_num_ = GetAttr<int>(kernel_node, "output_num");
|
||||
|
||||
if (!CheckParam(kernel_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
input_size_ = 1;
|
||||
all_size_before_axis_ = 1;
|
||||
all_size_axis_ = 1;
|
||||
|
||||
for (int i = 0; i < SizeToInt(input_shape.size()); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
if (i > axis_) {
|
||||
all_size_before_axis_ *= input_shape[i];
|
||||
all_size_axis_ *= input_shape[i];
|
||||
}
|
||||
if (i == axis_) {
|
||||
all_size_before_axis_ *= input_shape[i];
|
||||
}
|
||||
}
|
||||
input_size_list_.push_back(IntToSize(input_size_ * sizeof(T)));
|
||||
axis_step_ = input_shape[axis_] / output_num_;
|
||||
|
||||
for (int i = 0; i < output_num_; i++) {
|
||||
size_t output_size = 1;
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, i);
|
||||
for (size_t j = 0; j < output_shape.size(); j++) {
|
||||
output_size *= output_shape[j];
|
||||
}
|
||||
output_size_list_.push_back(output_size * sizeof(T));
|
||||
}
|
||||
workspace_size_list_.push_back(sizeof(T *) * output_num_);
|
||||
InitSizeLists();
|
||||
outputs_host_ = std::make_unique<T *[]>(output_num_);
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {}
|
||||
|
||||
private:
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
int dims = SizeToInt(input_shape.size());
|
||||
int output_num = SizeToInt(AnfAlgo::GetOutputTensorNum(kernel_node));
|
||||
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but Split needs 1 input.";
|
||||
return false;
|
||||
}
|
||||
if (dims == 0) {
|
||||
MS_LOG(ERROR) << "Input dims is " << dims << ", scalar is not supported.";
|
||||
return false;
|
||||
}
|
||||
if (axis_ < -dims || axis_ >= dims) {
|
||||
MS_LOG(ERROR) << "Attr axis " << axis_ << " must be in " << -dims << "~" << dims;
|
||||
return false;
|
||||
}
|
||||
if (output_num_ > SizeToInt(input_shape[axis_])) {
|
||||
MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must less than" << input_shape[axis_];
|
||||
return false;
|
||||
}
|
||||
if (input_shape[axis_] % output_num_ != 0) {
|
||||
MS_LOG(ERROR) << "Attr output_num " << output_num_ << "must be divided by" << input_shape[axis_];
|
||||
return false;
|
||||
}
|
||||
if (output_num_ != output_num) {
|
||||
MS_LOG(ERROR) << "Output num is " << output_num << ", but need " << output_num_;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
int axis_;
|
||||
int output_num_;
|
||||
int input_size_;
|
||||
int axis_step_;
|
||||
int all_size_before_axis_;
|
||||
int all_size_axis_;
|
||||
std::unique_ptr<T *[]> outputs_host_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H
|
@ -0,0 +1,50 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh"
|
||||
template <typename T>
|
||||
__global__ void Split(const int size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const T* input, T** outputs) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
int num = pos % all_size_before_axis / all_size_axis;
|
||||
int block = num / axis_step;
|
||||
int block_pos = pos / all_size_before_axis * axis_step * all_size_axis +
|
||||
num % axis_step * all_size_axis + pos % all_size_axis;
|
||||
outputs[block][block_pos] = input[pos];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream) {
|
||||
Split<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, axis_step, all_size_before_axis,
|
||||
all_size_axis, input, outputs);
|
||||
return;
|
||||
}
|
||||
|
||||
template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const float* input, float** outputs,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const int* input, int** outputs,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const half* input, half** outputs,
|
||||
cudaStream_t cuda_stream);
|
@ -0,0 +1,24 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void SplitKernel(const int size, const int axis_step, const int all_size_before_axis,
|
||||
const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
|
@ -0,0 +1,58 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0, out_nums=1):
|
||||
super(Net, self).__init__()
|
||||
self.split = P.Split(axis, out_nums)
|
||||
|
||||
def construct(self, x):
|
||||
return self.split(x)
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_split():
|
||||
x = np.array([[[1, -1, 1], [2, -2, 2]],
|
||||
[[3, -3, 3], [4, -4, 4]],
|
||||
[[5, -5, 5], [6, -6, 6]]]).astype(np.float32)
|
||||
|
||||
split_op = Net(0, 3)
|
||||
outputs = split_op(Tensor(x))
|
||||
for i, out in enumerate(outputs):
|
||||
assert (out.asnumpy() == x[i]).all()
|
||||
|
||||
|
||||
def test_split_4d():
|
||||
x_np = np.random.randn(2, 6, 4, 4).astype(np.float32)
|
||||
y = np.split(x_np, 3, axis=1)
|
||||
|
||||
split_op = Net(1, 3)
|
||||
outputs = split_op(Tensor(x_np))
|
||||
|
||||
for i, out in enumerate(outputs):
|
||||
assert (out.asnumpy() == y[i]).all()
|
Loading…
Reference in new issue