diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 7ff6a9bdc1..449d483d92 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -26,9 +26,6 @@ #include "src/common/utils.h" #include "src/common/graph_util.h" #include "src/kernel_registry.h" -#if SUPPORT_GPU -#include "src/runtime/opencl/opencl_runtime.h" -#endif namespace mindspore { namespace lite { @@ -343,7 +340,7 @@ int LiteSession::Init(Context *context) { } #if SUPPORT_GPU if (context_->device_type_ == DT_GPU) { - auto opencl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto opencl_runtime = ocl_runtime_wrap_.GetInstance(); opencl_runtime->SetFp16Enable(context_->float16_priority); if (opencl_runtime->Init() != RET_OK) { context_->device_type_ = DT_CPU; @@ -394,11 +391,6 @@ LiteSession::~LiteSession() { for (auto *kernel : kernels_) { delete kernel; } -#if SUPPORT_GPU - if (context_->device_type_ == DT_GPU) { - lite::opencl::OpenCLRuntime::DeleteInstance(); - } -#endif delete this->context_; delete this->executor; this->executor = nullptr; diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h index 7046086623..b0f3ca6f52 100644 --- a/mindspore/lite/src/lite_session.h +++ b/mindspore/lite/src/lite_session.h @@ -30,6 +30,9 @@ #include "schema/model_generated.h" #include "src/executor.h" #include "src/tensor.h" +#if SUPPORT_GPU +#include "src/runtime/opencl/opencl_runtime.h" +#endif namespace mindspore { namespace lite { @@ -108,6 +111,9 @@ class LiteSession : public session::LiteSession { std::unordered_map output_tensor_map_; Executor *executor = nullptr; std::atomic is_running_ = false; +#if SUPPORT_GPU + opencl::OpenCLRuntimeWrapper ocl_runtime_wrap_; +#endif }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc index c7b94cdbae..b600cd149a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc @@ -54,7 +54,7 @@ int GatherOpenCLKernel::Init() { auto indices_tensor = in_tensors_.at(1); int indices_num = indices_tensor->ElementsNum(); bool isIndicesInt32 = indices_tensor->data_type() == kNumberTypeInt32; - auto allocator = lite::opencl::OpenCLRuntime::GetInstance()->GetAllocator(); + auto allocator = ocl_runtime_->GetAllocator(); if (!isIndicesInt32) { indices_data_ = reinterpret_cast(allocator->Malloc(sizeof(int32_t) * indices_num)); if (indices_data_ == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h index 7e10112f7c..a3f36c5c66 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h @@ -38,15 +38,10 @@ class OpenCLKernel : public LiteKernel { explicit OpenCLKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs) : LiteKernel(parameter, inputs, outputs, nullptr, nullptr) { - ocl_runtime_ = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime_ = ocl_runtime_wrap_.GetInstance(); } - ~OpenCLKernel() { - if (ocl_runtime_ != nullptr) { - lite::opencl::OpenCLRuntime::DeleteInstance(); - ocl_runtime_ = nullptr; - } - } + ~OpenCLKernel() {} virtual int Init() { return RET_ERROR; } virtual int Prepare() { return RET_ERROR; } @@ -69,7 +64,8 @@ class OpenCLKernel : public LiteKernel { schema::Format in_ori_format_{schema::Format::Format_NHWC}; schema::Format out_ori_format_{schema::Format::Format_NHWC4}; schema::Format op_format_{schema::Format::Format_NHWC4}; - lite::opencl::OpenCLRuntime *ocl_runtime_{nullptr}; + lite::opencl::OpenCLRuntimeWrapper ocl_runtime_wrap_; + lite::opencl::OpenCLRuntime *ocl_runtime_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc index 6d2e4ab367..ba6dcc579a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc @@ -17,7 +17,6 @@ #include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h" #include #include "src/runtime/opencl/opencl_executor.h" -#include "src/runtime/opencl/opencl_runtime.h" #include "src/runtime/kernel/opencl/utils.h" #include "include/errorcode.h" #include "src/common/utils.h" @@ -161,7 +160,6 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector &in_te } int SubGraphOpenCLKernel::Init() { - ocl_runtime_ = lite::opencl::OpenCLRuntime::GetInstance(); allocator_ = ocl_runtime_->GetAllocator(); MS_LOG(DEBUG) << "input num=" << in_tensors_.size() << ", output num=" << out_tensors_.size(); for (const auto tensor : in_tensors_) { @@ -308,10 +306,6 @@ int SubGraphOpenCLKernel::UnInit() { nodes_.clear(); in_convert_ops_.clear(); out_convert_ops_.clear(); - if (ocl_runtime_ != nullptr) { - lite::opencl::OpenCLRuntime::DeleteInstance(); - ocl_runtime_ = nullptr; - } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h index e6a083c692..6d233e2db4 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h +++ b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h @@ -36,7 +36,9 @@ class SubGraphOpenCLKernel : public SubGraphKernel { const std::vector outKernels, const std::vector nodes, const lite::InnerContext *ctx = nullptr, const mindspore::lite::PrimitiveC *primitive = nullptr) - : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx, primitive) {} + : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx, primitive) { + ocl_runtime_ = ocl_runtime_wrap_.GetInstance(); + } ~SubGraphOpenCLKernel() override; int Init() override; @@ -64,6 +66,7 @@ class SubGraphOpenCLKernel : public SubGraphKernel { std::vector out_parameters_; std::vector in_convert_ops_; std::vector out_convert_ops_; + lite::opencl::OpenCLRuntimeWrapper ocl_runtime_wrap_; lite::opencl::OpenCLRuntime *ocl_runtime_{nullptr}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.cc b/mindspore/lite/src/runtime/kernel/opencl/utils.cc index 1046afb83e..9936c2ac42 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/utils.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.cc @@ -15,9 +15,11 @@ */ #include "src/runtime/kernel/opencl/utils.h" +#include #include #include #include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" using mindspore::lite::KernelRegistrar; @@ -221,4 +223,64 @@ std::string CLErrorCode(cl_int error_code) { return "Unknown OpenCL error code"; } } + +void Write2File(void *mem, const std::string &file_name, int size) { + std::fstream os; + os.open(file_name, std::ios::out | std::ios::binary); + os.write(static_cast(mem), size); + os.close(); +} + +void PrintTensor(lite::Tensor *tensor, int num, const std::string &out_file) { + if (tensor->data_c() == nullptr) { + return; + } + auto runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); + runtime->SyncCommandQueue(); + + auto allocator = runtime->GetAllocator(); + auto origin_data = tensor->data_c(); + allocator->MapBuffer(origin_data, CL_MAP_READ, nullptr, true); + tensor->SetData(origin_data); + + auto Height = tensor->shape().size() == 4 ? tensor->Height() : 1; + auto Width = tensor->shape().size() == 4 ? tensor->Width() : 1; + auto SLICES = UP_DIV(tensor->Channel(), C4NUM); + auto alignment = runtime->GetImagePitchAlignment(); + auto dtype_size = tensor->data_type() == kNumberTypeFloat16 ? sizeof(cl_half4) : sizeof(cl_float4); + auto row_pitch = (Width * SLICES + alignment - 1) / alignment * alignment * dtype_size; + auto row_size = Width * SLICES * dtype_size; + std::cout << "tensor->GetFormat() =" << tensor->GetFormat() << "\n"; + std::cout << "Height =" << Height << "\n"; + std::cout << "Width =" << Width << "\n"; + std::cout << "SLICES =" << SLICES << "\n"; + std::cout << "image_alignment =" << alignment << "\n"; + std::cout << "dtype_size =" << dtype_size << "\n"; + std::cout << "row_pitch =" << row_pitch << "\n"; + std::cout << "row_size =" << row_size << "\n"; + std::cout << "tensor->Size() =" << tensor->Size() << "\n"; + std::vector data(tensor->Size()); + for (int i = 0; i < Height; ++i) { + memcpy(static_cast(data.data()) + i * row_size, static_cast(origin_data) + i * row_pitch, row_size); + } + + std::cout << "shape=("; + for (auto x : tensor->shape()) { + printf("%3d,", x); + } + printf("): "); + + for (size_t i = 0; i < num && i < tensor->ElementsNum(); ++i) { + if (tensor->data_type() == kNumberTypeFloat16) + printf("%zu %6.3f | ", i, (reinterpret_cast(data.data()))[i]); + else + printf("%zu %6.3f | ", i, (reinterpret_cast(data.data()))[i]); + } + printf("\n"); + + if (!out_file.empty()) { + Write2File(data.data(), out_file, tensor->Size()); + } + allocator->UnmapBuffer(origin_data); +} } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.h b/mindspore/lite/src/runtime/kernel/opencl/utils.h index f163cd4687..92cddb32ed 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/utils.h +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.h @@ -44,6 +44,10 @@ std::vector GetCommonLocalSize(const std::vector &global, int ma std::string CLErrorCode(cl_int error_code); +void Write2File(void *mem, const std::string &file_name, int size); + +void PrintTensor(lite::Tensor *tensor, int num = 10, const std::string &out_file = ""); + template void PackNCHWToNC4HW4(void *src, void *dst, int batch, int plane, int channel, const std::function &to_dtype) { int c4 = UP_DIV(channel, C4NUM); diff --git a/mindspore/lite/src/runtime/opencl/opencl_executor.h b/mindspore/lite/src/runtime/opencl/opencl_executor.h index b12a793238..9ada5741a8 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_executor.h +++ b/mindspore/lite/src/runtime/opencl/opencl_executor.h @@ -27,11 +27,7 @@ namespace mindspore::lite::opencl { class OpenCLExecutor : Executor { public: - OpenCLExecutor() : Executor() { - auto ocl_runtime = OpenCLRuntime::GetInstance(); - allocator_ = ocl_runtime->GetAllocator(); - OpenCLRuntime::DeleteInstance(); - } + OpenCLExecutor() : Executor() { allocator_ = ocl_runtime.GetInstance()->GetAllocator(); } int Prepare(const std::vector &kernels); @@ -42,6 +38,7 @@ class OpenCLExecutor : Executor { protected: InnerContext *context = nullptr; OpenCLAllocator *allocator_; + OpenCLRuntimeWrapper ocl_runtime; }; } // namespace mindspore::lite::opencl #endif diff --git a/mindspore/lite/src/runtime/opencl/opencl_runtime.cc b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc index 0d87c7f12a..10b24867b8 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_runtime.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc @@ -393,11 +393,16 @@ int OpenCLRuntime::RunKernel(const cl::Kernel &kernel, const std::vector cl::Event event; cl_int ret = CL_SUCCESS; ret = command_queue->enqueueNDRangeKernel(kernel, cl::NullRange, global_range, local_range, nullptr, &event); - if (ret != CL_SUCCESS) { MS_LOG(ERROR) << "Kernel execute failed:" << CLErrorCode(ret); return RET_ERROR; } + static int cnt = 0; + const int flush_period = 10; + if (cnt % flush_period == 0) { + command_queue->flush(); + } + cnt++; MS_LOG(DEBUG) << "RunKernel success!"; #if MS_OPENCL_PROFILE event.wait(); diff --git a/mindspore/lite/src/runtime/opencl/opencl_runtime.h b/mindspore/lite/src/runtime/opencl/opencl_runtime.h index c413e0406d..e04229771c 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_runtime.h +++ b/mindspore/lite/src/runtime/opencl/opencl_runtime.h @@ -37,11 +37,10 @@ struct GpuInfo { int model_num = 0; float opencl_version = 0; }; - +class OpenCLRuntimeWrapper; class OpenCLRuntime { public: - static OpenCLRuntime *GetInstance(); - static void DeleteInstance(); + friend OpenCLRuntimeWrapper; ~OpenCLRuntime(); OpenCLRuntime(const OpenCLRuntime &) = delete; @@ -138,6 +137,8 @@ class OpenCLRuntime { int GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id); private: + static OpenCLRuntime *GetInstance(); + static void DeleteInstance(); OpenCLRuntime(); GpuInfo ParseGpuInfo(std::string device_name, std::string device_version); @@ -169,5 +170,16 @@ class OpenCLRuntime { void *handle_{nullptr}; }; +class OpenCLRuntimeWrapper { + public: + OpenCLRuntimeWrapper() { ocl_runtime_ = OpenCLRuntime::GetInstance(); } + ~OpenCLRuntimeWrapper() { OpenCLRuntime::DeleteInstance(); } + explicit OpenCLRuntimeWrapper(const OpenCLRuntime &) = delete; + OpenCLRuntimeWrapper &operator=(const OpenCLRuntime &) = delete; + OpenCLRuntime *GetInstance() { return ocl_runtime_; } + + private: + OpenCLRuntime *ocl_runtime_{nullptr}; +}; } // namespace mindspore::lite::opencl #endif // MINDSPORE_LITE_SRC_OPENCL_RUNTIME_H_ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc index 6778ed9f42..135a6f7611 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc @@ -82,7 +82,7 @@ TEST_F(TestActivationOpenCL, ReluFp_dim4) { std::string in_file = "/data/local/tmp/in_data.bin"; std::string out_file = "/data/local/tmp/relu.bin"; MS_LOG(INFO) << "Relu Begin test!"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); auto data_type = kNumberTypeFloat16; @@ -184,14 +184,13 @@ TEST_F(TestActivationOpenCL, ReluFp_dim4) { delete input_tensor; delete output_tensor; delete sub_graph; - lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestActivationOpenCL, Relu6Fp_dim4) { std::string in_file = "/data/local/tmp/in_data.bin"; std::string out_file = "/data/local/tmp/relu6.bin"; MS_LOG(INFO) << "Relu6 Begin test!"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); auto data_type = kNumberTypeFloat16; ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); bool enable_fp16 = ocl_runtime->GetFp16Enable(); @@ -296,14 +295,13 @@ TEST_F(TestActivationOpenCL, Relu6Fp_dim4) { delete input_tensor; delete output_tensor; delete sub_graph; - lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestActivationOpenCL, SigmoidFp_dim4) { std::string in_file = "/data/local/tmp/in_data.bin"; std::string out_file = "/data/local/tmp/sigmoid.bin"; MS_LOG(INFO) << "Sigmoid Begin test!"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto data_type = kNumberTypeFloat32; ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); @@ -408,14 +406,13 @@ TEST_F(TestActivationOpenCL, SigmoidFp_dim4) { delete input_tensor; delete output_tensor; delete sub_graph; - lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestActivationOpenCL, LeakyReluFp_dim4) { std::string in_file = "/data/local/tmp/in_data.bin"; std::string out_file = "/data/local/tmp/leaky_relu.bin"; MS_LOG(INFO) << "Leaky relu Begin test!"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto data_type = kNumberTypeFloat16; // need modify ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); @@ -519,14 +516,13 @@ TEST_F(TestActivationOpenCL, LeakyReluFp_dim4) { delete param; delete input_tensor; delete output_tensor; - lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestActivationOpenCLTanh, TanhFp_dim4) { std::string in_file = "/data/local/tmp/test_data/in_tanhfp16.bin"; std::string out_file = "/data/local/tmp/test_data/out_tanhfp16.bin"; MS_LOG(INFO) << "Tanh Begin test!"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto data_type = kNumberTypeFloat16; ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); @@ -627,7 +623,6 @@ TEST_F(TestActivationOpenCLTanh, TanhFp_dim4) { printf_tensor("Tanh:FP32--output data---", outputs[0]); CompareRes(output_tensor, out_file); } - lite::opencl::OpenCLRuntime::DeleteInstance(); input_tensor->SetData(nullptr); delete input_tensor; output_tensor->SetData(nullptr); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc index 9c246a83bb..016f45dbf6 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc @@ -43,7 +43,7 @@ void CompareOutputData1(T *input_data1, T *output_data, T *correct_data, int siz TEST_F(TestArithmeticSelfOpenCLfp16, ArithmeticSelfOpenCLFp16) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->SetFp16Enable(true); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -125,7 +125,6 @@ TEST_F(TestArithmeticSelfOpenCLfp16, ArithmeticSelfOpenCLFp16) { sub_graph->Run(); auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(input_data1, output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); - lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { tensor->SetData(nullptr); delete tensor; @@ -139,7 +138,7 @@ TEST_F(TestArithmeticSelfOpenCLfp16, ArithmeticSelfOpenCLFp16) { TEST_F(TestArithmeticSelfOpenCLCI, ArithmeticSelfRound) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); float input_data1[] = {0.75f, 0.06f, 0.74f, 0.30f, 0.9f, 0.59f, 0.03f, 0.37f, @@ -216,7 +215,6 @@ TEST_F(TestArithmeticSelfOpenCLCI, ArithmeticSelfRound) { sub_graph->Run(); auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(input_data1, output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); - lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { tensor->SetData(nullptr); delete tensor; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc index 1408a50d8f..d052b3b2ed 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc @@ -68,7 +68,7 @@ static void LogData(void *data, const int size, const std::string prefix) { template static void TestCase(const std::vector &shape_a, const std::vector &shape_b) { bool is_log_data = false; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); auto allocator = ocl_runtime->GetAllocator(); bool is_bias_add = shape_b.empty(); @@ -212,7 +212,6 @@ static void TestCase(const std::vector &shape_a, const std::vector &sh for (auto tensor : outputs) { delete tensor; } - lite::opencl::OpenCLRuntime::DeleteInstance(); } class TestArithmeticOpenCL : public mindspore::CommonTest { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc index 4911a5fef2..532c7fc9e5 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc @@ -53,7 +53,7 @@ void InitAvgPoolingParam(PoolingParameter *param) { } void RunTestCaseAvgPooling(const std::vector &shape, void *input_data, void *output_data, bool enable_fp16) { - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float); ocl_runtime->SetFp16Enable(enable_fp16); @@ -125,7 +125,6 @@ void RunTestCaseAvgPooling(const std::vector &shape, void *input_data, void } MS_LOG(INFO) << "Test AvgPool2d passed"; - lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestAvgPoolingOpenCL, AvgPoolingFp32) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc index 19fa65862a..59ce5d59ef 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc @@ -38,7 +38,7 @@ class TestBatchnormOpenCLCI : public mindspore::CommonTest { TEST_F(TestBatchnormOpenCLCI, Batchnormfp32CI) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -142,7 +142,6 @@ TEST_F(TestBatchnormOpenCLCI, Batchnormfp32CI) { auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); - lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { tensor->SetData(nullptr); delete tensor; @@ -156,7 +155,7 @@ TEST_F(TestBatchnormOpenCLCI, Batchnormfp32CI) { TEST_F(TestBatchnormOpenCLfp16, Batchnormfp16input_dim4) { MS_LOG(INFO) << "begin test"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->SetFp16Enable(true); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -262,7 +261,6 @@ TEST_F(TestBatchnormOpenCLfp16, Batchnormfp16input_dim4) { auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.01); - lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { tensor->SetData(nullptr); delete tensor; @@ -276,7 +274,7 @@ TEST_F(TestBatchnormOpenCLfp16, Batchnormfp16input_dim4) { TEST_F(TestBatchnormOpenCLfp32, Batchnormfp32input_dim4) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -381,7 +379,6 @@ TEST_F(TestBatchnormOpenCLfp32, Batchnormfp32input_dim4) { auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); - lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { tensor->SetData(nullptr); delete tensor; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc index 87e8c54cf6..d9cc475a76 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/biasadd_tests.cc @@ -75,7 +75,7 @@ TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) { std::string weight_file = "/data/local/tmp/weight_data.bin"; std::string standard_answer_file = "/data/local/tmp/biasadd.bin"; MS_LOG(INFO) << "BiasAdd Begin test:"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto data_type = kNumberTypeFloat16; // need modify ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); @@ -200,6 +200,5 @@ TEST_F(TestBiasAddOpenCL, BiasAddFp32_dim4) { delete output_tensor; delete sub_graph; delete param; - lite::opencl::OpenCLRuntime::DeleteInstance(); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc index 2f9b7a1a6b..c5b3300590 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc @@ -38,7 +38,7 @@ void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bou TEST_F(TestCastSelfOpenCL, Castfp32tofp16) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -113,7 +113,6 @@ TEST_F(TestCastSelfOpenCL, Castfp32tofp16) { sub_graph->Run(); auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); - lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { tensor->SetData(nullptr); delete tensor; @@ -127,7 +126,7 @@ TEST_F(TestCastSelfOpenCL, Castfp32tofp16) { TEST_F(TestCastSelfOpenCL, Castfp16tofp32) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -201,7 +200,6 @@ TEST_F(TestCastSelfOpenCL, Castfp16tofp32) { sub_graph->Run(); auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); - lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { tensor->SetData(nullptr); delete tensor; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc index 2de645340e..b2339186ab 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc @@ -47,7 +47,7 @@ void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bou TEST_F(TestConcatOpenCLCI, ConcatFp32_2inputforCI) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -134,7 +134,6 @@ TEST_F(TestConcatOpenCLCI, ConcatFp32_2inputforCI) { sub_graph->Run(); auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.00001); - lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { tensor->SetData(nullptr); delete tensor; @@ -148,7 +147,7 @@ TEST_F(TestConcatOpenCLCI, ConcatFp32_2inputforCI) { TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis1) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->SetFp16Enable(true); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -264,7 +263,6 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis1) { sub_graph->Run(); auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); - lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { tensor->SetData(nullptr); delete tensor; @@ -278,7 +276,7 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis1) { TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -385,7 +383,6 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) { sub_graph->Run(); auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.00001); - lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { tensor->SetData(nullptr); delete tensor; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc index 3f7eca1639..cc4b507882 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc @@ -32,7 +32,7 @@ class TestConv2dTransposeOpenCL : public mindspore::CommonTest { void RunTestCaseConv2dTranspose(const std::vector &shape, void *input_data, void *weight_data, void *bias_data, void *output_data, bool enable_fp16) { - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float); ocl_runtime->SetFp16Enable(enable_fp16); @@ -134,7 +134,6 @@ void RunTestCaseConv2dTranspose(const std::vector &shape, void *input_data, for (auto t : outputs) { t->SetData(nullptr); } - lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestConv2dTransposeOpenCL, Conv2dTransposeFp32) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/convolution_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/convolution_tests.cc index 3d35ec23e1..c0ac7d6fe2 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/convolution_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/convolution_tests.cc @@ -157,7 +157,7 @@ void TEST_MAIN(const std::string &attr, Format input_format, Format output_forma ¶m->dilation_h_, ¶m->dilation_w_); MS_LOG(DEBUG) << "initialize OpenCLRuntime and OpenCLAllocator"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); auto allocator = ocl_runtime->GetAllocator(); @@ -201,7 +201,6 @@ void TEST_MAIN(const std::string &attr, Format input_format, Format output_forma input.SetData(nullptr); output.SetData(nullptr); delete sub_graph; - lite::opencl::OpenCLRuntime::DeleteInstance(); } void TEST_MAIN(const std::string &attr, Format input_format, Format output_format, const TypeId data_type, diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc index 7903cefa26..f9b92731ef 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc @@ -33,7 +33,7 @@ class TestConvolutionDwOpenCL : public mindspore::CommonTest { template void DepthWiseTestMain(ConvParameter *conv_param, T2 *input_data, T1 *weight_data, T2 *gnd_data, schema::Format format, TypeId dtype = kNumberTypeFloat32, bool is_compare = true, T2 err_max = 1e-5) { - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); if (dtype == kNumberTypeFloat16) { @@ -167,7 +167,6 @@ void DepthWiseTestMain(ConvParameter *conv_param, T2 *input_data, T1 *weight_dat inputs[1]->SetData(nullptr); inputs[2]->SetData(nullptr); delete[] packed_input; - lite::opencl::OpenCLRuntime::DeleteInstance(); inputs[0]->SetData(nullptr); outputs[0]->SetData(nullptr); return; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/gather_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/gather_tests.cc index 41b5f3f7bc..d649f3aa10 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/gather_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/gather_tests.cc @@ -32,7 +32,7 @@ void test_main_gather(void *input_data, void *correct_data, const std::vector &indices, GatherParameter *param, TypeId data_type, schema::Format format) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc index 579e2e25d1..151e20e39d 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc @@ -31,7 +31,7 @@ class TestMatMulOpenCL : public mindspore::CommonTest { void RunTestCaseMatMul(const std::vector &shape, void *input_data, void *weight_data, void *output_data, bool enable_fp16, int dims) { - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float); ocl_runtime->SetFp16Enable(enable_fp16); @@ -123,7 +123,6 @@ void RunTestCaseMatMul(const std::vector &shape, void *input_data, void *we t->SetData(nullptr); } MS_LOG(INFO) << "TestMatMul passed"; - lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestMatMulOpenCL, MatMul2DFp32) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/max_pooling_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/max_pooling_tests.cc index 183600ee7c..c89fd1bdbd 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/max_pooling_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/max_pooling_tests.cc @@ -53,7 +53,7 @@ void InitMaxPoolingParam(PoolingParameter *param) { } void RunTestCaseMaxPooling(const std::vector &shape, void *input_data, void *output_data, bool enable_fp16) { - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float); ocl_runtime->SetFp16Enable(enable_fp16); @@ -124,7 +124,6 @@ void RunTestCaseMaxPooling(const std::vector &shape, void *input_data, void } MS_LOG(INFO) << "Test MaxPool2d passed"; - lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestMaxPoolingOpenCL, MaxPoolingFp32) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc index d759e85452..97a7db6dba 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc @@ -77,7 +77,7 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) { std::string weight_file = "/data/local/tmp/weight_data.bin"; std::string standard_answer_file = "/data/local/tmp/caffe_prelu.bin"; MS_LOG(INFO) << "-------------------->> Begin test PRelu!"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -194,6 +194,5 @@ TEST_F(TestPReluOpenCL, PReluFp32_dim4) { delete weight_tensor; delete param; delete sub_graph; - lite::opencl::OpenCLRuntime::DeleteInstance(); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc index eb2dd818f1..cedd0367e4 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/reduce_tests.cc @@ -31,7 +31,7 @@ class TestReduceOpenCL : public mindspore::CommonTest { void RunTestCaseReduce(const std::vector &shape, void *input_data, void *output_data, bool enable_fp16, int reduce_mode) { - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float); ocl_runtime->SetFp16Enable(enable_fp16); @@ -103,7 +103,6 @@ void RunTestCaseReduce(const std::vector &shape, void *input_data, void *ou } MS_LOG(INFO) << "Test Reduce passed"; - lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestReduceOpenCL, ReduceMeanFp32) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc index d9f518522a..9204ec22fe 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/reshape_tests.cc @@ -31,7 +31,7 @@ class TestReshapeOpenCL : public mindspore::CommonTest { void RunTestCaseReshape(const std::vector &shape, void *input_data, void *output_data, bool enable_fp16, bool is_output_2d) { - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float); ocl_runtime->SetFp16Enable(enable_fp16); @@ -99,7 +99,6 @@ void RunTestCaseReshape(const std::vector &shape, void *input_data, void *o } MS_LOG(INFO) << "Test Reshape passed"; - lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestReshapeOpenCL, ReshapeFp32) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc index 1875aec574..96ff554b09 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc @@ -68,7 +68,7 @@ static void LogData(void *data, const int size, const std::string prefix) { template static void TestCase(const std::vector &shape_a, const std::vector &shape_b) { bool is_log_data = false; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); auto allocator = ocl_runtime->GetAllocator(); bool is_broadcast = shape_b.empty(); @@ -232,7 +232,6 @@ static void TestCase(const std::vector &shape_a, const std::vector &sh for (auto tensor : outputs) { delete tensor; } - lite::opencl::OpenCLRuntime::DeleteInstance(); } class TestScaleOpenCL : public mindspore::CommonTest { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc index a9560bcd35..0fb44f3332 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc @@ -42,7 +42,7 @@ void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bou TEST_F(TestSliceOpenCLfp32, Slicefp32CI) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -139,7 +139,6 @@ TEST_F(TestSliceOpenCLfp32, Slicefp32CI) { auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); - lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { tensor->SetData(nullptr); delete tensor; @@ -153,7 +152,7 @@ TEST_F(TestSliceOpenCLfp32, Slicefp32CI) { TEST_F(TestSliceOpenCLfp32, Slicefp32input_dim4) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -248,7 +247,6 @@ TEST_F(TestSliceOpenCLfp32, Slicefp32input_dim4) { auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); - lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { tensor->SetData(nullptr); delete tensor; @@ -262,7 +260,7 @@ TEST_F(TestSliceOpenCLfp32, Slicefp32input_dim4) { TEST_F(TestSliceOpenCLfp16, Slicefp16input_dim4) { MS_LOG(INFO) << " begin test "; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->SetFp16Enable(true); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); @@ -358,7 +356,6 @@ TEST_F(TestSliceOpenCLfp16, Slicefp16input_dim4) { sub_graph->Run(); auto *output_data_gpu = reinterpret_cast(output_tensor->data_c()); CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); - lite::opencl::OpenCLRuntime::DeleteInstance(); for (auto tensor : inputs) { tensor->SetData(nullptr); delete tensor; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc index 6457ab2108..b82806ad20 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc @@ -30,7 +30,7 @@ class TestSoftmaxOpenCL : public mindspore::CommonTest { }; void RunTestCaseSoftmax(const std::vector &shape, void *input_data, void *output_data, bool enable_fp16) { - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float); ocl_runtime->SetFp16Enable(enable_fp16); @@ -103,7 +103,6 @@ void RunTestCaseSoftmax(const std::vector &shape, void *input_data, void *o } MS_LOG(INFO) << "Test Softmax passed"; - lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestSoftmaxOpenCL, Softmax2DFp32) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc index 3b03679586..11280707de 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc @@ -29,7 +29,7 @@ class TestToFormatOpenCL : public mindspore::CommonTest { }; TEST_F(TestToFormatOpenCL, ToFormatNHWC2NCHW) { - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); auto allocator = ocl_runtime->GetAllocator(); int h = 64; @@ -102,6 +102,5 @@ TEST_F(TestToFormatOpenCL, ToFormatNHWC2NCHW) { // compare CompareOutputData(output_data, correct_data, h * w * c, 0.00001); MS_LOG(INFO) << "Test TransposeFp32 passed"; - lite::opencl::OpenCLRuntime::DeleteInstance(); } } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc index 04c22cb6dd..311c0b0ce6 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc @@ -30,7 +30,7 @@ class TestTransposeOpenCL : public mindspore::CommonTest { }; void RunTestTranspose(const std::vector &shape, void *input_data, void *output_data, bool enable_fp16) { - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); ocl_runtime->Init(); size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float); ocl_runtime->SetFp16Enable(enable_fp16); @@ -103,7 +103,6 @@ void RunTestTranspose(const std::vector &shape, void *input_data, void *out } MS_LOG(INFO) << "Test TransposeFp32 passed"; - lite::opencl::OpenCLRuntime::DeleteInstance(); } TEST_F(TestTransposeOpenCL, TransposeNHWC2NCHWFp32) {