new gpu op for cbg repeat_elements

fixed ci

fixed ci

addressed comments
pull/7537/MERGE^2
Peilin Wang 4 years ago
parent db0868d745
commit bd0b462691

@ -0,0 +1,28 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include "backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(RepeatElements, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
RepeatElementsGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(RepeatElements, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
RepeatElementsGpuKernel, int32_t)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,161 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GPU_KERNEL_H_
#include "backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cuh"
#include <cuda_runtime.h>
#include <algorithm>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class RepeatElementsGpuKernel : public GpuKernel {
public:
RepeatElementsGpuKernel() : rep_(1), axis_(0), input_size_(1), output_size_(0) {}
~RepeatElementsGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input_device_address = GetDeviceAddress<T>(inputs, 0);
T *output_device_address = GetDeviceAddress<T>(outputs, 0);
switch (input_dim_) {
case 1:
CalRepeatElements1d(input_device_address, rep_, axis_, output_device_address, output_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
case 2:
CalRepeatElements2d(input_device_address, input_shape_[1], rep_, axis_, output_device_address, output_shape_[1],
output_size_, reinterpret_cast<cudaStream_t>(stream_ptr));
break;
case 3:
CalRepeatElements3d(input_device_address, input_shape_[1], input_shape_[2], rep_, axis_, output_device_address,
output_shape_[1], output_shape_[2], output_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
case 4:
CalRepeatElements4d(input_device_address, input_shape_[1], input_shape_[2], input_shape_[3], rep_, axis_,
output_device_address, output_shape_[1], output_shape_[2], output_shape_[3], output_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
case 5:
CalRepeatElements5d(input_device_address, input_shape_[1], input_shape_[2], input_shape_[3], input_shape_[4],
rep_, axis_, output_device_address, output_shape_[1], output_shape_[2], output_shape_[3],
output_shape_[4], output_size_, reinterpret_cast<cudaStream_t>(stream_ptr));
break;
default:
int *input_shape_device_address = GetDeviceAddress<int>(workspace, 0);
int *output_shape_device_address = GetDeviceAddress<int>(workspace, 1);
int *input_shape_cumulative_product_device_address = GetDeviceAddress<int>(workspace, 2);
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(input_shape_device_address, input_shape_.data(), workspace_size_list_[0],
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_shape failed");
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(output_shape_device_address, output_shape_.data(), workspace_size_list_[1],
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output_shape failed");
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(input_shape_cumulative_product_device_address, input_shape_cumulative_product_.data(),
workspace_size_list_[2], cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_shape_cumulative_product_device_address failed");
CalRepeatElements(input_device_address, input_dim_, input_shape_device_address,
input_shape_cumulative_product_device_address, rep_, axis_, output_device_address,
output_shape_device_address, output_size_, reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_count != 1) {
MS_LOG(EXCEPTION) << input_count << " arguments were provided, but RepeatElementGpuKernel expects 1.";
}
std::vector<size_t> temp_input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_dim_ = temp_input_shape.size();
for (size_t e : temp_input_shape) {
input_size_ *= e;
input_shape_.push_back(e);
}
int cumulative_product = 1;
for (size_t i = input_dim_ - 1; i > 0; i--) {
cumulative_product *= input_shape_[i];
input_shape_cumulative_product_.push_back(cumulative_product);
}
std::reverse(input_shape_cumulative_product_.begin(), input_shape_cumulative_product_.end());
axis_ = GetAttr<int>(kernel_node, "axis");
if (axis_ < 0) {
axis_ += input_dim_;
}
rep_ = GetAttr<int>(kernel_node, "rep");
output_size_ = input_size_ * rep_;
output_shape_ = input_shape_;
output_shape_[axis_] *= rep_;
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
output_size_list_.push_back(output_size_ * sizeof(T));
// workspaces for input shape, output shape and cumulative sum
workspace_size_list_.push_back(input_dim_ * sizeof(int));
workspace_size_list_.push_back(input_dim_ * sizeof(int));
workspace_size_list_.push_back((input_dim_ - 1) * sizeof(int));
}
private:
int rep_;
int axis_;
int input_dim_;
std::vector<int> input_shape_;
std::vector<int> input_shape_cumulative_product_;
std::vector<int> output_shape_;
size_t input_size_;
size_t output_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GPU_KERNEL_H_

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_H_
#include <cuda_runtime.h>
#define REPEAT_ELEMENTS_MAX_INPUT_DIM 100
template <typename T>
void CalRepeatElements1d(
const T *input, const int rep, const int axis, T *output, const int output_size, cudaStream_t cuda_stream);
template <typename T>
void CalRepeatElements2d(const T *input, const int input_d1, const int rep, const int axis, T *output,
const int output_d1, const int output_size, cudaStream_t cuda_stream);
template <typename T>
void CalRepeatElements3d(const T *input, const int input_d1, const int input_d2, const int rep, const int axis,
T *output, const int output_d1, const int output_d2, const int output_size,
cudaStream_t cuda_stream);
template <typename T>
void CalRepeatElements4d(const T *input, const int input_d1, const int input_d2, const int input_d3, const int rep,
const int axis, T *output, const int output_d1, const int output_d2, const int output_d3,
const int output_size, cudaStream_t cuda_stream);
template <typename T>
void CalRepeatElements5d(const T *input, const int input_d1, const int input_d2, const int input_d3, const int input_d4,
const int rep, const int axis, T *output, const int output_d1, const int output_d2,
const int output_d3, const int output_d4, const int output_size, cudaStream_t cuda_stream);
template <typename T>
void CalRepeatElements(const T *input, const int input_dim, const int* const input_shape,
const int* const input_shape_cumulative_product, const int rep, const int axis, T *output,
const int* const output_shape, const int output_size, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_H_

@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
Unique, GatherD, Identity)
Unique, GatherD, Identity, RepeatElements)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice,
@ -381,7 +381,8 @@ __all__ = [
"Push",
"Pull",
"ReLUV2",
'SparseToDense',
"SparseToDense",
"RepeatElements",
]
__all__.sort()

@ -4022,3 +4022,52 @@ class Identity(PrimitiveWithInfer):
'dtype': x['dtype'],
'value': None}
return out
class RepeatElements(PrimitiveWithInfer):
"""
Repeat elements of a tensor along an axis, like np.repeat.
Args:
rep (int): The number of times to repeat, must be positive, required.
axis (int): The axis along which to repeat, default 0.
Inputs:
- **x** (Tensor) - The tensor to repeat values for. Must be of type int32 or float16.
Outputs:
One tensor with values repeated along the specified axis. If x has shape
(s1, s2, ..., sn) and axis is i, the output will have shape (s1, s2, ..., si * rep, ..., sn)
Examples:
>>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32)
>>> repeat_elements = P.RepeatElements(rep = 2, axis = 0)
>>> output = repeat_elements(x)
[[0, 1, 2],
[0, 1, 2],
[3, 4, 5],
[3, 4, 5]],
"""
@prim_attr_register
def __init__(self, rep, axis=0):
self.init_prim_io_names(inputs=["x"], outputs=["output"])
validator.check_value_type("rep", rep, [int], self.name)
self.rep = rep
validator.check_value_type("axis", axis, [int], self.name)
self.axis = axis
def infer_shape(self, x_shape):
validator.check("rep", self.rep, "", 0, Rel.GT, self.name)
validator.check("axis", self.axis, "dimension of x", len(x_shape), Rel.LT, self.name)
validator.check("axis", self.axis, "negative dimension of x", -len(x_shape), Rel.GE, self.name)
x_shape[self.axis] *= self.rep
return x_shape
def infer_dtype(self, x_dtype):
validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
return x_dtype

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save