diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/batchnorm.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/batchnorm.cl index 554674da6a..c3cc89ad1f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/batchnorm.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/batchnorm.cl @@ -1,3 +1,4 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable #define INT4 int4 #define INT2 int2 __constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; @@ -11,16 +12,16 @@ __kernel void batch_normalization(__read_only image2d_t input, __read_only image if (X >= input_shape.y || Y >= input_shape.z || Z >= input_shape.w) { return; } - FLT4 result = read_imagef(input, smp_none, (int2)((Y)*input_shape.w + Z, (X))); + FLT4 result = READ_IMAGE(input, smp_none, (int2)((Y)*input_shape.w + Z, (X))); - FLT4 result_mean = read_imagef(mean, smp_none, (int2)((Z), (0))); - FLT4 result_var = read_imagef(variance, smp_none, (int2)((Z), (0))); - FLT4 result_scale = read_imagef(scale, smp_none, (int2)((Z), (0))); - FLT4 result_offset = read_imagef(offset, smp_none, (int2)((Z), (0))); + FLT4 result_mean = READ_IMAGE(mean, smp_none, (int2)((Z), (0))); + FLT4 result_var = READ_IMAGE(variance, smp_none, (int2)((Z), (0))); + FLT4 result_scale = READ_IMAGE(scale, smp_none, (int2)((Z), (0))); + FLT4 result_offset = READ_IMAGE(offset, smp_none, (int2)((Z), (0))); result.x = result_scale.x * ((result.x - result_mean.x) / sqrt(result_var.x + epsilon)) + result_offset.x; result.y = result_scale.y * ((result.y - result_mean.y) / sqrt(result_var.y + epsilon)) + result_offset.y; result.z = result_scale.z * ((result.z - result_mean.z) / sqrt(result_var.z + epsilon)) + result_offset.z; result.w = result_scale.w * ((result.w - result_mean.w) / sqrt(result_var.w + epsilon)) + result_offset.w; - write_imagef(output, (int2)((Y)*input_shape.w + Z, (X)), result); + WRITE_IMAGE(output, (int2)((Y)*input_shape.w + Z, (X)), result); } diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/concat.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/concat.cl index c16daf9c81..c2ae7c9106 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/concat.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/concat.cl @@ -1,4 +1,4 @@ -// #pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_fp16 : enable __constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; __kernel void Concat(__read_only image2d_t input0, __read_only image2d_t input1, __write_only image2d_t output, @@ -10,11 +10,11 @@ __kernel void Concat(__read_only image2d_t input0, __read_only image2d_t input1, return; } if (Z < input_channels.x) { - FLT4 result = read_imagef(input0, smp_none, (int2)((Y)*input_channels.x + Z, (X))); - write_imagef(output, (int2)((Y)*output_shape.w + Z, (X)), result); + FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*input_channels.x + Z, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); } else { - FLT4 result = read_imagef(input1, smp_none, (int2)((Y)*input_channels.y + Z - input_channels.x, (X))); - write_imagef(output, (int2)((Y)*output_shape.w + Z, (X)), result); + FLT4 result = READ_IMAGE(input1, smp_none, (int2)((Y)*input_channels.y + Z - input_channels.x, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); } } @@ -27,14 +27,14 @@ __kernel void Concat3input(__read_only image2d_t input0, __read_only image2d_t i return; } if (Z < input_channels.x) { - FLT4 result0 = read_imagef(input0, smp_none, (int2)((Y)*input_channels.x + Z, (X))); - write_imagef(output, (int2)((Y)*output_shape.w + Z, (X)), result0); + FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*input_channels.x + Z, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result0); } else if (Z < (input_channels.x + input_channels.y)) { - FLT4 result1 = read_imagef(input1, smp_none, (int2)((Y)*input_channels.y + Z - input_channels.x, (X))); - write_imagef(output, (int2)((Y)*output_shape.w + Z, (X)), result1); + FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*input_channels.y + Z - input_channels.x, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result1); } else { FLT4 result2 = - read_imagef(input2, smp_none, (int2)((Y)*input_channels.z + Z - input_channels.x - input_channels.y, (X))); - write_imagef(output, (int2)((Y)*output_shape.w + Z, (X)), result2); + READ_IMAGE(input2, smp_none, (int2)((Y)*input_channels.z + Z - input_channels.x - input_channels.y, (X))); + WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result2); } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/slice.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/slice.cl index 5fc704ae4c..72a20cd293 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/slice.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/slice.cl @@ -1,6 +1,6 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable #define INT2 int2 #define INT4 int4 -#define FLT4 float4 __constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; __kernel void slice(__read_only image2d_t input, __write_only image2d_t output, INT4 input_shape, INT4 out_shape, INT4 begin, INT2 sharedNoUpdiv) { @@ -12,46 +12,43 @@ __kernel void slice(__read_only image2d_t input, __write_only image2d_t output, FLT4 result; if (sharedNoUpdiv.x % 4 == 0) { for (int i = 0; i < out_shape.w; i++) { - result = read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (i + begin.w), (X + begin.y))); - write_imagef(output, (INT2)((Y)*out_shape.w + i, (X)), result); + result = READ_IMAGE(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (i + begin.w), (X + begin.y))); + WRITE_IMAGE(output, (INT2)((Y)*out_shape.w + i, (X)), result); } } else { int begin_postion = sharedNoUpdiv.y % 4; - FLT4 first = read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + begin.w, (X + begin.y))); + FLT4 first = READ_IMAGE(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + begin.w, (X + begin.y))); if (begin_postion == 1) { for (int i = 1; i <= out_shape.w; i++) { - FLT4 second = - read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y))); + FLT4 second = READ_IMAGE(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y))); result.x = first.y; result.y = first.z; result.z = first.w; result.w = second.x; - write_imagef(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result); + WRITE_IMAGE(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result); first.y = second.y; first.z = second.z; first.w = second.w; } } else if (begin_postion == 2) { for (int i = 1; i <= out_shape.w; i++) { - FLT4 second = - read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y))); + FLT4 second = READ_IMAGE(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y))); result.x = first.z; result.y = first.w; result.z = second.x; result.w = second.y; - write_imagef(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result); + WRITE_IMAGE(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result); first.z = second.z; first.w = second.w; } } else { for (int i = 1; i <= out_shape.w; i++) { - FLT4 second = - read_imagef(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y))); + FLT4 second = READ_IMAGE(input, smp_none, (INT2)((Y + begin.z) * input_shape.w + (begin.w + i), (X + begin.y))); result.x = first.w; result.y = second.x; result.z = second.y; result.w = second.z; - write_imagef(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result); + WRITE_IMAGE(output, (INT2)((Y)*out_shape.w + i - 1, (X)), result); first.w = second.w; } } @@ -64,18 +61,18 @@ __kernel void slice(__read_only image2d_t input, __write_only image2d_t output, result_fill0.y = 0; result_fill0.z = 0; result_fill0.w = 0; - write_imagef(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0); + WRITE_IMAGE(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0); } else if (size == 2) { result_fill0.x = result.x; result_fill0.y = result.y; result_fill0.z = 0; result_fill0.w = 0; - write_imagef(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0); + WRITE_IMAGE(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0); } else if (size == 3) { result_fill0.x = result.x; result_fill0.y = result.y; result_fill0.z = result.z; result_fill0.w = 0; - write_imagef(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0); + WRITE_IMAGE(output, (INT2)((Y)*out_shape.w + out_shape.w - 1, (X)), result_fill0); } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc index c0036a160f..ee8eba8ad3 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc @@ -38,11 +38,12 @@ int BatchNormOpenCLKernel::GetImageSize(size_t idx, std::vector *img_siz im_dst_y = out_tensors_[0]->Height() * CO4; im_dst_x = out_tensors_[0]->Width(); } -#ifdef ENABLE_FP16 - size_t img_dtype = CL_HALF_FLOAT; -#else size_t img_dtype = CL_FLOAT; -#endif + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto enable_fp16_ = ocl_runtime->GetFp16Enable(); + if (enable_fp16_) { + img_dtype = CL_HALF_FLOAT; + } img_size->clear(); std::vector vec{im_dst_x, im_dst_y, img_dtype}; *img_size = vec; @@ -148,4 +149,5 @@ kernel::LiteKernel *OpenCLBatchnormKernelCreator(const std::vector *img_size) im_dst_y = out_tensors_[0]->Height() * CO4; im_dst_x = out_tensors_[0]->Width(); } -#ifdef ENABLE_FP16 - size_t img_dtype = CL_HALF_FLOAT; -#else size_t img_dtype = CL_FLOAT; -#endif + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto enable_fp16_ = ocl_runtime->GetFp16Enable(); + if (enable_fp16_) { + img_dtype = CL_HALF_FLOAT; + } img_size->clear(); std::vector vec{im_dst_x, im_dst_y, img_dtype}; *img_size = vec; @@ -225,4 +226,5 @@ kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector *img_size) { im_dst_y = out_tensors_[0]->Height() * CO4; im_dst_x = out_tensors_[0]->Width(); } -#ifdef ENABLE_FP16 - size_t img_dtype = CL_HALF_FLOAT; -#else size_t img_dtype = CL_FLOAT; -#endif + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto enable_fp16_ = ocl_runtime->GetFp16Enable(); + if (enable_fp16_) { + img_dtype = CL_HALF_FLOAT; + } img_size->clear(); std::vector vec{im_dst_x, im_dst_y, img_dtype}; *img_size = vec; @@ -143,4 +144,6 @@ kernel::LiteKernel *OpenCLSliceKernelCreator(const std::vector @@ -35,30 +39,153 @@ void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bou ASSERT_LE(abs, err_bound); } } +TEST_F(TestBatchnormOpenCLfp16, Batchnormfp16input_dim4) { + MS_LOG(INFO) << "begin test"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->SetFp16Enable(true); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + MS_LOG(INFO) << "Read tensors from .bin"; + std::vector input_shape = {1, 256, 256, 48}; + std::vector output_shape = {1, 256, 256, 48}; + auto data_type = kNumberTypeFloat32; + auto tensor_type = schema::NodeType_ValueNode; + + // get the input from .bin + size_t input_size, output_size; + std::string input_path = "./test_data/batchnorm_in_datafp16.bin"; + std::string mean_path = "./test_data/batchnorm_meanfp16.bin"; + std::string var_path = "./test_data/batchnorm_varfp16.bin"; + std::string offset_path = "./test_data/batchnorm_offsetfp16.bin"; + std::string scale_path = "./test_data/batchnorm_scalefp16.bin"; + std::string output_path = "./test_data/batchnorm_out_datafp16.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + size_t mean_size, var_size, scale_size, offset_size; + auto mean_data = reinterpret_cast(mindspore::lite::ReadFile(mean_path.c_str(), &mean_size)); + auto var_data = reinterpret_cast(mindspore::lite::ReadFile(var_path.c_str(), &var_size)); + auto scale_data = reinterpret_cast(mindspore::lite::ReadFile(scale_path.c_str(), &scale_size)); + auto offset_data = reinterpret_cast(mindspore::lite::ReadFile(offset_path.c_str(), &offset_size)); + + MS_LOG(INFO) << "construct tensors"; + lite::tensor::Tensor *tensor_data = + new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type); + lite::tensor::Tensor *tensor_mean = + new (std::nothrow) lite::tensor::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type); + lite::tensor::Tensor *tensor_var = + new (std::nothrow) lite::tensor::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type); + lite::tensor::Tensor *tensor_scale = + new (std::nothrow) lite::tensor::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type); + lite::tensor::Tensor *tensor_offset = + new (std::nothrow) lite::tensor::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type); + if (tensor_data == nullptr || tensor_mean == nullptr || tensor_var == nullptr || tensor_scale == nullptr || + tensor_offset == nullptr) { + MS_LOG(INFO) << "init tensor failed"; + return; + } + auto *output_tensor = + new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type); + if (output_tensor == nullptr) { + MS_LOG(INFO) << "init tensor failed"; + delete tensor_data; + delete tensor_mean; + delete tensor_var; + delete tensor_scale; + delete tensor_offset; + return; + } + std::vector inputs = {tensor_data, tensor_scale, tensor_offset, tensor_mean, tensor_var}; + std::vector outputs{output_tensor}; + + MS_LOG(INFO) << "initialize tensors"; + auto param = new (std::nothrow) BatchNormParameter(); + if (param == nullptr) { + MS_LOG(INFO) << "new BatchNormParameter failed"; + for (auto tensor : outputs) { + delete tensor; + } + return; + } + param->epsilon_ = pow(10, -5); + auto *batchnorm_kernel = + new (std::nothrow) kernel::BatchNormOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (batchnorm_kernel == nullptr) { + MS_LOG(INFO) << "new kernel::BatchNorm_kernel failed"; + for (auto tensor : outputs) { + delete tensor; + } + delete param; + return; + } + batchnorm_kernel->Init(); -TEST_F(TestBatchnormOpenCL, Batchnorminput_dim4) { + // to do allocate memory for inputs and outputs + for (auto &input_tensor : inputs) { + input_tensor->MallocData(allocator); + } + + MS_LOG(INFO) << "initialize sub_graph"; + std::vector kernels{batchnorm_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << "new kernel::SubGraphOpenCLKernel failed"; + for (auto tensor : outputs) { + delete tensor; + } + delete param; + delete batchnorm_kernel; + return; + } + sub_graph->Init(); + MS_LOG(INFO) << "init tensors"; + memcpy(inputs[0]->Data(), input_data, input_size); + memcpy(inputs[1]->Data(), scale_data, scale_size); + memcpy(inputs[2]->Data(), offset_data, offset_size); + memcpy(inputs[3]->Data(), mean_data, mean_size); + memcpy(inputs[4]->Data(), var_data, var_size); + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + + auto *output_data_gpu = reinterpret_cast(output_tensor->Data()); + CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + delete param; + delete batchnorm_kernel; + delete sub_graph; +} +TEST_F(TestBatchnormOpenCLfp32, Batchnormfp32input_dim4) { MS_LOG(INFO) << "begin test"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); MS_LOG(INFO) << "Read tensors from .bin"; - std::vector input_shape = {1, 256, 256, 16}; - std::vector output_shape = {1, 256, 256, 16}; + std::vector input_shape = {1, 256, 256, 47}; + std::vector output_shape = {1, 256, 256, 47}; auto data_type = kNumberTypeFloat32; auto tensor_type = schema::NodeType_ValueNode; // get the input from .bin size_t input_size, output_size; - std::string input_path = "./test_data/in_data.bin"; - std::string mean_path = "./test_data/mean.bin"; - std::string var_path = "./test_data/var.bin"; - std::string output_path = "./test_data/out_data.bin"; + std::string input_path = "./test_data/batchnorm_in_datafp32.bin"; + std::string mean_path = "./test_data/batchnorm_meanfp32.bin"; + std::string var_path = "./test_data/batchnorm_varfp32.bin"; + std::string offset_path = "./test_data/batchnorm_offsetfp32.bin"; + std::string scale_path = "./test_data/batchnorm_scalefp32.bin"; + std::string output_path = "./test_data/batchnorm_out_datafp32.bin"; auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); - size_t mean_size, var_size; + size_t mean_size, var_size, scale_size, offset_size; auto mean_data = reinterpret_cast(mindspore::lite::ReadFile(mean_path.c_str(), &mean_size)); auto var_data = reinterpret_cast(mindspore::lite::ReadFile(var_path.c_str(), &var_size)); + auto scale_data = reinterpret_cast(mindspore::lite::ReadFile(scale_path.c_str(), &scale_size)); + auto offset_data = reinterpret_cast(mindspore::lite::ReadFile(offset_path.c_str(), &offset_size)); MS_LOG(INFO) << "construct tensors"; lite::tensor::Tensor *tensor_data = @@ -131,14 +258,9 @@ TEST_F(TestBatchnormOpenCL, Batchnorminput_dim4) { } sub_graph->Init(); MS_LOG(INFO) << "init tensors"; - std::cout << "init tensors" << std::endl; memcpy(inputs[0]->Data(), input_data, input_size); - auto &temp = inputs[1]; - auto tensor_temp = reinterpret_cast(temp->Data()); - int UPDIV_tensor_scale = UP_DIV(tensor_scale->ElementsNum(), C4NUM) * 4; - for (int i = 0; i < UPDIV_tensor_scale; ++i) { - tensor_temp[i] = static_cast(1); - } + memcpy(inputs[1]->Data(), scale_data, scale_size); + memcpy(inputs[2]->Data(), offset_data, offset_size); memcpy(inputs[3]->Data(), mean_data, mean_size); memcpy(inputs[4]->Data(), var_data, var_size); std::cout << "==================output data================" << std::endl; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc index c59a375615..b843b339a7 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc @@ -21,9 +21,10 @@ #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" #include "mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h" -void ConcatComputeByCPU_2input_dim4_axis3(const float *input0, const float *input1, float *output, - std::vector input_shape0, std::vector input_shape1, - std::vector output_shape, const int axis) { +template +void ConcatComputeByCPU_2input_dim4_axis3(const T *input0, const T *input1, T *output, std::vector input_shape0, + std::vector input_shape1, std::vector output_shape, + const int axis) { int postion, index0 = 0, index1 = 0; for (int i = 0; i < output_shape[0]; i++) { for (int j = 0; j < output_shape[1]; j++) { @@ -43,10 +44,10 @@ void ConcatComputeByCPU_2input_dim4_axis3(const float *input0, const float *inpu } } } -void ConcatComputeByCPU_3input_dim4_axis3(float *input0, float *input1, float *input2, float *output, - std::vector input_shape0, std::vector input_shape1, - std::vector input_shape2, std::vector output_shape, - const int axis) { +template +void ConcatComputeByCPU_3input_dim4_axis3(T *input0, T *input1, T *input2, T *output, std::vector input_shape0, + std::vector input_shape1, std::vector input_shape2, + std::vector output_shape, const int axis) { int postion, index0 = 0, index1 = 0, index2 = 0; for (int i = 0; i < output_shape[0]; i++) { for (int j = 0; j < output_shape[1]; j++) { @@ -82,9 +83,13 @@ void ConcatComputeByCPU_3input_dim4_axis3(float *input0, float *input1, float *i } namespace mindspore { -class TestConcatOpenCL : public mindspore::CommonTest { +class TestConcatOpenCLfp32 : public mindspore::CommonTest { + public: + TestConcatOpenCLfp32() {} +}; +class TestConcatOpenCLfp16 : public mindspore::CommonTest { public: - TestConcatOpenCL() {} + TestConcatOpenCLfp16() {} }; template @@ -94,18 +99,138 @@ void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bou ASSERT_LE(abs, err_bound); } } +TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) { + MS_LOG(INFO) << "begin test"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->SetFp16Enable(true); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + MS_LOG(INFO) << "init tensors"; + constexpr int INPUT_NUM = 3; + std::array, INPUT_NUM> input_shapes = { + std::vector{1, 16, 256, 80}, std::vector{1, 16, 256, 80}, std::vector{1, 16, 256, 80}}; + std::vector output_shape = {1, 16, 256, 240}; + auto data_type = kNumberTypeFloat16; + auto tensor_type = schema::NodeType_ValueNode; + std::vector inputs; + for (auto &shape : input_shapes) { + auto input_temp = new (std::nothrow) lite::tensor::Tensor(data_type, shape, schema::Format_NHWC4, tensor_type); + inputs.push_back(input_temp); + if (input_temp == nullptr) { + MS_LOG(INFO) << "new input_tensor failed"; + return; + } + } + auto *output_tensor = + new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type); + if (output_tensor == nullptr) { + MS_LOG(INFO) << "new output_tensor failed"; + for (auto tensor : inputs) { + delete tensor; + } + return; + } + std::vector outputs{output_tensor}; + MS_LOG(INFO) << "input_shapes size=: " << input_shapes.size(); + + MS_LOG(INFO) << "initialize tensors"; + auto param = new (std::nothrow) ConcatParameter(); + if (param == nullptr) { + MS_LOG(INFO) << "new ConcatParameter failed"; + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + return; + } + param->axis_ = 3; + auto *concat_kernel = + new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (concat_kernel == nullptr) { + MS_LOG(INFO) << "new kernel::ConcatOpenCLKernel failed"; + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + delete param; + return; + } + concat_kernel->Init(); + // to do allocate memory for inputs and outputs + for (auto &input_tensor : inputs) { + input_tensor->MallocData(allocator); + } + + MS_LOG(INFO) << "initialize sub_graph"; + std::vector kernels{concat_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << "new kernel::SubGraphOpenCLKernel failed"; + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + delete param; + delete concat_kernel; + return; + } + sub_graph->Init(); + unsigned int seed = 123; + MS_LOG(INFO) << "initialize input data"; + for (auto &input_tensor : inputs) { + auto input_data = reinterpret_cast(input_tensor->Data()); + for (int i = 0; i < input_tensor->ElementsNum(); ++i) { + input_data[i] = static_cast(rand_r(&seed) % 10 + 1); + } + } + + // compute the result for CPU + auto *input_data0 = reinterpret_cast(inputs[0]->Data()); + auto *input_data1 = reinterpret_cast(inputs[1]->Data()); + std::vector output_data_cpu(output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3]); + if (inputs.size() == 2) { + ConcatComputeByCPU_2input_dim4_axis3(input_data0, input_data1, output_data_cpu.data(), input_shapes[0], + input_shapes[1], output_shape, param->axis_); + } + if (inputs.size() == 3) { + auto *input_data2 = reinterpret_cast(inputs[2]->Data()); + ConcatComputeByCPU_3input_dim4_axis3(input_data0, input_data1, input_data2, output_data_cpu.data(), input_shapes[0], + input_shapes[1], input_shapes[2], output_shape, param->axis_); + } + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + auto *output_data_gpu = reinterpret_cast(output_tensor->Data()); + CompareOutputData1(output_data_gpu, output_data_cpu.data(), output_tensor->ElementsNum(), 0.00001); + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + delete param; + delete concat_kernel; + delete sub_graph; +} -TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { +TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) { MS_LOG(INFO) << "begin test"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); MS_LOG(INFO) << "init tensors"; - constexpr int INPUT_NUM = 2; - std::array, INPUT_NUM> input_shapes = {std::vector{1, 16, 256, 80}, - std::vector{1, 16, 256, 80}}; - std::vector output_shape = {1, 16, 256, 160}; + constexpr int INPUT_NUM = 3; + std::array, INPUT_NUM> input_shapes = { + std::vector{1, 16, 256, 80}, std::vector{1, 16, 256, 80}, std::vector{1, 16, 256, 80}}; + std::vector output_shape = {1, 16, 256, 240}; auto data_type = kNumberTypeFloat32; auto tensor_type = schema::NodeType_ValueNode; std::vector inputs; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc index 2a57f68c0a..b74e50add2 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc @@ -23,9 +23,13 @@ #include "mindspore/lite/src/runtime/kernel/opencl/kernel/slice.h" namespace mindspore { -class TestSliceOpenCL : public mindspore::CommonTest { +class TestSliceOpenCLfp32 : public mindspore::CommonTest { public: - TestSliceOpenCL() {} + TestSliceOpenCLfp32() {} +}; +class TestSliceOpenCLfp16 : public mindspore::CommonTest { + public: + TestSliceOpenCLfp16() {} }; template @@ -36,7 +40,7 @@ void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bou } } -TEST_F(TestSliceOpenCL, Sliceinput_dim4) { +TEST_F(TestSliceOpenCLfp32, Slicefp32input_dim4) { MS_LOG(INFO) << "begin test"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); @@ -52,8 +56,8 @@ TEST_F(TestSliceOpenCL, Sliceinput_dim4) { // get the input from .bin size_t input_size, output_size; - std::string input_path = "./test_data/in_data.bin"; - std::string output_path = "./test_data/out_data.bin"; + std::string input_path = "./test_data/in_datafp32.bin"; + std::string output_path = "./test_data/out_datafp32.bin"; auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); @@ -86,7 +90,7 @@ TEST_F(TestSliceOpenCL, Sliceinput_dim4) { MS_LOG(INFO) << "new SliceParameter failed"; return; } - for (int i = 0; i < 4; i++) { + for (int i = 0; i < input_shape.size(); i++) { param->begin_[i] = begin[i]; param->size_[i] = size[i]; } @@ -145,4 +149,114 @@ TEST_F(TestSliceOpenCL, Sliceinput_dim4) { delete slice_kernel; delete sub_graph; } +TEST_F(TestSliceOpenCLfp16, Slicefp16input_dim4) { + MS_LOG(INFO) << "begin test"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->SetFp16Enable(true); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + MS_LOG(INFO) << "Read tensors from .bin"; + std::vector input_shape = {1, 256, 256, 48}; + std::vector output_shape = {1, 255, 255, 15}; + std::vector begin = {0, 1, 1, 7}; + std::vector size = {1, 255, 255, 15}; + auto data_type = kNumberTypeFloat16; + auto tensor_type = schema::NodeType_ValueNode; + + // get the input from .bin + size_t input_size, output_size; + std::string input_path = "./test_data/in_data.bin"; + std::string output_path = "./test_data/out_data.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + + MS_LOG(INFO) << "construct tensors"; + lite::tensor::Tensor *tensor_data = + new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type); + if (tensor_data == nullptr) { + MS_LOG(INFO) << "init tensor failed"; + return; + } + auto *output_tensor = + new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type); + if (output_tensor == nullptr) { + delete tensor_data; + MS_LOG(INFO) << "init tensor failed"; + return; + } + std::vector inputs = {tensor_data}; + std::vector outputs = {output_tensor}; + + MS_LOG(INFO) << "setting SliceParameter"; + auto param = new (std::nothrow) SliceParameter(); + if (param == nullptr) { + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + MS_LOG(INFO) << "new SliceParameter failed"; + return; + } + for (int i = 0; i < 4; i++) { + param->begin_[i] = begin[i]; + param->size_[i] = size[i]; + } + + auto *slice_kernel = + new (std::nothrow) kernel::SliceOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (slice_kernel == nullptr) { + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + delete param; + MS_LOG(INFO) << "new kernel::slice_kernel failed"; + return; + } + slice_kernel->Init(); + + // to do allocate memory for inputs and outputs + for (auto &input_tensor : inputs) { + input_tensor->MallocData(allocator); + } + + MS_LOG(INFO) << "initialize sub_graph"; + std::vector kernels{slice_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + delete param; + delete slice_kernel; + MS_LOG(INFO) << "new kernel::SubGraphOpenCLKernel failed"; + return; + } + sub_graph->Init(); + + MS_LOG(INFO) << "init tensors"; + memcpy(inputs[0]->Data(), input_data, input_size); + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + + auto *output_data_gpu = reinterpret_cast(output_tensor->Data()); + CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); + for (auto tensor : inputs) { + delete tensor; + } + for (auto tensor : outputs) { + delete tensor; + } + delete slice_kernel; + delete sub_graph; +} } // namespace mindspore