!11089 General reduction with hybrid mode

From: @jonwe
Reviewed-by: @robingrosman,@tom__chen,@robingrosman,@tom__chen
Signed-off-by: @tom__chen
pull/11089/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit f07bb3bd04

@ -20,7 +20,7 @@
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
@ -38,8 +38,8 @@ class ArgmaxWithValueGpuKernel : public GpuKernel {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 1);
S *index = GetDeviceAddress<S>(outputs, 0);
CalArgmaxWithValue(input, bound_, outerSize_, innerSize_, index, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalGeneralReduction(false, input, bound_, outerSize_, innerSize_, index, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}

@ -1,55 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "argmaxwithvalue_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
#include "include/cuda_fp16.h"
template <typename T, typename S>
__global__ void ArgmaxWithValue(const T *input, const S bound, size_t outerSize,
size_t innerSize, S *index, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outerSize * innerSize;
pos += gridDim.x * blockDim.x) {
size_t x = pos / innerSize % outerSize;
size_t y = pos % innerSize;
S idx = 0;
size_t InputOffset = x * bound * innerSize + 0 * innerSize + y;
T maxData = input[InputOffset];
for (S i = 0; i < bound; i++) {
InputOffset = x * bound * innerSize + i * innerSize + y;
auto inputData = input[InputOffset];
idx = inputData > maxData ? i : idx;
maxData = inputData > maxData ? inputData : maxData;
}
output[pos] = maxData;
index[pos] = idx;
}
return;
}
template <typename T, typename S>
void CalArgmaxWithValue(const T *input, const S bound_, const size_t outerSize_, const size_t innerSize_,
S *index, T *output, cudaStream_t cuda_stream) {
ArgmaxWithValue<<<GET_BLOCKS(outerSize_), GET_THREADS, 0, cuda_stream>>>(input, bound_, outerSize_, innerSize_,
index, output);
return;
}
template void CalArgmaxWithValue<float, int>(const float *input, const int bound_, const size_t outerSize_,
const size_t innerSize_, int *index, float *output,
cudaStream_t cuda_stream);
template void CalArgmaxWithValue<half, int>(const half *input, const int bound_, const size_t outerSize_,
const size_t innerSize_, int *index, half *output,
cudaStream_t cuda_stream);

@ -14,9 +14,9 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GENERAL_REDUCTION_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GENERAL_REDUCTION_H_
template <typename T, typename S>
void CalArgmaxWithValue(const T *input, const S bound_, const size_t outerSize_, const size_t innerSize_, S *index,
T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_
void CalGeneralReduction(bool small, const T *input, const size_t bound_, const size_t outerSize_,
const size_t innerSize_, S *index, T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GENERAL_REDUCTION_H_

@ -35,18 +35,24 @@ class NetArgmaxWithValue(nn.Cell):
return (self.argmax1(x), self.argmax2(x), self.argmax3(x))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmaxwithvalue():
class NetArgmaxWithValueBig(nn.Cell):
def __init__(self, axis=0):
super(NetArgmaxWithValueBig, self).__init__()
self.argmax = P.ArgMaxWithValue(axis)
def construct(self, x):
return self.argmax(x)
def argmaxwithvalue_base(data_type):
x = Tensor(np.array([[1., 20., 5.],
[67., 8., 9.],
[130., 24., 15.],
[0.3, -0.4, -15.]]).astype(np.float32))
expect1 = np.array([2, 2, 2]).astype(np.float32)
expect2 = np.array([1, 0, 0, 0]).astype(np.float32)
expect11 = np.array([130, 24, 15]).astype(np.float32)
expect22 = np.array([20, 67, 130, 0.3]).astype(np.float32)
[0.3, -0.4, -15.]]).astype(data_type))
expect1 = np.array([2, 2, 2]).astype(data_type)
expect2 = np.array([1, 0, 0, 0]).astype(data_type)
expect11 = np.array([130, 24, 15]).astype(data_type)
expect22 = np.array([20, 67, 130, 0.3]).astype(data_type)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
argmax = NetArgmaxWithValue()
output = argmax(x)
@ -66,3 +72,75 @@ def test_argmaxwithvalue():
assert (output[1][1].asnumpy() == expect22).all()
assert (output[2][0].asnumpy() == expect1).all()
assert (output[2][1].asnumpy() == expect11).all()
def argmaxwithvalue_3d(data_type, shape_x):
np.random.seed(876)
x_np = np.random.random(shape_x).astype(data_type)
x = Tensor(x_np)
argmax = NetArgmaxWithValueBig(0)
output = argmax(x)
expect1 = np.argmax(x_np, axis=0)
expect2 = np.maximum.reduce(x_np, 0)
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()
argmax = NetArgmaxWithValueBig(1)
output = argmax(x)
expect1 = np.argmax(x_np, axis=1)
expect2 = np.maximum.reduce(x_np, 1)
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()
argmax = NetArgmaxWithValueBig(2)
output = argmax(x)
expect1 = np.argmax(x_np, axis=2)
expect2 = np.maximum.reduce(x_np, 2)
assert (output[0].asnumpy() == expect1).all()
assert (output[1].asnumpy() == expect2).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmaxwithvalue_base_float32():
argmaxwithvalue_base(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmaxwithvalue_base_float16():
argmaxwithvalue_base(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmaxwithvalue_3d_float32():
shape_x = (2, 32, 256)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
argmaxwithvalue_3d(np.float32, shape_x)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
argmaxwithvalue_3d(np.float32, shape_x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmaxwithvalue_3d_float16():
shape_x = (2, 32, 16)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
argmaxwithvalue_3d(np.float16, shape_x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_argmaxwithvalue_3d_big_float32():
shape_x = (128, 1024, 1)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
argmaxwithvalue_3d(np.float32, shape_x)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
argmaxwithvalue_3d(np.float32, shape_x)

Loading…
Cancel
Save