!6011 fix opencl context set for fp16

Merge pull request !6011 from wandongdong/master
pull/6011/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 3c4c53bdf3

@ -307,6 +307,7 @@ int LiteSession::Init(Context *context) {
#if SUPPORT_GPU #if SUPPORT_GPU
if (context_->device_type_ == DT_GPU) { if (context_->device_type_ == DT_GPU) {
auto opencl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); auto opencl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
opencl_runtime->SetFp16Enable(context_->float16_priority);
opencl_runtime->Init(); opencl_runtime->Init();
MS_LOG(INFO) << "Init OpenCL runtime."; MS_LOG(INFO) << "Init OpenCL runtime.";
} }

@ -95,8 +95,9 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::Tensor *> &in_te
out_tensors->emplace_back(new_tensor); out_tensors->emplace_back(new_tensor);
KernelKey desc{kGPU, kNumberTypeFloat32, schema::PrimitiveType_ToFormat}; KernelKey desc{kGPU, kNumberTypeFloat32, schema::PrimitiveType_ToFormat};
if (lite::opencl::OpenCLRuntime::GetInstance()->GetFp16Enable()) { if (mem_type == OpenCLMemType::IMG && lite::opencl::OpenCLRuntime::GetInstance()->GetFp16Enable()) {
desc.data_type = kNumberTypeFloat16; desc.data_type = kNumberTypeFloat16;
new_tensor->set_data_type(kNumberTypeFloat16);
} }
OpenCLToFormatParameter *parameter = new (std::nothrow) OpenCLToFormatParameter; OpenCLToFormatParameter *parameter = new (std::nothrow) OpenCLToFormatParameter;
MS_ASSERT(parameter); MS_ASSERT(parameter);
@ -112,11 +113,11 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::Tensor *> &in_te
out_parameters->emplace_back(parameter); out_parameters->emplace_back(parameter);
LiteKernel *in_convert_op = nullptr; LiteKernel *in_convert_op = nullptr;
if (mem_type == OpenCLMemType::IMG) { if (mem_type == OpenCLMemType::IMG) {
in_convert_op = in_convert_op = lite::GetOpenCLKernel({in_tensors[i]}, {new_tensor}, reinterpret_cast<OpParameter *>(parameter),
lite::GetOpenCLKernel({in_tensors[i]}, {new_tensor}, reinterpret_cast<OpParameter *>(parameter), nullptr, desc); context_, desc);
} else { } else {
in_convert_op = in_convert_op = lite::GetOpenCLKernel({new_tensor}, {in_tensors[i]}, reinterpret_cast<OpParameter *>(parameter),
lite::GetOpenCLKernel({new_tensor}, {in_tensors[i]}, reinterpret_cast<OpParameter *>(parameter), nullptr, desc); context_, desc);
} }
MS_ASSERT(in_convert_op); MS_ASSERT(in_convert_op);
if (in_convert_op == nullptr) { if (in_convert_op == nullptr) {

@ -34,8 +34,10 @@ class SubGraphOpenCLKernel : public SubGraphKernel {
explicit SubGraphOpenCLKernel(const std::vector<lite::Tensor *> inputs, const std::vector<lite::Tensor *> outputs, explicit SubGraphOpenCLKernel(const std::vector<lite::Tensor *> inputs, const std::vector<lite::Tensor *> outputs,
const std::vector<kernel::LiteKernel *> inKernels, const std::vector<kernel::LiteKernel *> inKernels,
const std::vector<kernel::LiteKernel *> outKernels, const std::vector<kernel::LiteKernel *> outKernels,
const std::vector<kernel::LiteKernel *> nodes) const std::vector<kernel::LiteKernel *> nodes,
: SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, nullptr, nullptr) {} const lite::Context *ctx = nullptr,
const mindspore::lite::PrimitiveC *primitive = nullptr)
: SubGraphKernel(inputs, outputs, inKernels, outKernels, nodes, ctx, primitive) {}
~SubGraphOpenCLKernel() override; ~SubGraphOpenCLKernel() override;
int Init() override; int Init() override;

@ -178,6 +178,12 @@ void Scheduler::ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels) {
std::vector<kernel::LiteKernel *> subgraph_kernels; std::vector<kernel::LiteKernel *> subgraph_kernels;
size_t sub_cnt{0}; size_t sub_cnt{0};
for (auto temp_kernels : sub_kernels_list) { for (auto temp_kernels : sub_kernels_list) {
std::vector<Tensor *> output_tensor = kernel::LiteKernelUtil::SubgraphOutputTensors(temp_kernels);
for (auto tensor : output_tensor) {
if (context_->float16_priority && tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
}
}
kernel::KERNEL_ARCH arch = temp_kernels.front()->desc().arch; kernel::KERNEL_ARCH arch = temp_kernels.front()->desc().arch;
if (arch == kernel::KERNEL_ARCH::kCPU) { if (arch == kernel::KERNEL_ARCH::kCPU) {
for (auto kernel : temp_kernels) { for (auto kernel : temp_kernels) {
@ -185,12 +191,6 @@ void Scheduler::ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels) {
tensor->set_allocator(context_->allocator.get()); tensor->set_allocator(context_->allocator.get());
} }
} }
std::vector<Tensor *> output_tensor = kernel::LiteKernelUtil::SubgraphOutputTensors(temp_kernels);
for (auto tensor : output_tensor) {
if (context_->float16_priority && tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
}
}
std::copy(temp_kernels.begin(), temp_kernels.end(), std::back_inserter(subgraph_kernels)); std::copy(temp_kernels.begin(), temp_kernels.end(), std::back_inserter(subgraph_kernels));
} else { } else {
auto subgraph_kernel = CreateSubKernel(temp_kernels, arch); auto subgraph_kernel = CreateSubKernel(temp_kernels, arch);
@ -213,8 +213,8 @@ kernel::LiteKernel *Scheduler::CreateSubKernel(const std::vector<kernel::LiteKer
std::vector<Tensor *> output_tensors = kernel::LiteKernelUtil::SubgraphOutputTensors(kernels); std::vector<Tensor *> output_tensors = kernel::LiteKernelUtil::SubgraphOutputTensors(kernels);
std::vector<kernel::LiteKernel *> input_kernels = kernel::LiteKernelUtil::SubgraphInputKernels(kernels); std::vector<kernel::LiteKernel *> input_kernels = kernel::LiteKernelUtil::SubgraphInputKernels(kernels);
std::vector<kernel::LiteKernel *> output_kernels = kernel::LiteKernelUtil::SubgraphOutputKernels(kernels); std::vector<kernel::LiteKernel *> output_kernels = kernel::LiteKernelUtil::SubgraphOutputKernels(kernels);
sub_kernel = sub_kernel = new kernel::SubGraphOpenCLKernel(input_tensors, output_tensors, input_kernels, output_kernels, kernels,
new kernel::SubGraphOpenCLKernel(input_tensors, output_tensors, input_kernels, output_kernels, kernels); context_, nullptr);
sub_kernel->Init(); sub_kernel->Init();
} else if (arch == kernel::KERNEL_ARCH::kNPU) { } else if (arch == kernel::KERNEL_ARCH::kNPU) {
MS_LOG(ERROR) << "NPU kernel is not supported"; MS_LOG(ERROR) << "NPU kernel is not supported";

Loading…
Cancel
Save