From 5bcf605b4569410913d445cb4b761d3a9d51a5ea Mon Sep 17 00:00:00 2001 From: wandongdong Date: Sun, 17 Jan 2021 22:23:53 -0800 Subject: [PATCH] add int data support for opencl --- .../src/runtime/kernel/opencl/cl/argminmax.cl | 11 ++-- .../runtime/kernel/opencl/kernel/argminmax.cc | 40 ++++++++------- .../runtime/kernel/opencl/kernel/argminmax.h | 1 + .../kernel/opencl/kernel/depthwise_conv2d.cc | 11 ++-- .../runtime/kernel/opencl/opencl_fusion.cc | 3 +- .../runtime/kernel/opencl/opencl_subgraph.cc | 45 +++++++++++------ .../runtime/kernel/opencl/opencl_subgraph.h | 8 ++- .../lite/src/runtime/kernel/opencl/utils.cc | 18 +++++++ .../lite/src/runtime/kernel/opencl/utils.h | 2 + .../runtime/kernel/opencl/argminmax_tests.cc | 50 ++++++++++++++++++- 10 files changed, 142 insertions(+), 47 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/argminmax.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/argminmax.cl index 743eabe5bd..76e566f322 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/argminmax.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/argminmax.cl @@ -11,23 +11,28 @@ __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __kernel void argminmax(__global FLT *src_data, __global FLT *dst_data, __global FLT *buf, __global int *ids, int4 shape, int4 src_size, int4 cus_size, int4 strides, int4 flags) { - int X = get_global_id(0); // reduce len + int X = get_global_id(0); // lower reduce stride int Y = get_global_id(1); // upper axis accumulation if (X >= src_size.x || Y >= src_size.y) { return; } int offset = X + Y * src_size.z; - int align_c4 = (flags.z != 3) ? (X / shape.w) * (shape.x) : 0; + int align_c4 = (flags.z != 3) ? (X / shape.w) * (C4NUM - shape.w & 0x00000003) : 0; int align_in = 0; int align_out = 0; + bool keep_dims = cus_size.y; + int width = shape.z * shape.w; if (flags.z == 3) { align_in = (Y / shape.z) * cus_size.z; align_out = (Y / shape.z) * cus_size.w; } if (flags.z == 0) { - align_in = X / (shape.y) * cus_size.z; + align_in = X / (width)*cus_size.z; align_out = align_in; } + if (flags.z == 2 && !keep_dims) { + align_out = (Y / shape.y) * cus_size.w; + } for (int k = 0; k < src_size.w; ++k) { int idx0 = (X + k * strides.x) + Y * strides.y + (align_c4 + align_in); int idx1 = offset + k * src_size.x; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc index 5e8e0a95de..8614f9b2b9 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc @@ -61,8 +61,6 @@ void ArgMinMaxOpenCLKernel::SetConstArgs() { auto param = reinterpret_cast(op_parameter_); cl_int4 in_shape{static_cast(im_in_.N), static_cast(im_in_.H), static_cast(im_in_.W), static_cast(im_in_.C)}; - in_shape.s[0] = UP_ROUND(im_in_.C, C4NUM) - im_in_.C; - in_shape.s[1] = im_in_.W * im_in_.C; cl_int4 flags = {param->out_value_, param->get_max_, param->axis_, param->topk_}; int arg_cnt = 2; ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, buff_, lite::opencl::MemType::BUF); @@ -77,17 +75,20 @@ void ArgMinMaxOpenCLKernel::SetConstArgs() { void ArgMinMaxOpenCLKernel::SetGlobalLocal() { auto param = reinterpret_cast(op_parameter_); im_in_ = GpuTensorInfo(in_tensors_[0]); + im_out_ = GpuTensorInfo(out_tensors_[0]); std::vector in_shape = {im_in_.N, im_in_.H, im_in_.W, im_in_.C}; auto in_shape_align = in_shape; in_shape_align[3] = UP_ROUND(in_shape[3], C4NUM); - auto out_shape_align = in_shape_align; - out_shape_align.at(param->axis_) = param->axis_ == 3 ? UP_ROUND(param->topk_, C4NUM) : param->topk_; + std::vector out_shape = {im_out_.N, im_out_.H, im_out_.W, im_out_.C}; + auto out_shape_align = out_shape; + out_shape_align[3] = UP_ROUND(out_shape[3], C4NUM); int reduce_len = GetUpPow2(in_shape.at(param->axis_)); int dtype_size = in_tensors_[0]->data_type() == kNumberTypeFloat16 ? sizeof(int16_t) : sizeof(float); - cus_size_ = {reduce_len, static_cast(im_in_.RowPitch() / dtype_size), 1, 1}; - cus_size_.s[2] = UP_ROUND(im_in_.width * C4NUM, cus_size_.s[1]) - im_in_.width * C4NUM; - cus_size_.s[3] = im_in_.W * UP_ROUND(param->topk_, C4NUM); - cus_size_.s[3] = UP_ROUND(cus_size_.s[3], cus_size_.s[1]) - cus_size_.s[3]; + int in_pitch = im_in_.RowPitch() / dtype_size; + int out_pitch = im_out_.RowPitch() / dtype_size; + cus_size_ = {reduce_len, param->keep_dims_, 1, 1}; + cus_size_.s[2] = in_pitch - im_in_.width * C4NUM; + cus_size_.s[3] = out_pitch - im_out_.width * C4NUM; src_size_ = {std::accumulate(in_shape.begin() + param->axis_ + 1, in_shape.end(), 1, std::multiplies()), std::accumulate(in_shape.begin(), in_shape.begin() + param->axis_, 1, std::multiplies()), std::accumulate(in_shape.begin() + param->axis_, in_shape.end(), 1, std::multiplies()), @@ -100,22 +101,25 @@ void ArgMinMaxOpenCLKernel::SetGlobalLocal() { }; switch (param->axis_) { case 0: - strides_.s[0] = UP_ROUND(strides_.s[0] / im_in_.H, cus_size_.s[1]) * im_in_.H; + strides_.s[0] = UP_ROUND(strides_.s[0] / im_in_.H, in_pitch) * im_in_.H; strides_.s[1] = strides_.s[0] * im_in_.N; - strides_.s[2] = UP_ROUND(strides_.s[2] / im_in_.H, cus_size_.s[1]) * im_in_.H; + strides_.s[2] = UP_ROUND(strides_.s[2] / im_in_.H, out_pitch) * im_in_.H; strides_.s[3] = strides_.s[2] * param->topk_; break; case 1: - strides_.s[0] = UP_ROUND(strides_.s[0], cus_size_.s[1]); - strides_.s[1] = UP_ROUND(strides_.s[1] / im_in_.H, cus_size_.s[1]) * im_in_.H; - strides_.s[2] = UP_ROUND(strides_.s[2], cus_size_.s[1]); - strides_.s[3] = UP_ROUND(strides_.s[3] / param->topk_, cus_size_.s[1]) * param->topk_; + strides_.s[0] = UP_ROUND(strides_.s[0], in_pitch); + strides_.s[1] = UP_ROUND(strides_.s[1] / im_in_.H, in_pitch) * im_in_.H; + // org dim(4,3) org axis(1,0) + strides_.s[2] = UP_ROUND(strides_.s[2], out_pitch); + strides_.s[3] = UP_ROUND(strides_.s[3] / param->topk_, out_pitch) * param->topk_; break; case 2: - strides_.s[1] = UP_ROUND(strides_.s[1], cus_size_.s[1]); - strides_.s[3] = UP_ROUND(strides_.s[3], cus_size_.s[1]); + strides_.s[1] = UP_ROUND(strides_.s[1], in_pitch); + // org dim(4,3,2) org axis(2,1,0) + strides_.s[3] = param->keep_dims_ ? UP_ROUND(strides_.s[3], out_pitch) : strides_.s[2]; break; default: // 3 + // org dim(4,3,2,1) org axis(3,2,1,0) break; } local_size_ = {1, 1, 1}; @@ -147,8 +151,10 @@ int ArgMinMaxOpenCLKernel::Prepare() { auto *param = reinterpret_cast(this->op_parameter_); param->dims_size_ = in_tensors_[0]->shape().size(); param->axis_ = (param->axis_ + param->dims_size_) % param->dims_size_; - param->axis_ = (4 - param->dims_size_) + param->axis_; + param->axis_ = GetBroadcastGpuAxis(param->dims_size_, param->axis_); param->get_max_ = (Type() == PrimitiveType_ArgMax); + param->keep_dims_ = + param->keep_dims_ || param->topk_ > 1 || in_tensors_[0]->shape().size() == out_tensors_[0]->shape().size(); InitWeights(); SetGlobalLocal(); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h index 6b7ce95095..51d7c07858 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.h @@ -44,6 +44,7 @@ class ArgMinMaxOpenCLKernel : public OpenCLKernel { void *buff_{nullptr}; void *ids_{nullptr}; GpuTensorInfo im_in_{GpuTensorInfo(nullptr)}; + GpuTensorInfo im_out_{GpuTensorInfo(nullptr)}; cl_int4 src_size_; cl_int4 cus_size_; cl_int4 strides_; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc index 373963494d..58e653f82e 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc @@ -105,6 +105,7 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { auto allocator = ocl_runtime_->GetAllocator(); bool is_fp16 = ocl_runtime_->GetFp16Enable(); + size_t dtype_size = is_fp16 ? sizeof(int16_t) : sizeof(float); auto out_info = GpuTensorInfo(out_tensors_[0]); // weight: o, h, w, i; o == group, i == 1 void *origin_weight = in_tensors_.at(kWeightIndex)->data_c(); @@ -121,7 +122,7 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT; img_size = {(size_t)plane_out / C4NUM, (size_t)out_info.N * CO4, img_dtype}; } - pack_weight_size = is_fp16 ? pack_weight_size * sizeof(int16_t) : pack_weight_size * sizeof(float); + pack_weight_size = pack_weight_size * dtype_size; auto ConvertFilter = [](void *src, void *dst, TypeId src_type, TypeId dst_type, size_t plane_in, size_t plane_out, size_t channel) { if (dst_type == kNumberTypeFloat16) { @@ -173,18 +174,14 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { memcpy(dst, src, size * dtype_size); } }; - size_t dtype_size = sizeof(float); - if (is_fp16 && in_tensors_.at(kBiasIndex)->data_type() == kNumberTypeFloat16) { - dtype_size = sizeof(int16_t); - } - std::vector temp_bias(pack_weight_size, 0); + size_t bias_size = C4NUM * CO4 * dtype_size; + std::vector temp_bias(bias_size, 0); if (in_tensors_.size() == 3) { src_type = in_tensors_.at(kBiasIndex)->data_type(); dst_type = is_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32; auto element_size = in_tensors_.at(kBiasIndex)->ElementsNum(); ConvertBias(in_tensors_.at(kBiasIndex)->data_c(), temp_bias.data(), element_size, dtype_size, src_type, dst_type); } - size_t bias_size = C4NUM * CO4 * dtype_size; bias_data_ = allocator->Malloc(bias_size, {}, temp_bias.data()); if (bias_data_ == nullptr) { return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc index c5e962c579..7642eb685a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_fusion.cc @@ -538,7 +538,7 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::vector *nodes, s } // namespace -void OpenCLSubGraph::Fusion() { +int OpenCLSubGraph::FusionPass() { MS_LOG(DEBUG) << "start Fusion"; std::vector input_nodes; @@ -657,6 +657,7 @@ void OpenCLSubGraph::Fusion() { std::remove_if(nodes_.begin(), nodes_.end(), [&](LiteKernel *node) { return AIsInB(node, &removed_set); }), nodes_.end()); MS_LOG(DEBUG) << "number of kernels(after fusion) : " << nodes_.size(); + return RET_OK; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc index cea8eb7fb5..6fed22b189 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.cc @@ -16,6 +16,8 @@ #include "src/runtime/kernel/opencl/opencl_subgraph.h" #include +#include +#include #include "src/runtime/opencl/opencl_executor.h" #include "src/runtime/kernel/opencl/utils.h" #include "include/errorcode.h" @@ -189,19 +191,7 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector &in_tensors, } return RET_OK; } - -int OpenCLSubGraph::Init() { - allocator_ = ocl_runtime_->GetAllocator(); - MS_LOG(DEBUG) << "input num=" << in_tensors_.size() << ", output num=" << out_tensors_.size(); - for (const auto tensor : in_tensors_) { - MS_ASSERT(tensor); - tensor->set_allocator(allocator_); - } - for (const auto tensor : out_tensors_) { - MS_ASSERT(tensor); - tensor->set_allocator(allocator_); - } - +int OpenCLSubGraph::InsertOpsPass() { GetInOutNodes(); std::vector> from_kernels_; @@ -222,12 +212,34 @@ int OpenCLSubGraph::Init() { } nodes_.insert(nodes_.end(), out_convert_ops_.begin(), out_convert_ops_.end()); GetInOutNodes(); - UpdateTensorDataType(); - Fusion(); + return RET_OK; +} +int OpenCLSubGraph::Init() { + allocator_ = ocl_runtime_->GetAllocator(); + MS_LOG(DEBUG) << "input num=" << in_tensors_.size() << ", output num=" << out_tensors_.size(); + for (const auto tensor : in_tensors_) { + MS_ASSERT(tensor); + tensor->set_allocator(allocator_); + } + for (const auto tensor : out_tensors_) { + MS_ASSERT(tensor); + tensor->set_allocator(allocator_); + } + std::map> pass_manager{ + {"InsertOpsPass", std::bind(&OpenCLSubGraph::InsertOpsPass, this)}, + {"UpdateTensorDataTypePass", std::bind(&OpenCLSubGraph::UpdateTensorDataTypePass, this)}, + {"FusionPass", std::bind(&OpenCLSubGraph::FusionPass, this)}}; + for (auto iv : pass_manager) { + auto ret = iv.second(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Run Pass: " << iv.first << " failed."; + return RET_ERROR; + } + } return RET_OK; } -void OpenCLSubGraph::UpdateTensorDataType() { +int OpenCLSubGraph::UpdateTensorDataTypePass() { bool is_fp16 = ocl_runtime_->GetFp16Enable(); MS_ASSERT(in_tensors_[0]); if (is_fp16 && (in_tensors_[0]->data_type() == kNumberTypeFloat32)) { @@ -245,6 +257,7 @@ void OpenCLSubGraph::UpdateTensorDataType() { } } } + return RET_OK; } void OpenCLSubGraph::GetKernelFromToTensor(const std::vector &in_tensors, diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h index bcaf85bf34..7e737968e3 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h @@ -46,10 +46,11 @@ class OpenCLSubGraph : public SubGraphKernel { int ReSize() override; int Run() override; int Run(const KernelCallBack &before, const KernelCallBack &after) override { return this->Run(); }; + int InsertOpsPass(); private: void UnInit(); - void UpdateTensorDataType(); + int UpdateTensorDataTypePass(); void ReplaceOutTensorAndKernelToNull(const std::vector &in_tensors, const std::vector> &in_kernels, lite::opencl::MemType mem_type); @@ -64,7 +65,10 @@ class OpenCLSubGraph : public SubGraphKernel { void GetKernelFromToTensor(const std::vector &in_tensors, const std::vector &in_kernels, std::vector> *out_kernels, bool is_from); - void Fusion(); + int FusionPass(); + + public: + using PassFunc = int (OpenCLSubGraph::*)(void); private: lite::opencl::OpenCLAllocator *allocator_{nullptr}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.cc b/mindspore/lite/src/runtime/kernel/opencl/utils.cc index 57cba111b2..2537c3e9ae 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/utils.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.cc @@ -330,4 +330,22 @@ std::vector GetImage2dShapeFromNHWC(const std::vector &tensor_shape } return {image_x, image_y}; } +int GetBroadcastGpuAxis(int ndim, int ori_axis) { + if (ori_axis >= ndim) { + return ndim - 1; + } + int axis = 0; + if (ndim == 1) { + axis = 3; + } else if (ndim == 2) { + axis = ori_axis == 0 ? 0 : 3; + } else if (ndim == 3) { + axis = ori_axis == 0 ? 0 : ori_axis == 1 ? 2 : 3; + } else if (ndim == 4) { + axis = ori_axis; + } else if (ndim > 4) { + MS_LOG(ERROR) << "GPU doesn't support ndim>=" << ndim; + } + return axis; +} } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.h b/mindspore/lite/src/runtime/kernel/opencl/utils.h index 8142f3562a..4d99899744 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/utils.h +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.h @@ -61,6 +61,8 @@ std::vector GetNHWCShape(const std::vector &tensor_shape); std::vector GetImage2dShapeFromNHWC(const std::vector &tensor_shape, schema::Format format); +int GetBroadcastGpuAxis(int ndim, int ori_axis); + template void PackNCHWToNC4HW4(void *src, void *dst, int batch, int plane_in, int plane_out, int channel, const std::function &to_dtype) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc index 126c115d13..540eb0d6b1 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc @@ -185,7 +185,7 @@ TEST_F(TestOpenCL_ArgMinMax, axis3topk2value) { TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); } } -TEST_F(TestOpenCL_ArgMinMax, axis1topk1index) { +TEST_F(TestOpenCL_ArgMinMax, dim32axis1topk1index) { schema::PrimitiveType type = schema::PrimitiveType_ArgMax; int axis = 1; int topk = 1; @@ -200,4 +200,52 @@ TEST_F(TestOpenCL_ArgMinMax, axis1topk1index) { TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true); } } +TEST_F(TestOpenCL_ArgMinMax, dim43axis2topk1index) { + schema::PrimitiveType type = schema::PrimitiveType_ArgMax; + int axis = 2; + int topk = 1; + bool out_value = false; + std::vector input_shape = {2, 2, 2, 14}; + std::vector output_shape = {2, 2, 14}; + float input_data[] = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, + 1, 50, 30, 45, 25, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, + 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, + 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 10, 20, 30, 40, 90, 20, 11, 15, + 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25}; + float output_data[] = {1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, + 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0}; + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(type, axis, topk, out_value); + TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true); + } +} +TEST_F(TestOpenCL_ArgMinMax, dim21axis2topk1index) { + schema::PrimitiveType type = schema::PrimitiveType_ArgMax; + int axis = 0; + int topk = 1; + bool out_value = false; + std::vector input_shape = {2, 14}; + std::vector output_shape = {14}; + float input_data[] = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, + 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25}; + float output_data[] = {1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0}; + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(type, axis, topk, out_value); + TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true); + } +} +TEST_F(TestOpenCL_ArgMinMax, dim10axis2topk1index) { + schema::PrimitiveType type = schema::PrimitiveType_ArgMax; + int axis = 0; + int topk = 1; + bool out_value = false; + std::vector input_shape = {14}; + std::vector output_shape = {1}; + float input_data[] = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50}; + float output_data[] = {4}; + for (auto fp16_enable : {false, true}) { + auto *param = CreateParameter(type, axis, topk, out_value); + TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true); + } +} } // namespace mindspore::lite::opencl::test