From b98e2b314033284f18c4a9006f58680a657af115 Mon Sep 17 00:00:00 2001 From: chenzupeng Date: Wed, 19 Aug 2020 21:22:54 +0800 Subject: [PATCH] fix memory release bug in testcase --- .../kernel/opencl/kernel/conv2d_transpose.cc | 4 +- .../runtime/kernel/opencl/kernel/matmul.cc | 5 +- .../runtime/kernel/opencl/kernel/reshape.cc | 1 - .../runtime/kernel/opencl/kernel/softmax.cc | 6 +- .../runtime/kernel/opencl/kernel/transpose.cc | 4 +- mindspore/lite/test/CMakeLists.txt | 1 + .../kernel/opencl/conv2d_transpose_tests.cc | 85 ++++++++++++++----- .../src/runtime/kernel/opencl/matmul_tests.cc | 59 +++++++++---- .../runtime/kernel/opencl/to_format_tests.cc | 58 ++++++++----- .../runtime/kernel/opencl/transpose_tests.cc | 52 ++++++++---- 10 files changed, 192 insertions(+), 83 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc index 1cc4c37770..bb9276e671 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc @@ -192,9 +192,7 @@ kernel::LiteKernel *OpenCLConv2dTransposeKernelCreator(const std::vectorInit(); - if (0 != ret) { - // MS_LOG(ERROR) << "Init kernel failed, name: " << opDef.name()->str() - // << ", type: " << lite::EnumNameOpT(opDef.attr_type()); + if (ret != RET_OK) { delete kernel; return nullptr; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc index aa5d48bf07..8b8dac11b7 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc @@ -40,7 +40,6 @@ int MatMulOpenCLKernel::Init() { ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); #else std::set build_options; -// build_options.emplace("-DPOOL_AVG"); #ifdef ENABLE_FP16 std::string source = matmul_source_fp16; #else @@ -169,9 +168,7 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vectorInit(); - if (0 != ret) { - // MS_LOG(ERROR) << "Init kernel failed, name: " << opDef.name()->str() - // << ", type: " << lite::EnumNameOpT(opDef.attr_type()); + if (ret != RET_OK) { delete kernel; return nullptr; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc index bf4aa835b3..064caffaf6 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/reshape.cc @@ -83,7 +83,6 @@ int ReshapeOpenCLKernel::Run() { int c = shapex[3]; int c4 = UP_DIV(c, C4NUM); auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); - // local size should less than MAX_GROUP_SIZE std::vector local = {}; std::vector global = {(size_t)h, (size_t)w, (size_t)c4}; cl_int4 size = {h, w, c4, 1}; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc index a5de65607b..1fc02fc03f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc @@ -91,7 +91,9 @@ int SoftmaxOpenCLKernel::Init() { std::string source = softmax_source_fp32; runtime_ = lite::opencl::OpenCLRuntime::GetInstance(); // framework not set this param yet! just use default. - parameter_->axis_ = 1; + if (parameter_->axis_ == -1) { + parameter_->axis_ = 1; + } if (in_tensors_[0]->shape().size() == 4 && parameter_->axis_ == 3) { // support 4d tensor onexone_flag_ = false; @@ -180,7 +182,7 @@ kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vectorInit(); - if (0 != ret) { + if (ret != RET_OK) { MS_LOG(ERROR) << "Init `Softmax` kernel failed!"; delete kernel; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc index a2200d2c80..a42f39195a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.cc @@ -64,7 +64,6 @@ int TransposeOpenCLKernel::Init() { MS_LOG(ERROR) << "input H * W % 4 != 0 not support!"; return RET_ERROR; } - // Transpose::InferShape just set output->SetFormat(input->GetFormat()); -^-! ori_format_ = schema::Format_NCHW; out_tensors_[0]->SetFormat(schema::Format_NCHW); if (!is_image_out_) { @@ -100,7 +99,6 @@ int TransposeOpenCLKernel::Run() { int c4 = UP_DIV(c, 4); int hw4 = UP_DIV(h * w, 4); auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); - // local size should less than MAX_GROUP_SIZE std::vector local = {16, 16}; std::vector global = {UP_ROUND(hw4, local[0]), UP_ROUND(c4, local[1])}; @@ -126,7 +124,7 @@ kernel::LiteKernel *OpenCLTransposeKernelCreator(const std::vectorInit(); - if (0 != ret) { + if (ret != RET_OK) { delete kernel; return nullptr; } diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index f191e13636..ca2a5b67a7 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -152,6 +152,7 @@ if (SUPPORT_GPU) ${LITE_DIR}/src/runtime/kernel/opencl/kernel/to_format.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/caffe_prelu.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/prelu.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/to_format.cc ) endif() ### minddata lite diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc index b90b97795a..288144b06d 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc @@ -30,7 +30,6 @@ class TestConv2dTransposeOpenCL : public mindspore::CommonTest { }; TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { - // setbuf(stdout, NULL); auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -48,27 +47,67 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { size_t input_size; std::string input_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_input.bin"; auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + if (input_data == nullptr) { + MS_LOG(ERROR) << "input_data load error."; + return; + } size_t weight_size; std::string weight_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_weight.bin"; auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); + if (weight_data == nullptr) { + MS_LOG(ERROR) << "weight_data load error."; + return; + } size_t bias_size; std::string bias_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_bias.bin"; auto bias_data = reinterpret_cast(mindspore::lite::ReadFile(bias_path.c_str(), &bias_size)); + if (bias_data == nullptr) { + MS_LOG(ERROR) << "bias_data load error."; + return; + } + std::vector input_shape = {n, h, w, ci}; + auto tensor_x_ptr = std::make_unique(TypeId(kNumberTypeFloat32), input_shape); + auto tensor_x = tensor_x_ptr.get(); + if (tensor_x == nullptr) { + MS_LOG(ERROR) << "tensor_x create error."; + return; + } - lite::tensor::Tensor *tensor_x = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, h, w, ci}); - - lite::tensor::Tensor *tensor_w = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co, kh, kw, ci}); + std::vector weight_shape = {co, kh, kw, ci}; + auto tensor_w_ptr = std::make_unique(TypeId(kNumberTypeFloat32), weight_shape); + auto tensor_w = tensor_w_ptr.get(); + if (tensor_w == nullptr) { + MS_LOG(ERROR) << "tensor_w create error."; + return; + } tensor_w->SetData(weight_data); - lite::tensor::Tensor *tensor_bias = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co}); + std::vector bias_shape = {co}; + auto tensor_bias_ptr = std::make_unique(TypeId(kNumberTypeFloat32), bias_shape); + auto tensor_bias = tensor_bias_ptr.get(); + if (tensor_bias == nullptr) { + MS_LOG(ERROR) << "tensor_bias create error."; + return; + } tensor_bias->SetData(bias_data); - lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, oh, ow, co}); + std::vector out_shape = {1, oh, ow, co}; + auto tensor_out_ptr = std::make_unique(TypeId(kNumberTypeFloat32), out_shape); + auto tensor_out = tensor_out_ptr.get(); + if (tensor_out == nullptr) { + MS_LOG(ERROR) << "tensor_out create error."; + return; + } std::vector inputs{tensor_x, tensor_w, tensor_bias}; std::vector outputs{tensor_out}; - ConvParameter *opParameter = new ConvParameter(); + auto opParameter_ptr = std::make_unique(); + auto opParameter = opParameter_ptr.get(); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "opParameter create error."; + return; + } opParameter->kernel_h_ = kh; opParameter->kernel_w_ = kw; opParameter->stride_h_ = 2; @@ -77,23 +116,39 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { opParameter->pad_w_ = pad; opParameter->input_channel_ = ci; opParameter->output_channel_ = co; - auto *arith_kernel = - new kernel::Conv2dTransposeOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + auto arith_kernel_ptr = std::make_unique( + reinterpret_cast(opParameter), inputs, outputs); + auto arith_kernel = arith_kernel_ptr.get(); + if (arith_kernel == nullptr) { + MS_LOG(ERROR) << "arith_kernel create error."; + return; + } arith_kernel->Init(); inputs[0]->MallocData(allocator); std::vector kernels{arith_kernel}; - auto *pGraph = new kernel::SubGraphOpenCLKernel({tensor_x}, outputs, kernels, kernels, kernels); + std::vector inputs_g{tensor_x}; + auto pGraph_ptr = std::make_unique(inputs_g, outputs, kernels, kernels, kernels); + auto pGraph = pGraph_ptr.get(); + if (pGraph == nullptr) { + MS_LOG(ERROR) << "pGraph create error."; + return; + } + pGraph->Init(); memcpy(inputs[0]->Data(), input_data, input_size); pGraph->Run(); - printf("==================output data=================\n"); + std::cout << "==================output data=================" << std::endl; float *output_data = reinterpret_cast(tensor_out->Data()); std::cout << std::endl; size_t output_size; std::string output_path = "./test_data/conv2d_transpose/conv2d_transpose_fp32_output.bin"; auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + if (correct_data == nullptr) { + MS_LOG(ERROR) << "correct_data create error."; + return; + } int size_n = oh * ow * co; size_n = size_n > 100 ? 100 : size_n; for (int i = 0; i < size_n; i++) { @@ -108,14 +163,6 @@ TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { CompareOutputData(output_data, correct_data, oh * ow * co, 0.00001); MS_LOG(INFO) << "Test Conv2dTransposeFp32 passed"; - for (auto tensor : inputs) { - delete tensor; - } - for (auto tensor : outputs) { - delete tensor; - } - delete arith_kernel; - delete pGraph; lite::opencl::OpenCLRuntime::DeleteInstance(); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc index 82bbb90128..27ba92f018 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc @@ -36,25 +36,61 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) { int co = 1001; std::string input_path = "./test_data/matmul/matmul_fp32_input.bin"; auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); - + if (input_data == nullptr) { + MS_LOG(ERROR) << "input_data load error."; + return; + } size_t weight_size; std::string weight_path = "./test_data/matmul/matmul_fp32_weight.bin"; auto weight_data = reinterpret_cast(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size)); - - lite::tensor::Tensor *tensor_x = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, 1, 1, ci}); + if (weight_data == nullptr) { + MS_LOG(ERROR) << "weight_data load error."; + return; + } + std::vector input_shape = {1, 1, 1, ci}; + auto tensor_x_ptr = std::make_unique(TypeId(kNumberTypeFloat32), input_shape); + auto tensor_x = tensor_x_ptr.get(); + if (tensor_x == nullptr) { + MS_LOG(ERROR) << "tensor_x create error."; + return; + } tensor_x->SetData(input_data); - lite::tensor::Tensor *tensor_w = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {co, 1, 1, ci}); + std::vector w_shape = {co, 1, 1, ci}; + auto tensor_w_ptr = std::make_unique(TypeId(kNumberTypeFloat32), w_shape); + auto tensor_w = tensor_w_ptr.get(); + if (tensor_w == nullptr) { + MS_LOG(ERROR) << "tensor_w create error."; + return; + } tensor_w->SetData(weight_data); - lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, 1, 1, co}); + std::vector out_shape = {1, 1, 1, co}; + auto tensor_out_ptr = std::make_unique(TypeId(kNumberTypeFloat32), out_shape); + auto tensor_out = tensor_out_ptr.get(); + if (tensor_out == nullptr) { + MS_LOG(ERROR) << "tensor_out create error."; + return; + } std::vector inputs{tensor_x, tensor_w}; std::vector outputs{tensor_out}; - auto *arith_kernel = new kernel::MatMulOpenCLKernel(nullptr, inputs, outputs, false); + auto arith_kernel_ptr = std::make_unique(nullptr, inputs, outputs, false); + auto arith_kernel = arith_kernel_ptr.get(); + if (arith_kernel == nullptr) { + MS_LOG(ERROR) << "arith_kernel create error."; + return; + } arith_kernel->Init(); std::vector kernels{arith_kernel}; - auto *pGraph = new kernel::SubGraphOpenCLKernel({tensor_x}, outputs, kernels, kernels, kernels); + + std::vector inputs_g{tensor_x}; + auto pGraph_ptr = std::make_unique(inputs_g, outputs, kernels, kernels, kernels); + auto pGraph = pGraph_ptr.get(); + if (pGraph == nullptr) { + MS_LOG(ERROR) << "pGraph create error."; + return; + } pGraph->Init(); pGraph->Run(); @@ -71,19 +107,10 @@ TEST_F(TestMatMulOpenCL, MatMulFp32) { } std::cout << std::endl; - // compare CompareOutputData(output_data, correct_data, co, 0.00001); MS_LOG(INFO) << "TestMatMulFp32 passed"; - for (auto tensor : inputs) { - delete tensor; - } - for (auto tensor : outputs) { - delete tensor; - } - delete arith_kernel; - delete pGraph; lite::opencl::OpenCLRuntime::DeleteInstance(); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc index e8566508d0..bbe9d94fa6 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc @@ -20,7 +20,7 @@ #include "mindspore/lite/src/common/file_utils.h" #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" -#include "mindspore/lite/src/runtime/kernel/opencl/kernel/transpose.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h" namespace mindspore { class TestToFormatOpenCL : public mindspore::CommonTest { @@ -28,8 +28,8 @@ class TestToFormatOpenCL : public mindspore::CommonTest { TestToFormatOpenCL() {} }; -TEST_F(TestToFormatOpenCL, TransposeFp32) { - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); +TEST_F(TestToFormatOpenCL, ToFormatNHWC2NCHW) { + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); int h = 64; @@ -38,20 +38,44 @@ TEST_F(TestToFormatOpenCL, TransposeFp32) { size_t input_size; std::string input_path = "./test_data/transpose/transpose_fp32_input.bin"; auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); - - lite::tensor::Tensor *tensor_x = - new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, h, w, c}, schema::Format_NHWC4); - - lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, c, h, w}); + if (input_data == nullptr) { + MS_LOG(ERROR) << "input_data load error."; + return; + } + std::vector input_shape = {1, h, w, c}; + auto tensor_x_ptr = + std::make_unique(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC4); + auto tensor_x = tensor_x_ptr.get(); + if (tensor_x == nullptr) { + MS_LOG(ERROR) << "tensor_x create error."; + return; + } + std::vector out_shape = {1, c, h, w}; + auto tensor_out_ptr = std::make_unique(TypeId(kNumberTypeFloat32), out_shape); + auto tensor_out = tensor_out_ptr.get(); + if (tensor_out == nullptr) { + MS_LOG(ERROR) << "tensor_out create error."; + return; + } std::vector inputs{tensor_x}; std::vector outputs{tensor_out}; - auto *arith_kernel = new kernel::TransposeOpenCLKernel(nullptr, inputs, outputs); + auto arith_kernel_ptr = std::make_unique(nullptr, inputs, outputs); + auto arith_kernel = arith_kernel_ptr.get(); + if (arith_kernel == nullptr) { + MS_LOG(ERROR) << "arith_kernel create error."; + return; + } arith_kernel->Init(); inputs[0]->MallocData(allocator); std::vector kernels{arith_kernel}; - auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + auto pGraph_ptr = std::make_unique(inputs, outputs, kernels, kernels, kernels); + auto pGraph = pGraph_ptr.get(); + if (pGraph == nullptr) { + MS_LOG(ERROR) << "pGraph create error."; + return; + } pGraph->Init(); memcpy(inputs[0]->Data(), input_data, input_size); pGraph->Run(); @@ -59,6 +83,10 @@ TEST_F(TestToFormatOpenCL, TransposeFp32) { size_t output_size; std::string output_path = "./test_data/transpose/transpose_fp32_output.bin"; auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + if (correct_data == nullptr) { + MS_LOG(ERROR) << "correct_data create error."; + return; + } printf("==================output data=================\n"); float *output_data = reinterpret_cast(tensor_out->Data()); std::cout << std::endl; @@ -74,15 +102,7 @@ TEST_F(TestToFormatOpenCL, TransposeFp32) { // compare CompareOutputData(output_data, correct_data, h * w * c, 0.00001); - MS_LOG(INFO) << "TestMatMulFp32 passed"; - for (auto tensor : inputs) { - delete tensor; - } - for (auto tensor : outputs) { - delete tensor; - } - delete arith_kernel; - delete pGraph; + MS_LOG(INFO) << "Test TransposeFp32 passed"; lite::opencl::OpenCLRuntime::DeleteInstance(); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc index 53eefc7e76..5a5882da21 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc @@ -38,20 +38,44 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) { size_t input_size; std::string input_path = "./test_data/transpose/transpose_fp32_input.bin"; auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); - - lite::tensor::Tensor *tensor_x = - new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, h, w, c}, schema::Format_NHWC4); - - lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(TypeId(kNumberTypeFloat32), {1, c, h, w}); + if (input_data == nullptr) { + MS_LOG(ERROR) << "input_data load error."; + return; + } + std::vector input_shape = {1, h, w, c}; + auto tensor_x_ptr = + std::make_unique(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC4); + auto tensor_x = tensor_x_ptr.get(); + if (tensor_x == nullptr) { + MS_LOG(ERROR) << "tensor_x create error."; + return; + } + std::vector out_shape = {1, c, h, w}; + auto tensor_out_ptr = std::make_unique(TypeId(kNumberTypeFloat32), out_shape); + auto tensor_out = tensor_out_ptr.get(); + if (tensor_out == nullptr) { + MS_LOG(ERROR) << "tensor_out create error."; + return; + } std::vector inputs{tensor_x}; std::vector outputs{tensor_out}; - auto *arith_kernel = new kernel::TransposeOpenCLKernel(nullptr, inputs, outputs); + auto arith_kernel_ptr = std::make_unique(nullptr, inputs, outputs); + auto arith_kernel = arith_kernel_ptr.get(); + if (arith_kernel == nullptr) { + MS_LOG(ERROR) << "arith_kernel create error."; + return; + } arith_kernel->Init(); inputs[0]->MallocData(allocator); std::vector kernels{arith_kernel}; - auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + auto pGraph_ptr = std::make_unique(inputs, outputs, kernels, kernels, kernels); + auto pGraph = pGraph_ptr.get(); + if (pGraph == nullptr) { + MS_LOG(ERROR) << "pGraph create error."; + return; + } pGraph->Init(); memcpy(inputs[0]->Data(), input_data, input_size); pGraph->Run(); @@ -59,6 +83,10 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) { size_t output_size; std::string output_path = "./test_data/transpose/transpose_fp32_output.bin"; auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + if (correct_data == nullptr) { + MS_LOG(ERROR) << "correct_data create error."; + return; + } printf("==================output data=================\n"); float *output_data = reinterpret_cast(tensor_out->Data()); std::cout << std::endl; @@ -74,15 +102,7 @@ TEST_F(TestTransposeOpenCL, TransposeFp32) { // compare CompareOutputData(output_data, correct_data, h * w * c, 0.00001); - MS_LOG(INFO) << "TestMatMulFp32 passed"; - for (auto tensor : inputs) { - delete tensor; - } - for (auto tensor : outputs) { - delete tensor; - } - delete arith_kernel; - delete pGraph; + MS_LOG(INFO) << "Test TransposeFp32 passed"; lite::opencl::OpenCLRuntime::DeleteInstance(); } } // namespace mindspore