From e6e0da0ebf9ecac031166241c77a4c3334ac89b6 Mon Sep 17 00:00:00 2001 From: wandongdong Date: Thu, 10 Sep 2020 07:15:13 -0700 Subject: [PATCH] fix opencl context set for fp16 --- mindspore/lite/src/lite_session.cc | 1 + .../kernel/opencl/subgraph_opencl_kernel.cc | 11 ++++++----- .../kernel/opencl/subgraph_opencl_kernel.h | 6 ++++-- mindspore/lite/src/scheduler.cc | 16 ++++++++-------- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index a2e8861f19..a04881baa7 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -307,6 +307,7 @@ int LiteSession::Init(Context *context) { #if SUPPORT_GPU if (context_->device_type_ == DT_GPU) { auto opencl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + opencl_runtime->SetFp16Enable(context_->float16_priority); opencl_runtime->Init(); MS_LOG(INFO) << "Init OpenCL runtime."; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc index 9e5bec0bb8..ec208024a3 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.cc @@ -95,8 +95,9 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector &in_te out_tensors->emplace_back(new_tensor); KernelKey desc{kGPU, kNumberTypeFloat32, schema::PrimitiveType_ToFormat}; - if (lite::opencl::OpenCLRuntime::GetInstance()->GetFp16Enable()) { + if (mem_type == OpenCLMemType::IMG && lite::opencl::OpenCLRuntime::GetInstance()->GetFp16Enable()) { desc.data_type = kNumberTypeFloat16; + new_tensor->set_data_type(kNumberTypeFloat16); } OpenCLToFormatParameter *parameter = new (std::nothrow) OpenCLToFormatParameter; MS_ASSERT(parameter); @@ -112,11 +113,11 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector &in_te out_parameters->emplace_back(parameter); LiteKernel *in_convert_op = nullptr; if (mem_type == OpenCLMemType::IMG) { - in_convert_op = - lite::GetOpenCLKernel({in_tensors[i]}, {new_tensor}, reinterpret_cast(parameter), nullptr, desc); + in_convert_op = lite::GetOpenCLKernel({in_tensors[i]}, {new_tensor}, reinterpret_cast(parameter), + context_, desc); } else { - in_convert_op = - lite::GetOpenCLKernel({new_tensor}, {in_tensors[i]}, reinterpret_cast(parameter), nullptr, desc); + in_convert_op = lite::GetOpenCLKernel({new_tensor}, {in_tensors[i]}, reinterpret_cast(parameter), + context_, desc); } MS_ASSERT(in_convert_op); if (in_convert_op == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h index a2f8a0a2b6..e7db0e6540 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h +++ b/mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h @@ -34,8 +34,10 @@ class SubGraphOpenCLKernel : public SubGraphKernel { explicit SubGraphOpenCLKernel(const std::vector inputs, const std::vector outputs, const std::vector inKernels, const std::vector outKernels, - const std::vector nodes) - : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, nullptr, nullptr) {} + const std::vector nodes, + const lite::Context *ctx = nullptr, + const mindspore::lite::PrimitiveC *primitive = nullptr) + : SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx, primitive) {} ~SubGraphOpenCLKernel() override; int Init() override; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index e8ee1cec96..030bdc061e 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -178,6 +178,12 @@ void Scheduler::ConstructSubgraphs(std::vector *kernels) { std::vector subgraph_kernels; size_t sub_cnt{0}; for (auto temp_kernels : sub_kernels_list) { + std::vector output_tensor = kernel::LiteKernelUtil::SubgraphOutputTensors(temp_kernels); + for (auto tensor : output_tensor) { + if (context_->float16_priority && tensor->data_type() == kNumberTypeFloat16) { + tensor->set_data_type(kNumberTypeFloat32); + } + } kernel::KERNEL_ARCH arch = temp_kernels.front()->desc().arch; if (arch == kernel::KERNEL_ARCH::kCPU) { for (auto kernel : temp_kernels) { @@ -185,12 +191,6 @@ void Scheduler::ConstructSubgraphs(std::vector *kernels) { tensor->set_allocator(context_->allocator.get()); } } - std::vector output_tensor = kernel::LiteKernelUtil::SubgraphOutputTensors(temp_kernels); - for (auto tensor : output_tensor) { - if (context_->float16_priority && tensor->data_type() == kNumberTypeFloat16) { - tensor->set_data_type(kNumberTypeFloat32); - } - } std::copy(temp_kernels.begin(), temp_kernels.end(), std::back_inserter(subgraph_kernels)); } else { auto subgraph_kernel = CreateSubKernel(temp_kernels, arch); @@ -213,8 +213,8 @@ kernel::LiteKernel *Scheduler::CreateSubKernel(const std::vector output_tensors = kernel::LiteKernelUtil::SubgraphOutputTensors(kernels); std::vector input_kernels = kernel::LiteKernelUtil::SubgraphInputKernels(kernels); std::vector output_kernels = kernel::LiteKernelUtil::SubgraphOutputKernels(kernels); - sub_kernel = - new kernel::SubGraphOpenCLKernel(input_tensors, output_tensors, input_kernels, output_kernels, kernels); + sub_kernel = new kernel::SubGraphOpenCLKernel(input_tensors, output_tensors, input_kernels, output_kernels, kernels, + context_, nullptr); sub_kernel->Init(); } else if (arch == kernel::KERNEL_ARCH::kNPU) { MS_LOG(ERROR) << "NPU kernel is not supported";