diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/biasadd.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/biasadd.cl index e2e1f5b0bb..df5eac49d3 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/biasadd.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/biasadd.cl @@ -1,31 +1,23 @@ -#pragma OPENCL EXTENSION cl_arm_printf : enable - -#define SLICES 4 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define C4NUM 4 #define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) -#define FLT4 float4 -#define READ_FLT4 read_imagef -#define WRITE_FLT4 write_imagef __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __kernel void BiasAdd(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, - __global float *alpha, const int dim) { - int C = input_shape.w; // channel size - + __read_only image2d_t alpha, const int dim) { + int C = input_shape.w; // channel size int Y = get_global_id(0); // height id int X = get_global_id(1); // weight id - for (int num = 0; num < UP_DIV(C, SLICES); ++num) { - FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC - FLT4 tmp; + for (int num = 0; num < UP_DIV(C, C4NUM); ++num) { + FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, C4NUM) + num, Y)); // NHWC4: H WC + FLT4 tmp = in_c4; int index = 0; if (dim == 2) { - index = X * 4; + index = X; } else { - index = num * 4; + index = num; } - tmp.x = in_c4.x + alpha[index]; - tmp.y = in_c4.y + alpha[index + 1]; - tmp.z = in_c4.z + alpha[index + 2]; - tmp.w = in_c4.w + alpha[index + 3]; - WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC + tmp += READ_IMAGE(alpha, smp_zero, (int2)(index, 0)); + WRITE_IMAGE(output, (int2)(X * UP_DIV(C, C4NUM) + num, Y), tmp); // NHWC4: H WC } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/prelu.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/prelu.cl index 40fbb4cfe3..65166c588c 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/prelu.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/prelu.cl @@ -1,11 +1,10 @@ -#pragma OPENCL EXTENSION cl_arm_printf : enable - +#pragma OPENCL EXTENSION cl_khr_fp16 : enable #define SLICES 4 #define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; __kernel void PRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, - __global float *alpha, const int dim) { + __read_only image2d_t alpha, const int dim) { int C = input_shape.w; // channel size int Y = get_global_id(0); // height id @@ -14,16 +13,17 @@ __kernel void PRelu(__read_only image2d_t input, __write_only image2d_t output, FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC FLT4 tmp; if (dim == 1) { - tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * (*alpha); - tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * (*alpha); - tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * (*alpha); - tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * (*alpha); + FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(0, 0)); + tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * weight.x; + tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * weight.x; + tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * weight.x; + tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * weight.x; } else { - int index = num * 4; - tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * alpha[index]; - tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * alpha[index + 1]; - tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * alpha[index + 2]; - tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * alpha[index + 3]; + FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(num, 0)); + tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * weight.x; + tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * weight.y; + tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * weight.z; + tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * weight.w; } WRITE_IMAGE(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc index 151412194b..cd6c880f66 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.cc @@ -39,19 +39,24 @@ void BiasAddOpenCLKernel::InitBuffer() { int C = in_tensors_[1]->shape()[0]; int div_ci = UP_DIV(C, C4NUM); auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator(); - BiasAdd_ = reinterpret_cast(allocator->Malloc(div_ci * C4NUM * sizeof(FLOAT_t))); - BiasAdd_ = reinterpret_cast(allocator->MapBuffer(BiasAdd_, CL_MAP_WRITE, nullptr, true)); - memset(BiasAdd_, 0x00, div_ci * C4NUM * sizeof(FLOAT_t)); - auto origin_weight = reinterpret_cast(in_tensors_[1]->Data()); - for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) { - BiasAdd_[i] = origin_weight[i]; + size_t img_dtype = CL_FLOAT; + if (enable_fp16_) { + img_dtype = CL_HALF_FLOAT; } + std::vector img_size{size_t(div_ci), 1, img_dtype}; + BiasAdd_ = allocator->Malloc(div_ci * C4NUM * fp_size, img_size); + BiasAdd_ = allocator->MapBuffer(BiasAdd_, CL_MAP_WRITE, nullptr, true); + memset(BiasAdd_, 0x00, div_ci * C4NUM * fp_size); + memcpy(BiasAdd_, in_tensors_[1]->Data(), C * fp_size); allocator->UnmapBuffer(BiasAdd_); } int BiasAddOpenCLKernel::Init() { in_size_ = in_tensors_[0]->shape().size(); out_size_ = out_tensors_[0]->shape().size(); + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + enable_fp16_ = ocl_runtime->GetFp16Enable(); + fp_size = enable_fp16_ ? sizeof(float) / 2 : sizeof(float); if (in_size_ != 4 && in_size_ != 2) { MS_LOG(ERROR) << "BiasAdd only support dim=4 or 2, but your dim=" << in_size_; return RET_ERROR; @@ -67,7 +72,6 @@ int BiasAddOpenCLKernel::Init() { std::string source = biasadd_source; std::string program_name = "BiasAdd"; std::string kernel_name = "BiasAdd"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->LoadSource(program_name, source); ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h index 56535afddb..b3c4ba80c2 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/biasadd.h @@ -42,9 +42,11 @@ class BiasAddOpenCLKernel : public OpenCLKernel { private: cl::Kernel kernel_; - FLOAT_t *BiasAdd_; + void *BiasAdd_; int in_size_; int out_size_; + size_t fp_size; + bool enable_fp16_{false}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc index 609ebc0dcc..e79349bf5f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -187,7 +187,7 @@ int Conv2dTransposeOpenCLKernel::Run() { int arg_cnt = 0; ocl_runtime->SetKernelArg(kernel_, arg_cnt++, in_tensors_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_cnt++, padWeight_, lite::opencl::MemType::BUF); - ocl_runtime->SetKernelArg(kernel_, arg_cnt++, bias_, lite::opencl::MemType::BUF); + ocl_runtime->SetKernelArg(kernel_, arg_cnt++, bias_); ocl_runtime->SetKernelArg(kernel_, arg_cnt++, out_tensors_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_cnt++, kernel_size); ocl_runtime->SetKernelArg(kernel_, arg_cnt++, stride); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc index 0356963ad3..809cd5b1b2 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc @@ -164,7 +164,7 @@ int MatMulOpenCLKernel::Run() { int arg_count = 0; ocl_runtime->SetKernelArg(kernel_, arg_count++, in_tensors_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_count++, padWeight_, lite::opencl::MemType::BUF); - ocl_runtime->SetKernelArg(kernel_, arg_count++, bias_, lite::opencl::MemType::BUF); + ocl_runtime->SetKernelArg(kernel_, arg_count++, bias_); ocl_runtime->SetKernelArg(kernel_, arg_count++, out_tensors_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_count++, sizeCI); ocl_runtime->SetKernelArg(kernel_, arg_count++, sizeCO); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc index 78f7df0baf..8b09d42773 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.cc @@ -36,15 +36,16 @@ namespace mindspore::kernel { void PReluOpenCLKernel::InitBuffer() { int C = in_tensors_[1]->shape()[0]; int div_ci = UP_DIV(C, C4NUM); - std::cout << div_ci << std::endl; auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator(); - PReluWeight_ = reinterpret_cast(allocator->Malloc(div_ci * C4NUM * sizeof(FLOAT_t))); - PReluWeight_ = reinterpret_cast(allocator->MapBuffer(PReluWeight_, CL_MAP_WRITE, nullptr, true)); - memset(PReluWeight_, 0x00, div_ci * C4NUM * sizeof(FLOAT_t)); - auto origin_weight = reinterpret_cast(in_tensors_[1]->Data()); - for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) { - PReluWeight_[i] = origin_weight[i]; + size_t img_dtype = CL_FLOAT; + if (enable_fp16_) { + img_dtype = CL_HALF_FLOAT; } + std::vector img_size{size_t(div_ci), 1, img_dtype}; + PReluWeight_ = allocator->Malloc(div_ci * C4NUM * fp_size, img_size); + PReluWeight_ = allocator->MapBuffer(PReluWeight_, CL_MAP_WRITE, nullptr, true); + memset(PReluWeight_, 0x00, div_ci * C4NUM * fp_size); + memcpy(PReluWeight_, in_tensors_[1]->Data(), C * fp_size); allocator->UnmapBuffer(PReluWeight_); } @@ -61,14 +62,14 @@ int PReluOpenCLKernel::Init() { << C_Weight << " and your input channel size is " << C; return RET_ERROR; } - if (C_Weight != 1) { - InitBuffer(); - } std::set build_options; std::string source = prelu_source; std::string program_name = "PRelu"; std::string kernel_name = "PRelu"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + enable_fp16_ = ocl_runtime->GetFp16Enable(); + fp_size = enable_fp16_ ? sizeof(float) / 2 : sizeof(float); + InitBuffer(); ocl_runtime->LoadSource(program_name, source); ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); in_ori_format_ = in_tensors_[0]->GetFormat(); @@ -92,11 +93,7 @@ int PReluOpenCLKernel::Run() { ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape); - if (in_tensors_[1]->shape()[0] == 1) { - ocl_runtime->SetKernelArg(kernel_, arg_idx++, reinterpret_cast(in_tensors_[1]->Data())); - } else { - ocl_runtime->SetKernelArg(kernel_, arg_idx++, PReluWeight_); - } + ocl_runtime->SetKernelArg(kernel_, arg_idx++, PReluWeight_); ocl_runtime->SetKernelArg(kernel_, arg_idx++, reinterpret_cast(in_tensors_[1]->shape()[0])); std::vector local = {1, 1}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h index cc2429250a..d9cdb8137d 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h @@ -40,7 +40,9 @@ class PReluOpenCLKernel : public OpenCLKernel { private: cl::Kernel kernel_; - FLOAT_t *PReluWeight_; + void *PReluWeight_; + size_t fp_size; + bool enable_fp16_{false}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc index 57960e4094..d46b62e7b4 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc @@ -35,21 +35,22 @@ void LoadDataBiasAdd(void *dst, size_t dst_size, const std::string &file_path) { if (file_path.empty()) { memset(dst, 0x00, dst_size); } else { - auto src_data = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &dst_size)); + auto src_data = mindspore::lite::ReadFile(file_path.c_str(), &dst_size); memcpy(dst, src_data, dst_size); } } +template void CompareOutBiasAdd(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) { - auto *output_data = reinterpret_cast(output_tensor->Data()); size_t output_size = output_tensor->ElementsNum(); - auto expect_data = reinterpret_cast(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); + auto output_data = reinterpret_cast(output_tensor->Data()); + auto expect_data = reinterpret_cast(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); constexpr float atol = 0.0002; for (int i = 0; i < output_tensor->ElementsNum(); ++i) { if (std::fabs(output_data[i] - expect_data[i]) > atol) { - printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); - printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); - printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%f output=%f\n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%f output=%f\n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%f output=%f\n\n\n", i, expect_data[i], output_data[i]); return; } } @@ -58,8 +59,10 @@ void CompareOutBiasAdd(lite::tensor::Tensor *output_tensor, const std::string &s printf("compare success!\n\n\n"); } -void printf_tensor_BiasAdd(mindspore::lite::tensor::Tensor *in_data, int size) { - auto input_data = reinterpret_cast(in_data->Data()); +template +void printf_tensor_BiasAdd(const std::string log, mindspore::lite::tensor::Tensor *in_data, int size) { + MS_LOG(INFO) << log; + auto input_data = reinterpret_cast(in_data->Data()); for (int i = 0; i < size; ++i) { printf("%f ", input_data[i]); } @@ -67,15 +70,6 @@ void printf_tensor_BiasAdd(mindspore::lite::tensor::Tensor *in_data, int size) { MS_LOG(INFO) << "Print tensor done"; } -void printf_float_BiasAdd(float *data, int num = 0) { - float *temp = data; - for (int i = 0; i < num; ++i) { - std::cout << *temp << " "; - temp++; - } - std::cout << std::endl; -} - TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) { std::string in_file = "/data/local/tmp/in_data.bin"; std::string weight_file = "/data/local/tmp/weight_data.bin"; @@ -83,29 +77,34 @@ TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) { MS_LOG(INFO) << "BiasAdd Begin test:"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); - auto allocator = ocl_runtime->GetAllocator(); - - MS_LOG(INFO) << "BiasAdd init tensors."; - + auto data_type = kNumberTypeFloat16; + ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); std::vector input_shape = {1, 9}; std::vector output_shape = {1, 9}; - auto data_type = kNumberTypeFloat32; + auto tensor_type = schema::NodeType_ValueNode; - auto *input_tensor = - new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NC, tensor_type); + schema::Format type; + int weight_shape = 0; + if (input_shape.size() == 4) { + weight_shape = input_shape[3]; + type = schema::Format_NHWC; + } else { + weight_shape = input_shape[1]; + type = schema::Format_NC; + } + auto *input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, type, tensor_type); if (input_tensor == nullptr) { MS_LOG(ERROR) << "new input tensor error!"; return; } - auto *output_tensor = - new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NC, tensor_type); + auto *output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, type, tensor_type); if (output_tensor == nullptr) { MS_LOG(ERROR) << "new output tensor error!"; delete input_tensor; return; } auto *weight_tensor = new (std::nothrow) - lite::tensor::Tensor(data_type, std::vector{input_shape[1]}, schema::Format_NHWC, tensor_type); + lite::tensor::Tensor(data_type, std::vector{weight_shape}, schema::Format_NHWC, tensor_type); if (weight_tensor == nullptr) { MS_LOG(ERROR) << "new weight tensor error!"; delete output_tensor; @@ -114,14 +113,18 @@ TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) { } std::vector inputs{input_tensor, weight_tensor}; std::vector outputs{output_tensor}; + auto allocator = ocl_runtime->GetAllocator(); inputs[0]->MallocData(allocator); inputs[1]->MallocData(allocator); LoadDataBiasAdd(input_tensor->Data(), input_tensor->Size(), in_file); - MS_LOG(INFO) << "BiasAdd==================input data================"; - printf_tensor_BiasAdd(inputs[0], input_tensor->ElementsNum()); LoadDataBiasAdd(weight_tensor->Data(), weight_tensor->Size(), weight_file); - MS_LOG(INFO) << "BiasAdd==================weight data================"; - printf_tensor_BiasAdd(inputs[1], weight_tensor->ElementsNum()); + if (ocl_runtime->GetFp16Enable()) { + printf_tensor_BiasAdd("BiasAdd:FP16--input data", inputs[0], input_tensor->ElementsNum()); + printf_tensor_BiasAdd("BiasAdd:FP16--weight data", inputs[1], weight_tensor->ElementsNum()); + } else { + printf_tensor_BiasAdd("BiasAdd:FP32--input data", inputs[0], input_tensor->ElementsNum()); + printf_tensor_BiasAdd("BiasAdd:FP32--weight data", inputs[1], weight_tensor->ElementsNum()); + } auto *param = new (std::nothrow) OpParameter(); if (param == nullptr) { @@ -189,9 +192,13 @@ TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) { return; } - MS_LOG(INFO) << "BiasAdd==================output data================"; - printf_tensor_BiasAdd(outputs[0], output_tensor->ElementsNum()); - CompareOutBiasAdd(output_tensor, standard_answer_file); + if (ocl_runtime->GetFp16Enable()) { + printf_tensor_BiasAdd("BiasAdd:FP16--output data", outputs[0], output_tensor->ElementsNum()); + CompareOutBiasAdd(output_tensor, standard_answer_file); + } else { + printf_tensor_BiasAdd("BiasAdd:FP32--output data", outputs[0], output_tensor->ElementsNum()); + CompareOutBiasAdd(output_tensor, standard_answer_file); + } delete input_tensor; delete weight_tensor; delete output_tensor; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc index 64fc23c43c..8ee30bd194 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc @@ -37,15 +37,16 @@ void LoadDataPRelu(void *dst, size_t dst_size, const std::string &file_path) { if (file_path.empty()) { memset(dst, 0x00, dst_size); } else { - auto src_data = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &dst_size)); + auto src_data = mindspore::lite::ReadFile(file_path.c_str(), &dst_size); memcpy(dst, src_data, dst_size); } } +template void CompareOutPRelu(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) { - auto *output_data = reinterpret_cast(output_tensor->Data()); + auto *output_data = reinterpret_cast(output_tensor->Data()); size_t output_size = output_tensor->Size(); - auto expect_data = reinterpret_cast(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); + auto expect_data = reinterpret_cast(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); constexpr float atol = 0.0002; for (int i = 0; i < output_tensor->ElementsNum(); ++i) { if (std::fabs(output_data[i] - expect_data[i]) > atol) { @@ -60,6 +61,17 @@ void CompareOutPRelu(lite::tensor::Tensor *output_tensor, const std::string &sta printf("compare success!\n\n\n"); } +template +void printf_tensor_Prelu(const std::string &log, mindspore::lite::tensor::Tensor *in_data, int size) { + MS_LOG(INFO) << log; + auto input_data = reinterpret_cast(in_data->Data()); + for (int i = 0; i < size; ++i) { + printf("%f ", input_data[i]); + } + printf("\n"); + MS_LOG(INFO) << "Print tensor done"; +} + TEST_F(TestPReluOpenCL, PReluFp32_dim4) { std::string in_file = "/data/local/tmp/in_data.bin"; std::string weight_file = "/data/local/tmp/weight_data.bin"; @@ -71,16 +83,14 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) { MS_LOG(INFO) << "Init tensors."; std::vector input_shape = {1, 4, 3, 9}; - - auto data_type = kNumberTypeFloat32; + auto data_type = kNumberTypeFloat16; + ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); auto tensor_type = schema::NodeType_ValueNode; - auto input_tensor = - new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type); + auto input_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type); if (input_tensor == nullptr) { MS_LOG(ERROR) << "new input_tensor error!"; return; } - auto output_tensor = new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type); if (output_tensor == nullptr) { @@ -88,9 +98,8 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) { delete input_tensor; return; } - - auto weight_tensor = - new (std::nothrow) lite::tensor::Tensor(data_type, std::vector{9}, schema::Format_NHWC, tensor_type); + auto weight_tensor = new (std::nothrow) + lite::tensor::Tensor(data_type, std::vector{input_shape[3]}, schema::Format_NHWC, tensor_type); if (weight_tensor == nullptr) { MS_LOG(ERROR) << "new weight_tensor error"; delete input_tensor; @@ -99,20 +108,20 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) { } std::vector inputs{input_tensor, weight_tensor}; std::vector outputs{output_tensor}; - - // freamework to do!!! allocate memory by hand inputs[0]->MallocData(allocator); inputs[1]->MallocData(allocator); MS_LOG(INFO) << "initialize input data"; LoadDataPRelu(input_tensor->Data(), input_tensor->Size(), in_file); LoadDataPRelu(weight_tensor->Data(), weight_tensor->Size(), weight_file); - auto weight_data = reinterpret_cast(weight_tensor->Data()); - PrintData("Weight data", weight_data, inputs[1]->ElementsNum()); - auto *input_data = reinterpret_cast(inputs[0]->Data()); - PrintData("PRelu input data", input_data, inputs[0]->ElementsNum()); - std::cout << inputs[0]->ElementsNum() << std::endl; - std::cout << "--------------------------------------------" << std::endl; + if (ocl_runtime->GetFp16Enable()) { + printf_tensor_Prelu("PRELU:FP16--input data", input_tensor, inputs[0]->ElementsNum()); + printf_tensor_Prelu("PRELU:FP16--weight data", weight_tensor, weight_tensor->ElementsNum()); + } else { + printf_tensor_Prelu("PRELU:FP32--input data", input_tensor, inputs[0]->ElementsNum()); + printf_tensor_Prelu("PRELU:FP32--weight data", weight_tensor, inputs[1]->ElementsNum()); + } + auto param = new (std::nothrow) PReluParameter(); if (param == nullptr) { MS_LOG(ERROR) << "new PreluParameter error"; @@ -173,10 +182,13 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) { return; } - MS_LOG(INFO) << "PRelu==================output data================"; - auto *output_data = reinterpret_cast(outputs[0]->Data()); - PrintData("output_data", output_data, outputs[0]->ElementsC4Num()); - CompareOutPRelu(output_tensor, standard_answer_file); + if (ocl_runtime->GetFp16Enable()) { + printf_tensor_Prelu("PRelu:FP16--output_data", output_tensor, outputs[0]->ElementsNum()); + CompareOutPRelu(output_tensor, standard_answer_file); + } else { + printf_tensor_Prelu("PRelu:FP32--output_data", output_tensor, outputs[0]->ElementsNum()); + CompareOutPRelu(output_tensor, standard_answer_file); + } delete input_tensor; delete output_tensor; delete weight_tensor;