From 6c8f275bdd039f3ead3f263c9ee3dd3330ab3429 Mon Sep 17 00:00:00 2001 From: Jonathan Yan Date: Thu, 7 Jan 2021 20:28:23 -0500 Subject: [PATCH] redcution v1 --- .../gpu/arrays/argmaxwithvalue_gpu_kernel.h | 6 +- .../gpu/cuda_impl/argmaxwithvalue_impl.cu | 55 --- .../gpu/cuda_impl/general_reduction_impl.cu | 321 ++++++++++++++++++ ...ue_impl.cuh => general_reduction_impl.cuh} | 10 +- tests/st/ops/gpu/test_argmaxwithvalue_op.py | 96 +++++- 5 files changed, 416 insertions(+), 72 deletions(-) delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu rename mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/{argmaxwithvalue_impl.cuh => general_reduction_impl.cuh} (61%) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h index dd0b6f91a0..2862715508 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h @@ -20,7 +20,7 @@ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" -#include "backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh" namespace mindspore { namespace kernel { template @@ -38,8 +38,8 @@ class ArgmaxWithValueGpuKernel : public GpuKernel { T *input = GetDeviceAddress(inputs, 0); T *output = GetDeviceAddress(outputs, 1); S *index = GetDeviceAddress(outputs, 0); - CalArgmaxWithValue(input, bound_, outerSize_, innerSize_, index, output, - reinterpret_cast(stream_ptr)); + CalGeneralReduction(false, input, bound_, outerSize_, innerSize_, index, output, + reinterpret_cast(stream_ptr)); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu deleted file mode 100644 index 8bafcecf1e..0000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "argmaxwithvalue_impl.cuh" -#include "runtime/device/gpu/cuda_common.h" -#include "include/cuda_fp16.h" -template -__global__ void ArgmaxWithValue(const T *input, const S bound, size_t outerSize, - size_t innerSize, S *index, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outerSize * innerSize; - pos += gridDim.x * blockDim.x) { - size_t x = pos / innerSize % outerSize; - size_t y = pos % innerSize; - S idx = 0; - size_t InputOffset = x * bound * innerSize + 0 * innerSize + y; - T maxData = input[InputOffset]; - for (S i = 0; i < bound; i++) { - InputOffset = x * bound * innerSize + i * innerSize + y; - auto inputData = input[InputOffset]; - idx = inputData > maxData ? i : idx; - maxData = inputData > maxData ? inputData : maxData; - } - output[pos] = maxData; - index[pos] = idx; - } - return; -} - -template -void CalArgmaxWithValue(const T *input, const S bound_, const size_t outerSize_, const size_t innerSize_, - S *index, T *output, cudaStream_t cuda_stream) { - ArgmaxWithValue<<>>(input, bound_, outerSize_, innerSize_, - index, output); - return; -} - -template void CalArgmaxWithValue(const float *input, const int bound_, const size_t outerSize_, - const size_t innerSize_, int *index, float *output, - cudaStream_t cuda_stream); -template void CalArgmaxWithValue(const half *input, const int bound_, const size_t outerSize_, - const size_t innerSize_, int *index, half *output, - cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu new file mode 100644 index 0000000000..1444b2f958 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cu @@ -0,0 +1,321 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" +#include "backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh" + +const int kWarpSize = 32; +const int kBlockSize = 512; +const int kWarpGroup = 4; +const int kNumWarps = kBlockSize / kWarpSize; // 16 +const int kGroupSize = kWarpGroup * kWarpSize; // 128 + +// Mode selection constant +const int kMaxThreadLoop = 4; +const int kMaxWarpLoop = kWarpSize * 3; // 32 * 3 = 96 +const int kMaxGroupLoop = kGroupSize * 3; // 128 * 3 = + // 384 + +template +struct Cmp { + __device__ static inline bool lt(T a, T b) { return a <= b; } + __device__ static inline bool gt(T a, T b) { return a >= b; } +}; + +template +inline __device__ void ConditionAssign(bool is_assign, T *x, const T &y) { + (*x) = is_assign ? y : (*x); +} + +template +__global__ void ThreadReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, + T *output, S *output_index, bool fp16_flag, T init_K) { + if (fp16_flag) { + init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504); + } + + const S init_V = static_cast(-1); + + for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < outer_size * inner_size; + t_idx += blockDim.x * gridDim.x) { + int outer_id = t_idx / inner_size; + int inner_id = t_idx % inner_size; + + T threadK = init_K; + S threadV = init_V; + + for (int i = 0; i < bound; i++) { + T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id]; + S other_V = i; + bool is_winner = small ? Cmp::gt(threadK, other_K) : Cmp::lt(threadK, other_K); + ConditionAssign(is_winner, &threadK, other_K); + ConditionAssign(is_winner, &threadV, other_V); + } + + output[outer_id * inner_size + inner_id] = threadK; + output_index[outer_id * inner_size + inner_id] = threadV; + } +} + +template +__global__ void WarpReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, T *output, + S *output_index, bool fp16_flag, T init_K) { + if (fp16_flag) { + init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504); + } + const S init_V = static_cast(-1); + + for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kWarpSize * outer_size * inner_size; + t_idx += blockDim.x * gridDim.x) { + int outer_id = t_idx / kWarpSize / inner_size; + int inner_id = t_idx / kWarpSize % inner_size; + + int laneId = threadIdx.x % kWarpSize; + + T threadK = init_K; + S threadV = init_V; + + for (int i = laneId; i < bound; i += kWarpSize) { + T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id]; + S other_V = i; + bool is_winner = small ? Cmp::gt(threadK, other_K) : Cmp::lt(threadK, other_K); + ConditionAssign(is_winner, &threadK, other_K); + ConditionAssign(is_winner, &threadV, other_V); + } + __syncwarp(); + + for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { + T other_K = __shfl_down_sync(0xffffffff, threadK, offset); + S other_V = __shfl_down_sync(0xffffffff, threadV, offset); + + bool is_winner = small ? Cmp::gt(threadK, other_K) : Cmp::lt(threadK, other_K); + ConditionAssign(is_winner, &threadK, other_K); + ConditionAssign(is_winner, &threadV, other_V); + } + + __syncwarp(); + + if (laneId == 0) { + output[outer_id * inner_size + inner_id] = threadK; + output_index[outer_id * inner_size + inner_id] = threadV; + } + __syncthreads(); + } +} + +template +__global__ void Warp4Reduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, + T *output, S *output_index, bool fp16_flag, T init_K) { + __shared__ T shared_K[kNumWarps]; + __shared__ S shared_V[kNumWarps]; + if (fp16_flag) { + init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504); + } + const S init_V = static_cast(-1); + + for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kGroupSize * outer_size * inner_size; + t_idx += blockDim.x * gridDim.x) { + int outer_id = t_idx / kGroupSize / inner_size; + int inner_id = t_idx / kGroupSize % inner_size; + + int groupId = threadIdx.x / kGroupSize; + int tgId = threadIdx.x % kGroupSize; + int warpId = threadIdx.x / kWarpSize; + int laneId = threadIdx.x % kWarpSize; + + T threadK = init_K; + S threadV = init_V; + + if (laneId == 0) { + shared_K[warpId] = init_K; + shared_V[warpId] = init_V; + } + __syncthreads(); + + for (int i = tgId; i < bound; i += kGroupSize) { + T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id]; + S other_V = i; + bool is_winner = small ? Cmp::gt(threadK, other_K) : Cmp::lt(threadK, other_K); + ConditionAssign(is_winner, &threadK, other_K); + ConditionAssign(is_winner, &threadV, other_V); + } + __syncwarp(); + + for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { + T other_K = __shfl_down_sync(0xffffffff, threadK, offset); + S other_V = __shfl_down_sync(0xffffffff, threadV, offset); + + bool is_winner = small ? Cmp::gt(threadK, other_K) : Cmp::lt(threadK, other_K); + ConditionAssign(is_winner, &threadK, other_K); + ConditionAssign(is_winner, &threadV, other_V); + } + + __syncwarp(); + + if (laneId == 0) { + shared_K[warpId] = threadK; + shared_V[warpId] = threadV; + } + __syncthreads(); + + if (tgId < 2) { + bool is_winner = + small ? Cmp::gt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 2]) + : Cmp::lt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 2]); + ConditionAssign(is_winner, (shared_K + (groupId * kWarpGroup) + tgId), + (shared_K[(groupId * kWarpGroup) + tgId + 2])); + ConditionAssign(is_winner, (shared_V + (groupId * kWarpGroup) + tgId), + (shared_V[(groupId * kWarpGroup) + tgId + 2])); + } + __syncwarp(); + + if (tgId == 0) { + bool is_winner = + small ? Cmp::gt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 1]) + : Cmp::lt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 1]); + ConditionAssign(is_winner, (shared_K + (groupId * kWarpGroup) + tgId), + (shared_K[(groupId * kWarpGroup) + tgId + 1])); + ConditionAssign(is_winner, (shared_V + (groupId * kWarpGroup) + tgId), + (shared_V[(groupId * kWarpGroup) + tgId + 1])); + + // The first thread of each group write output + output[outer_id * inner_size + inner_id] = shared_K[groupId * kWarpGroup]; + output_index[outer_id * inner_size + inner_id] = shared_V[groupId * kWarpGroup]; + } + __syncthreads(); + } +} + +template +__global__ void BlockReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, + T *output, S *output_index, bool fp16_flag, T init_K) { + __shared__ T shared_K[kNumWarps]; + __shared__ S shared_V[kNumWarps]; + if (fp16_flag) { + init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504); + } + const S init_V = static_cast(-1); + + for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kBlockSize * outer_size * inner_size; + t_idx += blockDim.x * gridDim.x) { + int outer_id = t_idx / kBlockSize / inner_size; + int inner_id = t_idx / kBlockSize % inner_size; + + int tgId = threadIdx.x % kBlockSize; + int warpId = threadIdx.x / kWarpSize; + int laneId = threadIdx.x % kWarpSize; + + T threadK = init_K; + S threadV = init_V; + + if (laneId == 0) { + shared_K[warpId] = init_K; + shared_V[warpId] = init_V; + } + __syncthreads(); + + for (int i = tgId; i < bound; i += kBlockSize) { + T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id]; + S other_V = i; + bool is_winner = small ? Cmp::gt(threadK, other_K) : Cmp::lt(threadK, other_K); + ConditionAssign(is_winner, &threadK, other_K); + ConditionAssign(is_winner, &threadV, other_V); + } + __syncwarp(); + + for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { + T other_K = __shfl_down_sync(0xffffffff, threadK, offset); + S other_V = __shfl_down_sync(0xffffffff, threadV, offset); + + bool is_winner = small ? Cmp::gt(threadK, other_K) : Cmp::lt(threadK, other_K); + ConditionAssign(is_winner, &threadK, other_K); + ConditionAssign(is_winner, &threadV, other_V); + } + + __syncwarp(); + + if (laneId == 0) { + shared_K[warpId] = threadK; + shared_V[warpId] = threadV; + } + __syncthreads(); + + // Shared memory reduction + // There are 16 items in shared memory, can be reduced within one warp. + if (warpId == 0) { + threadK = laneId < kNumWarps ? shared_K[laneId] : init_K; + threadV = laneId < kNumWarps ? shared_V[laneId] : init_V; + } + __syncwarp(); + + if (warpId == 0) { + for (int offset = kWarpSize / 4; offset > 0; offset /= 2) { + T other_K = __shfl_down_sync(0xffffffff, threadK, offset); + S other_V = __shfl_down_sync(0xffffffff, threadV, offset); + + bool is_winner = small ? Cmp::gt(threadK, other_K) : Cmp::lt(threadK, other_K); + ConditionAssign(is_winner, &threadK, other_K); + ConditionAssign(is_winner, &threadV, other_V); + } + } + __syncwarp(); + + if (warpId == 0 && laneId == 0) { + output[outer_id * inner_size + inner_id] = threadK; + output_index[outer_id * inner_size + inner_id] = threadV; + } + } +} + +template +void GeneralReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, T *output, + S *output_index, cudaStream_t stream) { + int block_num_limit = outer_size * inner_size; + bool fp16_flag = false; + if (std::is_same::value) { + fp16_flag = true; + } + T init_K = small ? std::numeric_limits::lowest() : std::numeric_limits::lowest(); + + if (bound <= kMaxThreadLoop) { + ThreadReduction<<>>( + small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K); + } else if (bound <= kMaxWarpLoop) { + WarpReduction<<>>( + small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K); + } else if (bound <= kMaxGroupLoop) { + Warp4Reduction<<>>( + small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K); + } else { + BlockReduction<<>>( + small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K); + } +} + +template +void CalGeneralReduction(bool small, const T *input, const size_t bound, const size_t outerSize, const size_t innerSize, + S *index, T *output, cudaStream_t cuda_stream) { + GeneralReduction(small, outerSize, bound, innerSize, input, output, index, cuda_stream); + return; +} + +template void CalGeneralReduction(bool small, const float *input, const size_t bound_, const size_t outerSize_, + const size_t innerSize_, int *index, float *output, cudaStream_t cuda_stream); +template void CalGeneralReduction(bool small, const half *input, const size_t bound_, const size_t outerSize_, + const size_t innerSize_, int *index, half *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh similarity index 61% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh index 2a08365f20..b09cf08e4c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh @@ -14,9 +14,9 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GENERAL_REDUCTION_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GENERAL_REDUCTION_H_ template -void CalArgmaxWithValue(const T *input, const S bound_, const size_t outerSize_, const size_t innerSize_, S *index, - T *output, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ARGMAXWITHVALUE_H_ +void CalGeneralReduction(bool small, const T *input, const size_t bound_, const size_t outerSize_, + const size_t innerSize_, S *index, T *output, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GENERAL_REDUCTION_H_ diff --git a/tests/st/ops/gpu/test_argmaxwithvalue_op.py b/tests/st/ops/gpu/test_argmaxwithvalue_op.py index 6ce729a6cb..a3db5dbdda 100644 --- a/tests/st/ops/gpu/test_argmaxwithvalue_op.py +++ b/tests/st/ops/gpu/test_argmaxwithvalue_op.py @@ -35,18 +35,24 @@ class NetArgmaxWithValue(nn.Cell): return (self.argmax1(x), self.argmax2(x), self.argmax3(x)) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_argmaxwithvalue(): +class NetArgmaxWithValueBig(nn.Cell): + def __init__(self, axis=0): + super(NetArgmaxWithValueBig, self).__init__() + self.argmax = P.ArgMaxWithValue(axis) + + def construct(self, x): + return self.argmax(x) + + +def argmaxwithvalue_base(data_type): x = Tensor(np.array([[1., 20., 5.], [67., 8., 9.], [130., 24., 15.], - [0.3, -0.4, -15.]]).astype(np.float32)) - expect1 = np.array([2, 2, 2]).astype(np.float32) - expect2 = np.array([1, 0, 0, 0]).astype(np.float32) - expect11 = np.array([130, 24, 15]).astype(np.float32) - expect22 = np.array([20, 67, 130, 0.3]).astype(np.float32) + [0.3, -0.4, -15.]]).astype(data_type)) + expect1 = np.array([2, 2, 2]).astype(data_type) + expect2 = np.array([1, 0, 0, 0]).astype(data_type) + expect11 = np.array([130, 24, 15]).astype(data_type) + expect22 = np.array([20, 67, 130, 0.3]).astype(data_type) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") argmax = NetArgmaxWithValue() output = argmax(x) @@ -66,3 +72,75 @@ def test_argmaxwithvalue(): assert (output[1][1].asnumpy() == expect22).all() assert (output[2][0].asnumpy() == expect1).all() assert (output[2][1].asnumpy() == expect11).all() + + +def argmaxwithvalue_3d(data_type, shape_x): + np.random.seed(876) + x_np = np.random.random(shape_x).astype(data_type) + x = Tensor(x_np) + + argmax = NetArgmaxWithValueBig(0) + output = argmax(x) + expect1 = np.argmax(x_np, axis=0) + expect2 = np.maximum.reduce(x_np, 0) + assert (output[0].asnumpy() == expect1).all() + assert (output[1].asnumpy() == expect2).all() + + argmax = NetArgmaxWithValueBig(1) + output = argmax(x) + expect1 = np.argmax(x_np, axis=1) + expect2 = np.maximum.reduce(x_np, 1) + assert (output[0].asnumpy() == expect1).all() + assert (output[1].asnumpy() == expect2).all() + + argmax = NetArgmaxWithValueBig(2) + output = argmax(x) + expect1 = np.argmax(x_np, axis=2) + expect2 = np.maximum.reduce(x_np, 2) + assert (output[0].asnumpy() == expect1).all() + assert (output[1].asnumpy() == expect2).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argmaxwithvalue_base_float32(): + argmaxwithvalue_base(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argmaxwithvalue_base_float16(): + argmaxwithvalue_base(np.float16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argmaxwithvalue_3d_float32(): + shape_x = (2, 32, 256) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + argmaxwithvalue_3d(np.float32, shape_x) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + argmaxwithvalue_3d(np.float32, shape_x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argmaxwithvalue_3d_float16(): + shape_x = (2, 32, 16) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + argmaxwithvalue_3d(np.float16, shape_x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_argmaxwithvalue_3d_big_float32(): + shape_x = (128, 1024, 1) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + argmaxwithvalue_3d(np.float32, shape_x) + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + argmaxwithvalue_3d(np.float32, shape_x)