|
|
|
@ -62,6 +62,7 @@ class LiteKernel {
|
|
|
|
|
const lite::Primitive *primitive)
|
|
|
|
|
: opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive),
|
|
|
|
|
context_(ctx) {
|
|
|
|
|
opParameter->thread_num_ = ctx->thread_num_;
|
|
|
|
|
this->in_kernel_.clear();
|
|
|
|
|
this->out_kernel_.clear();
|
|
|
|
|
}
|
|
|
|
@ -69,12 +70,13 @@ class LiteKernel {
|
|
|
|
|
virtual ~LiteKernel() { delete opParameter; }
|
|
|
|
|
|
|
|
|
|
virtual int Prepare() {
|
|
|
|
|
if (primitive_ != nullptr && !primitive_->GetInferFlag()) {
|
|
|
|
|
if (!InferShapeDone()) {
|
|
|
|
|
(const_cast<lite::Primitive *>(primitive_))->InferShape(inputs_, outputs_);
|
|
|
|
|
if (need_reinit) {
|
|
|
|
|
Init();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (need_reinit) {
|
|
|
|
|
Init();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &outputs = this->GetOutputs();
|
|
|
|
|
for (auto *output : outputs) {
|
|
|
|
|
MS_ASSERT(output != nullptr);
|
|
|
|
@ -126,6 +128,13 @@ class LiteKernel {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
bool InferShapeDone() {
|
|
|
|
|
if (primitive_ != nullptr && !primitive_->GetInferFlag()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
KernelKey desc;
|
|
|
|
|
std::string name;
|
|
|
|
|
OpParameter *opParameter = nullptr;
|
|
|
|
|