|
|
|
@ -339,8 +339,7 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) {
|
|
|
|
|
}
|
|
|
|
|
// when merge is removed, this if is removed automatically
|
|
|
|
|
if (kernel->Type() == schema::PrimitiveType_Merge) {
|
|
|
|
|
MS_ASSERT(kernel->in_kernels().size() == 2);
|
|
|
|
|
return (is_kernel_finish[kernel->in_kernels().at(0)] || is_kernel_finish[kernel->in_kernels().at(1)]);
|
|
|
|
|
return MergeOpIsReady(kernel, is_kernel_finish);
|
|
|
|
|
} else {
|
|
|
|
|
return std::all_of(kernel_inputs.begin(), kernel_inputs.end(),
|
|
|
|
|
[&](kernel::LiteKernel *kernel) { return is_kernel_finish[kernel]; });
|
|
|
|
@ -370,6 +369,28 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) {
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
bool Scheduler::MergeOpIsReady(const kernel::LiteKernel *kernel,
|
|
|
|
|
std::map<const kernel::LiteKernel *, bool> is_kernel_finish) {
|
|
|
|
|
std::map<const lite::Tensor *, bool> merge_in_tensors_map;
|
|
|
|
|
for (auto merge_in_tensor : kernel->in_tensors()) {
|
|
|
|
|
merge_in_tensors_map[merge_in_tensor] = false;
|
|
|
|
|
if (merge_in_tensor->category() == Tensor::CONST_TENSOR || merge_in_tensor->category() == Tensor::CONST_SCALAR) {
|
|
|
|
|
merge_in_tensors_map[merge_in_tensor] = true;
|
|
|
|
|
}
|
|
|
|
|
for (auto merge_in_kernel : kernel->in_kernels()) {
|
|
|
|
|
for (auto tensor : merge_in_kernel->out_tensors()) {
|
|
|
|
|
if (tensor == merge_in_tensor && is_kernel_finish[merge_in_kernel]) {
|
|
|
|
|
merge_in_tensors_map[merge_in_tensor] = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto kernel_in_tensors_num = kernel->in_tensors().size();
|
|
|
|
|
return std::all_of(kernel->in_tensors().begin(), kernel->in_tensors().begin() + kernel_in_tensors_num / 2,
|
|
|
|
|
[&](lite::Tensor *in_tensor) { return merge_in_tensors_map[in_tensor]; }) ||
|
|
|
|
|
std::all_of(kernel->in_tensors().begin() + kernel_in_tensors_num / 2, kernel->in_tensors().end(),
|
|
|
|
|
[&](lite::Tensor *in_tensor) { return merge_in_tensors_map[in_tensor]; });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel::SubGraphKernel *Scheduler::CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels,
|
|
|
|
|
kernel::SubGraphType type) {
|
|
|
|
|