diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/reshape.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/reshape.cl index bb9892b575..b51c514856 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/reshape.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/reshape.cl @@ -1,3 +1,4 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __kernel void reshape(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) { int X = get_global_id(0); diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/transpose.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/transpose.cl index 05f903602e..0076b5fdb5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/transpose.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/transpose.cl @@ -1,3 +1,4 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __kernel void transpose_IMG(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 HW, int2 C) { int X = get_global_id(0); @@ -75,8 +76,8 @@ __kernel void transpose_BUF(__read_only image2d_t src_data, global FLT4 *dst_dat result[3].z = x2.w; result[3].w = x3.w; - dst_data[4 * Y * HW.y + X] = result[0]; - dst_data[(4 * Y + 1) * HW.y + X] = result[1]; - dst_data[(4 * Y + 2) * HW.y + X] = result[2]; - dst_data[(4 * Y + 3) * HW.y + X] = result[3]; + if (4 * Y < C.x) dst_data[4 * Y * HW.y + X] = result[0]; + if (4 * Y + 1 < C.x) dst_data[(4 * Y + 1) * HW.y + X] = result[1]; + if (4 * Y + 2 < C.x) dst_data[(4 * Y + 2) * HW.y + X] = result[2]; + if (4 * Y + 3 < C.x) dst_data[(4 * Y + 3) * HW.y + X] = result[3]; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc index db644507ca..3254d0758d 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc @@ -33,6 +33,7 @@ namespace mindspore::kernel { int ReshapeOpenCLKernel::Init() { std::string kernel_name = "reshape"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + enable_fp16_ = ocl_runtime->GetFp16Enable(); in_ori_format_ = in_tensors_[0]->GetFormat(); out_ori_format_ = out_tensors_[0]->GetFormat(); if (in_ori_format_ != schema::Format_NHWC4 && in_ori_format_ != schema::Format_NHWC) { @@ -73,11 +74,10 @@ int ReshapeOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) int c = shapex[3]; im_dst_x = w * UP_DIV(c, C4NUM); im_dst_y = h; -#ifdef ENABLE_FP16 - size_t img_dtype = CL_HALF_FLOAT; -#else size_t img_dtype = CL_FLOAT; -#endif + if (enable_fp16_) { + img_dtype = CL_HALF_FLOAT; + } img_size->clear(); std::vector vec{im_dst_x, im_dst_y, img_dtype}; *img_size = vec; @@ -121,4 +121,5 @@ kernel::LiteKernel *OpenCLReshapeKernelCreator(const std::vectorGetFp16Enable(); if (!is_image_out_) { kernel_name += "_BUF"; } else { @@ -70,11 +71,10 @@ int TransposeOpenCLKernel::GetImageSize(size_t idx, std::vector *img_siz size_t im_dst_x, im_dst_y; im_dst_x = UP_DIV(out_tensors_[0]->Height() * out_tensors_[0]->Width(), C4NUM); im_dst_y = out_tensors_[0]->Channel(); -#ifdef ENABLE_FP16 - size_t img_dtype = CL_HALF_FLOAT; -#else size_t img_dtype = CL_FLOAT; -#endif + if (enable_fp16_) { + img_dtype = CL_HALF_FLOAT; + } img_size->clear(); std::vector vec{im_dst_x, im_dst_y, img_dtype}; *img_size = vec; @@ -82,6 +82,7 @@ int TransposeOpenCLKernel::GetImageSize(size_t idx, std::vector *img_siz } int TransposeOpenCLKernel::Run() { + // notice: input image2d size = {c/4, h * w} MS_LOG(DEBUG) << this->name() << " Running!"; std::vector shapex = in_tensors_[0]->shape(); int h = shapex[1]; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h index 2efce26672..708acbfaae 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h @@ -38,7 +38,8 @@ class TransposeOpenCLKernel : public OpenCLKernel { private: cl::Kernel kernel_; - bool is_image_out_ = false; + bool is_image_out_{false}; + bool enable_fp16_{false}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc index 1bdb6f8575..cf37a850e4 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc @@ -31,14 +31,14 @@ class TestConv2dTransposeOpenCL : public mindspore::CommonTest { }; void RunTestCaseConv2dTranspose(const std::vector &shape, void *input_data, void *weight_data, void *bias_data, - void *output_data, bool fp16) { + void *output_data, bool enable_fp16) { auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); size_t dtype_size = sizeof(float); - if (fp16) { + if (enable_fp16) { ocl_runtime->SetFp16Enable(true); dtype_size = sizeof(float16_t); } - ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); int pad = shape[0]; int n = shape[1]; @@ -52,7 +52,7 @@ void RunTestCaseConv2dTranspose(const std::vector &shape, void *input_data, int ow = 2 * w - 1 + 2 * (kw - 1 - pad) - kw + 1; std::vector input_shape = {n, h, w, ci}; auto tensor_x_ptr = - std::make_unique(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape); + std::make_unique(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape); auto tensor_x = tensor_x_ptr.get(); if (tensor_x == nullptr) { MS_LOG(ERROR) << "tensor_x create error."; @@ -61,7 +61,7 @@ void RunTestCaseConv2dTranspose(const std::vector &shape, void *input_data, std::vector weight_shape = {co, kh, kw, ci}; auto tensor_w_ptr = - std::make_unique(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape); + std::make_unique(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape); auto tensor_w = tensor_w_ptr.get(); if (tensor_w == nullptr) { MS_LOG(ERROR) << "tensor_w create error."; @@ -71,7 +71,7 @@ void RunTestCaseConv2dTranspose(const std::vector &shape, void *input_data, std::vector bias_shape = {co}; auto tensor_bias_ptr = - std::make_unique(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), bias_shape); + std::make_unique(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), bias_shape); auto tensor_bias = tensor_bias_ptr.get(); if (tensor_bias == nullptr) { MS_LOG(ERROR) << "tensor_bias create error."; @@ -81,7 +81,7 @@ void RunTestCaseConv2dTranspose(const std::vector &shape, void *input_data, std::vector out_shape = {1, oh, ow, co}; auto tensor_out_ptr = - std::make_unique(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape); + std::make_unique(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape); auto tensor_out = tensor_out_ptr.get(); if (tensor_out == nullptr) { MS_LOG(ERROR) << "tensor_out create error."; @@ -126,7 +126,7 @@ void RunTestCaseConv2dTranspose(const std::vector &shape, void *input_data, pGraph->Init(); memcpy(inputs[0]->Data(), input_data, n * h * w * ci * dtype_size); pGraph->Run(); - if (fp16) { + if (enable_fp16) { CompareOutput(outputs[0]->Data(), output_data, n * oh * ow * co, static_cast(1e-3), 2e-2); } else { CompareOutput(outputs[0]->Data(), output_data, n * oh * ow * co, static_cast(1e-5)); @@ -137,7 +137,8 @@ void RunTestCaseConv2dTranspose(const std::vector &shape, void *input_data, lite::opencl::OpenCLRuntime::DeleteInstance(); } -void RunTestCaseConv2dTranspose(const std::vector shape, const std::vector file_path, bool fp16) { +void RunTestCaseConv2dTranspose(const std::vector shape, const std::vector file_path, + bool enable_fp16) { size_t input_size; std::string input_path = file_path[0]; auto input_data = mindspore::lite::ReadFile(input_path.c_str(), &input_size); @@ -168,7 +169,7 @@ void RunTestCaseConv2dTranspose(const std::vector shape, const std::vector< MS_LOG(ERROR) << "output_data load error."; return; } - RunTestCaseConv2dTranspose(shape, input_data, weight_data, bias_data, output_data, fp16); + RunTestCaseConv2dTranspose(shape, input_data, weight_data, bias_data, output_data, enable_fp16); } TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc index 9540fca960..8b21e5f845 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc @@ -29,32 +29,21 @@ class TestMatMulOpenCL : public mindspore::CommonTest { TestMatMulOpenCL() {} }; -void RunTestCaseMatMul(const std::vector shape, const std::vector file_path, bool fp16) { +void RunTestCaseMatMul(const std::vector &shape, void *input_data, void *weight_data, void *output_data, + bool enable_fp16) { auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); - if (fp16) { + size_t dtype_size = sizeof(float); + if (enable_fp16) { ocl_runtime->SetFp16Enable(true); + dtype_size = sizeof(float16_t); } auto allocator = ocl_runtime->GetAllocator(); - size_t input_size; int ci = shape[0]; int co = shape[1]; - std::string input_path = file_path[0]; - auto input_data = mindspore::lite::ReadFile(input_path.c_str(), &input_size); - if (input_data == nullptr) { - MS_LOG(ERROR) << "input_data load error."; - return; - } - size_t weight_size; - std::string weight_path = file_path[1]; - auto weight_data = mindspore::lite::ReadFile(weight_path.c_str(), &weight_size); - if (weight_data == nullptr) { - MS_LOG(ERROR) << "weight_data load error."; - return; - } std::vector input_shape = {1, ci}; - auto tensor_x_ptr = std::make_unique(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), - input_shape, schema::Format_NC); + auto tensor_x_ptr = std::make_unique( + TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape, schema::Format_NC); auto tensor_x = tensor_x_ptr.get(); if (tensor_x == nullptr) { MS_LOG(ERROR) << "tensor_x create error."; @@ -63,7 +52,7 @@ void RunTestCaseMatMul(const std::vector shape, const std::vector w_shape = {co, ci}; auto tensor_w_ptr = - std::make_unique(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), w_shape); + std::make_unique(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), w_shape); auto tensor_w = tensor_w_ptr.get(); if (tensor_w == nullptr) { MS_LOG(ERROR) << "tensor_w create error."; @@ -72,8 +61,8 @@ void RunTestCaseMatMul(const std::vector shape, const std::vectorSetData(weight_data); std::vector out_shape = {1, co}; - auto tensor_out_ptr = std::make_unique(TypeId(fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), - out_shape, schema::Format_NC); + auto tensor_out_ptr = std::make_unique( + TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape, schema::Format_NC); auto tensor_out = tensor_out_ptr.get(); if (tensor_out == nullptr) { MS_LOG(ERROR) << "tensor_out create error."; @@ -100,12 +89,12 @@ void RunTestCaseMatMul(const std::vector shape, const std::vectorInit(); - memcpy(inputs[0]->Data(), input_data, input_size); + memcpy(inputs[0]->Data(), input_data, ci * dtype_size); pGraph->Run(); - if (fp16) { - CompareOutput(tensor_out, file_path[2], static_cast(1e-3), 2e-2); + if (enable_fp16) { + CompareOutput(outputs[0]->Data(), output_data, co, static_cast(1e-3), 2e-2); } else { - CompareOutput(tensor_out, file_path[2], static_cast(1e-5)); + CompareOutput(outputs[0]->Data(), output_data, co, static_cast(1e-5)); } tensor_x->SetData(nullptr); @@ -114,6 +103,31 @@ void RunTestCaseMatMul(const std::vector shape, const std::vector shape, const std::vector file_path, bool enable_fp16) { + size_t input_size; + std::string input_path = file_path[0]; + auto input_data = mindspore::lite::ReadFile(input_path.c_str(), &input_size); + if (input_data == nullptr) { + MS_LOG(ERROR) << "input_data load error."; + return; + } + size_t weight_size; + std::string weight_path = file_path[1]; + auto weight_data = mindspore::lite::ReadFile(weight_path.c_str(), &weight_size); + if (weight_data == nullptr) { + MS_LOG(ERROR) << "weight_data load error."; + return; + } + size_t output_size; + std::string output_path = file_path[2]; + auto output_data = mindspore::lite::ReadFile(output_path.c_str(), &output_size); + if (output_data == nullptr) { + MS_LOG(ERROR) << "output_data load error."; + return; + } + RunTestCaseMatMul(shape, input_data, weight_data, output_data, enable_fp16); +} + TEST_F(TestMatMulOpenCL, MatMulFp32) { int ci = 1280; int co = 1001; @@ -133,4 +147,26 @@ TEST_F(TestMatMulOpenCL, MatMulFp16) { "./test_data/matmul/matmul_fp16_output.bin"}; RunTestCaseMatMul(shape, file_path, true); } + +TEST_F(TestMatMulOpenCL, MatMulFp32_2) { + int ci = 5; + int co = 3; + std::vector shape = {ci, co}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + std::vector weight_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + std::vector output_data = {10.f, 10.f, 10.f}; + RunTestCaseMatMul(shape, input_data.data(), weight_data.data(), output_data.data(), false); +} + +TEST_F(TestMatMulOpenCL, MatMulFp16_2) { + int ci = 5; + int co = 3; + std::vector shape = {ci, co}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + std::vector weight_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + std::vector output_data = {10.f, 10.f, 10.f}; + RunTestCaseMatMul(shape, input_data.data(), weight_data.data(), output_data.data(), true); +} } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc index d0e0b344f1..0172df1a31 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc @@ -21,6 +21,7 @@ #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" #include "mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.h" +#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" namespace mindspore { class TestReshapeOpenCL : public mindspore::CommonTest { @@ -28,29 +29,27 @@ class TestReshapeOpenCL : public mindspore::CommonTest { TestReshapeOpenCL() {} }; -TEST_F(TestReshapeOpenCL, ReshapeFp32) { +void RunTestCaseReshape(const std::vector &shape, void *input_data, void *output_data, bool enable_fp16) { auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); - auto allocator = ocl_runtime->GetAllocator(); - int c = 63; - size_t input_size; - std::string input_path = "./test_data/reshape/reshape_fp32_input.bin"; - auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); - if (input_data == nullptr) { - MS_LOG(ERROR) << "input_data load error."; - return; + size_t dtype_size = sizeof(float); + if (enable_fp16) { + ocl_runtime->SetFp16Enable(true); + dtype_size = sizeof(float16_t); } + auto allocator = ocl_runtime->GetAllocator(); + int c = shape[0]; std::vector input_shape = {1, 1, 1, c}; - auto tensor_x_ptr = - std::make_unique(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC); + auto tensor_x_ptr = std::make_unique( + TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape, schema::Format_NHWC); auto tensor_x = tensor_x_ptr.get(); if (tensor_x == nullptr) { MS_LOG(ERROR) << "tensor_x create error."; return; } std::vector out_shape = {1, c}; - auto tensor_out_ptr = - std::make_unique(TypeId(kNumberTypeFloat32), out_shape, schema::Format_NC); + auto tensor_out_ptr = std::make_unique( + TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape, schema::Format_NC); auto tensor_out = tensor_out_ptr.get(); if (tensor_out == nullptr) { MS_LOG(ERROR) << "tensor_out create error."; @@ -76,36 +75,36 @@ TEST_F(TestReshapeOpenCL, ReshapeFp32) { return; } pGraph->Init(); - memcpy(inputs[0]->Data(), input_data, input_size); + memcpy(inputs[0]->Data(), input_data, c * dtype_size); pGraph->Run(); - size_t output_size; - std::string output_path = "./test_data/reshape/reshape_fp32_output.bin"; - auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); - if (correct_data == nullptr) { - MS_LOG(ERROR) << "correct_data create error."; - return; - } - printf("==================output data=================\n"); - float *output_data = reinterpret_cast(tensor_out->Data()); - std::cout << std::endl; - int size_n = c; - size_n = size_n > 100 ? 100 : size_n; - for (int i = 0; i < size_n; i++) { - std::cout << output_data[i] << " "; - if ((i + 1) % c == 0) { - std::cout << std::endl; - } + if (enable_fp16) { + CompareOutput(outputs[0]->Data(), output_data, c, static_cast(1e-3), 2e-2); + } else { + CompareOutput(outputs[0]->Data(), output_data, c, static_cast(1e-5)); } - std::cout << std::endl; - - // compare - CompareOutputData(output_data, correct_data, c, 0.00001); - inputs[0]->SetData(nullptr); outputs[0]->SetData(nullptr); MS_LOG(INFO) << "Test ReshapeFp32 passed"; lite::opencl::OpenCLRuntime::DeleteInstance(); } + +TEST_F(TestReshapeOpenCL, ReshapeFp32) { + int c = 7; + std::vector shape = {c}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + std::vector output_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + RunTestCaseReshape(shape, input_data.data(), output_data.data(), false); +} + +TEST_F(TestReshapeOpenCL, ReshapeFp16) { + int c = 7; + std::vector shape = {c}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + std::vector output_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + RunTestCaseReshape(shape, input_data.data(), output_data.data(), true); +} } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc index 0cd2fa8536..925b45de20 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc @@ -21,6 +21,7 @@ #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" #include "mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h" +#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" namespace mindspore { class TestTransposeOpenCL : public mindspore::CommonTest { @@ -28,31 +29,29 @@ class TestTransposeOpenCL : public mindspore::CommonTest { TestTransposeOpenCL() {} }; -TEST_F(TestTransposeOpenCL, TransposeFp32) { +void RunTestTranspose(const std::vector &shape, void *input_data, void *output_data, bool enable_fp16) { auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); - auto allocator = ocl_runtime->GetAllocator(); - int h = 64; - int w = 1; - int c = 7360; - size_t input_size; - std::string input_path = "./test_data/transpose/transpose_fp32_input.bin"; - auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); - if (input_data == nullptr) { - MS_LOG(ERROR) << "input_data load error."; - return; + size_t dtype_size = sizeof(float); + if (enable_fp16) { + ocl_runtime->SetFp16Enable(true); + dtype_size = sizeof(float16_t); } + auto allocator = ocl_runtime->GetAllocator(); + int h = shape[0]; + int w = shape[1]; + int c = shape[2]; std::vector input_shape = {1, h, w, c}; - auto tensor_x_ptr = - std::make_unique(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC); + auto tensor_x_ptr = std::make_unique( + TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape, schema::Format_NHWC); auto tensor_x = tensor_x_ptr.get(); if (tensor_x == nullptr) { MS_LOG(ERROR) << "tensor_x create error."; return; } std::vector out_shape = {1, c, h, w}; - auto tensor_out_ptr = - std::make_unique(TypeId(kNumberTypeFloat32), out_shape, schema::Format_NCHW); + auto tensor_out_ptr = std::make_unique( + TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape, schema::Format_NCHW); auto tensor_out = tensor_out_ptr.get(); if (tensor_out == nullptr) { MS_LOG(ERROR) << "tensor_out create error."; @@ -78,9 +77,35 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) { return; } pGraph->Init(); - memcpy(inputs[0]->Data(), input_data, input_size); + memcpy(inputs[0]->Data(), input_data, h * w * c * dtype_size); pGraph->Run(); + if (enable_fp16) { + CompareOutput(outputs[0]->Data(), output_data, h * w * c, static_cast(1e-3), 2e-2); + } else { + CompareOutput(outputs[0]->Data(), output_data, h * w * c, static_cast(1e-5)); + } + + inputs[0]->SetData(nullptr); + outputs[0]->SetData(nullptr); + + MS_LOG(INFO) << "Test TransposeFp32 passed"; + lite::opencl::OpenCLRuntime::DeleteInstance(); +} + +TEST_F(TestTransposeOpenCL, TransposeFp32) { + int h = 64; + int w = 1; + int c = 7360; + std::vector shape = {h, w, c}; + size_t input_size; + std::string input_path = "./test_data/transpose/transpose_fp32_input.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + if (input_data == nullptr) { + MS_LOG(ERROR) << "input_data load error."; + return; + } + size_t output_size; std::string output_path = "./test_data/transpose/transpose_fp32_output.bin"; auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); @@ -88,26 +113,17 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) { MS_LOG(ERROR) << "correct_data create error."; return; } - printf("==================output data=================\n"); - float *output_data = reinterpret_cast(tensor_out->Data()); - std::cout << std::endl; - int size_n = h * w * c; - size_n = size_n > 100 ? 100 : size_n; - for (int i = 0; i < size_n; i++) { - std::cout << output_data[i] << " "; - if ((i + 1) % c == 0) { - std::cout << std::endl; - } - } - std::cout << std::endl; - - // compare - CompareOutputData(output_data, correct_data, h * w * c, 0.00001); + RunTestTranspose(shape, input_data, correct_data, false); +} - inputs[0]->SetData(nullptr); - outputs[0]->SetData(nullptr); +TEST_F(TestTransposeOpenCL, TransposeFp16) { + int h = 4; + int w = 1; + int c = 3; + std::vector shape = {h, w, c}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; + std::vector output_data = {0.0f, 3.0f, 6.0f, 9.0f, 1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f, 8.0f, 11.0f}; - MS_LOG(INFO) << "Test TransposeFp32 passed"; - lite::opencl::OpenCLRuntime::DeleteInstance(); + RunTestTranspose(shape, input_data.data(), output_data.data(), true); } } // namespace mindspore