From 46afb18e2586bf61b311d9c26e2b4514ace9d34c Mon Sep 17 00:00:00 2001 From: VectorSL Date: Tue, 16 Jun 2020 14:56:59 +0800 Subject: [PATCH] update argmaxwithvalue --- .../gpu/arrays/argmaxwithvalue_gpu_kernel.h | 39 +++++++------------ .../gpu/cuda_impl/argmaxwithvalue_impl.cu | 6 +-- .../gpu/cuda_impl/argmaxwithvalue_impl.cuh | 2 +- 3 files changed, 17 insertions(+), 30 deletions(-) diff --git a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h index 9d42f31eb6..fb7796b022 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h @@ -26,15 +26,7 @@ namespace kernel { template class ArgmaxWithValueGpuKernel : public GpuKernel { public: - ArgmaxWithValueGpuKernel() - : input_size_(0), - output_size_(0), - workspace_size_(0), - axis_(0), - dims_(1), - bound_(0), - outerSize_(0), - innerSize_(0) {} + ArgmaxWithValueGpuKernel() : input_size_(0), output_size_(0), bound_(0), outerSize_(0), innerSize_(0) {} ~ArgmaxWithValueGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -46,37 +38,36 @@ class ArgmaxWithValueGpuKernel : public GpuKernel { T *input = GetDeviceAddress(inputs, 0); T *output = GetDeviceAddress(outputs, 1); S *index = GetDeviceAddress(outputs, 0); - CalArgmaxWithValue(input_size_ / sizeof(T), input, bound_, outerSize_, innerSize_, axis_, dims_, index, output, + CalArgmaxWithValue(input_size_ / sizeof(T), input, bound_, outerSize_, innerSize_, index, output, reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { - shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + std::vector shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1); - dims_ = shape_.size(); - - axis_ = GetAttr(kernel_node, "axis"); - if (axis_ < 0) { - axis_ += dims_; + int dims = shape.size(); + int axis = GetAttr(kernel_node, "axis"); + if (axis < 0) { + axis += dims; } input_size_ = sizeof(T); - for (auto x : shape_) { + for (auto x : shape) { input_size_ *= x; } output_size_ = sizeof(S); for (auto x : output_shape) { output_size_ *= x; } - bound_ = shape_[axis_]; + bound_ = shape[axis]; outerSize_ = 1; - for (int i = axis_ - 1; i >= 0; i--) { - outerSize_ *= shape_[i]; + for (int i = axis - 1; i >= 0; i--) { + outerSize_ *= shape[i]; } innerSize_ = 1; - for (int i = axis_ + 1; i < dims_; i++) { - innerSize_ *= shape_[i]; + for (int i = axis + 1; i < dims; i++) { + innerSize_ *= shape[i]; } InitSizeLists(); return true; @@ -92,13 +83,9 @@ class ArgmaxWithValueGpuKernel : public GpuKernel { private: size_t input_size_; size_t output_size_; - size_t workspace_size_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; - std::vector shape_; - int axis_; - int dims_; int bound_; int outerSize_; int innerSize_; diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu index 47a794cdcd..a0687a2768 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu @@ -44,15 +44,15 @@ __global__ void ArgmaxWithValue(size_t size, const T* input, const int bound, in template void CalArgmaxWithValue(size_t size, const T* input, const int bound_, const int outerSize_, const int innerSize_, - int axis_, int dims_, S* index, T* output, cudaStream_t cuda_stream) { + S* index, T* output, cudaStream_t cuda_stream) { ArgmaxWithValue<<>>(size, input, bound_, outerSize_, innerSize_, index, output); return; } template void CalArgmaxWithValue(size_t size, const float* input, const int bound_, const int outerSize_, - const int innerSize_, int axis_, int dims_, int* index, float* output, + const int innerSize_, int* index, float* output, cudaStream_t cuda_stream); template void CalArgmaxWithValue(size_t size, const half* input, const int bound_, const int outerSize_, - const int innerSize_, int axis_, int dims_, int* index, half* output, + const int innerSize_, int* index, half* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh index eebe4c8fa6..0d4f4b62a3 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh @@ -18,5 +18,5 @@ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ template void CalArgmaxWithValue(size_t size, const T* input, const int bound_, const int outerSize_, const int innerSize_, - int axis_, int dims_, S* index, T* output, cudaStream_t cuda_stream); + S* index, T* output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_