From bf82e8d0037322db5582939e8d927b08a44395fe Mon Sep 17 00:00:00 2001 From: chenzupeng Date: Mon, 10 Aug 2020 21:18:04 +0800 Subject: [PATCH] fix bug in opencl subgraph output tensor --- mindspore/lite/src/lite_kernel.h | 9 +++++++-- mindspore/lite/src/runtime/opencl/opencl_executor.cc | 3 ++- mindspore/lite/src/scheduler.cc | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 147cd98407..4f715d4651 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -61,7 +61,9 @@ class LiteKernel { const std::vector &outputs, const lite::Context *ctx, const lite::Primitive *primitive) : opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive), context_(ctx) { - opParameter->thread_num_ = ctx->thread_num_; + if (opParameter && ctx) { + opParameter->thread_num_ = ctx->thread_num_; + } this->in_kernel_.clear(); this->out_kernel_.clear(); } @@ -100,7 +102,10 @@ class LiteKernel { schema::PrimitiveType type() { return (schema::PrimitiveType)this->opParameter->type_; } - std::string type_str() { return schema::EnumNamePrimitiveType((schema::PrimitiveType)this->opParameter->type_); } + std::string type_str() { + return this->opParameter ? schema::EnumNamePrimitiveType((schema::PrimitiveType)this->opParameter->type_) + : "ERROR:undefined primitive!"; + } void SetInputs(const std::vector &inputs) { this->inputs_ = inputs; } diff --git a/mindspore/lite/src/runtime/opencl/opencl_executor.cc b/mindspore/lite/src/runtime/opencl/opencl_executor.cc index a63d77e1da..fec99cd4c0 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_executor.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_executor.cc @@ -56,6 +56,7 @@ int OpenCLExecutor::Run(std::vector &inputs, std::vectorMallocData(allocator); } + output->set_allocator(allocator); } session::CallBackParam callbackParam; callbackParam.name_callback_param = kernel->Name(); @@ -91,7 +92,7 @@ int OpenCLExecutor::Run(std::vector &inputs, std::vectorGetFormat() != schema::Format_NHWC) { - TransformTensorLayout(outTensor, outTensor->GetFormat(), schema::Format_NHWC, false); + TransformTensorLayout(outTensor, outTensor->GetFormat(), schema::Format_NHWC, false); } } return RET_OK; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index ce1e9f6d3c..426c656b06 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -157,7 +157,7 @@ kernel::LiteKernel *Scheduler::CreateSubKernel(const std::vectorGetInputs()) { + for (auto tensor : tail_kernel->GetOutputs()) { if (tensor->Data() == nullptr) { output_tensors.emplace_back(tensor); }