|
|
|
@ -41,7 +41,7 @@ std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
|
|
|
|
|
auto tensorT = tmp_meta_graph->allTensors.at(input_index).get();
|
|
|
|
|
auto tensor_shape = tensorT->dims;
|
|
|
|
|
auto lite_tensor =
|
|
|
|
|
new (std::nothrow) Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType);
|
|
|
|
|
new (std::nothrow) Tensor(TypeId(tensorT->dataType), tensor_shape, tensorT->format, tensorT->nodeType);
|
|
|
|
|
if (lite_tensor == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "lite tensor is nullptr";
|
|
|
|
|
return input_tensors;
|
|
|
|
@ -106,7 +106,7 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
|
|
|
|
|
mindspore::lite::PrimitiveC *primitive) {
|
|
|
|
|
MS_ASSERT(nullptr != lite_primitive);
|
|
|
|
|
auto data_type = inputs.front()->data_type();
|
|
|
|
|
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType)primitive->Type()};
|
|
|
|
|
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType) primitive->Type()};
|
|
|
|
|
lite::Context context;
|
|
|
|
|
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
|
|
|
|
if (creator != nullptr) {
|
|
|
|
@ -115,6 +115,44 @@ kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tens
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
lite::STATUS ReplaceCNode(const FuncGraphPtr &func_graph, const CNodePtr &any_node, const AnfNodePtr &input_node,
|
|
|
|
|
std::vector<Tensor *> output_tensors, size_t replace_index) {
|
|
|
|
|
MS_ASSERT(func_graph != nullptr);
|
|
|
|
|
auto manager = func_graph->manager();
|
|
|
|
|
MS_ASSERT(manager != nullptr);
|
|
|
|
|
if (output_tensors.size() != 1) {
|
|
|
|
|
for (size_t k = 0; k < output_tensors.size(); k++) {
|
|
|
|
|
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, input_node, k);
|
|
|
|
|
if (used_node_list->size() != 1) {
|
|
|
|
|
MS_LOG(ERROR) << " output must tuple_getitem";
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto tuple_node = used_node_list->at(0).first;
|
|
|
|
|
if (GetCNodeType(tuple_node) == schema::PrimitiveType_TupleGetItem) {
|
|
|
|
|
auto new_parameter = CreateNewParamter(func_graph, output_tensors.at(k));
|
|
|
|
|
if (new_parameter == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "CreateNewParamter failed, name: " << input_node->fullname_with_scope();
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
new_parameter->set_name(input_node->fullname_with_scope() + "_const_" + std::to_string(k));
|
|
|
|
|
manager->Replace(tuple_node, new_parameter);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << " multi out tensor must connect tuple-getitem: " << input_node->fullname_with_scope();
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto new_parameter = CreateNewParamter(func_graph, output_tensors.front());
|
|
|
|
|
if (new_parameter == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "CreateNewParamter failed, name: " << input_node->fullname_with_scope();
|
|
|
|
|
return lite::RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
new_parameter->set_name(input_node->fullname_with_scope());
|
|
|
|
|
any_node->set_input(replace_index, new_parameter);
|
|
|
|
|
}
|
|
|
|
|
return lite::RET_OK;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *output_tensor) {
|
|
|
|
|
if (input_tensor != nullptr) {
|
|
|
|
@ -140,64 +178,66 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
|
|
|
|
|
}
|
|
|
|
|
auto any_node = node->cast<CNodePtr>();
|
|
|
|
|
CheckIfCNodeIsNull(any_node);
|
|
|
|
|
bool changed = false;
|
|
|
|
|
for (size_t i = 1; i < any_node->inputs().size(); i++) {
|
|
|
|
|
auto input_node = any_node->input(i);
|
|
|
|
|
if (input_node->isa<CNode>() && CheckIsAllInputsParam(input_node)) {
|
|
|
|
|
auto input_cnode = input_node->cast<CNodePtr>();
|
|
|
|
|
auto input_tensors = GetCNodeInputTensors(input_cnode);
|
|
|
|
|
if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) {
|
|
|
|
|
FreeTensors(&input_tensors, nullptr);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
|
|
|
|
|
auto output_nums = GetOutputTensorNum(input_cnode);
|
|
|
|
|
std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
|
|
|
|
|
auto lite_primitive = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
|
|
|
|
|
if (lite_primitive == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "lite_primitive is nullptr";
|
|
|
|
|
FreeTensors(&input_tensors, &output_tensors);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// here, input_tensor's format need to be transposed nhwc according to fmkType,
|
|
|
|
|
// but for the time being, we only transpose the tensor with 0/1/2/3D.
|
|
|
|
|
// Others should be added in future.
|
|
|
|
|
for (size_t j = 0; j < input_tensors.size(); ++j) {
|
|
|
|
|
input_tensors[j]->SetFormat(schema::Format_NHWC);
|
|
|
|
|
if (input_tensors[j]->shape().size() == 4) {
|
|
|
|
|
MS_LOG(INFO) << "init input_tensor format to nhwc";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
lite_primitive->InferShape(input_tensors, output_tensors);
|
|
|
|
|
auto parameter = kernel::PopulateParameter(lite_primitive.get());
|
|
|
|
|
if (parameter == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
|
|
|
|
|
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type()));
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get());
|
|
|
|
|
if (lite_kernel == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
|
|
|
|
FreeTensors(&input_tensors, &output_tensors);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto ret = lite_kernel->Run();
|
|
|
|
|
if (0 != ret) {
|
|
|
|
|
FreeTensors(&input_tensors, &output_tensors);
|
|
|
|
|
MS_LOG(ERROR) << "run kernel failed, name: " << lite_kernel->name();
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto new_parameter = CreateNewParamter(func_graph, output_tensors.front());
|
|
|
|
|
if (new_parameter == nullptr) {
|
|
|
|
|
FreeTensors(&input_tensors, &output_tensors);
|
|
|
|
|
MS_LOG(ERROR) << "CreateNewParamter failed, name: " << lite_kernel->name();
|
|
|
|
|
return nullptr;
|
|
|
|
|
if (!input_node->isa<CNode>() || !CheckIsAllInputsParam(input_node)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto input_cnode = input_node->cast<CNodePtr>();
|
|
|
|
|
auto input_tensors = GetCNodeInputTensors(input_cnode);
|
|
|
|
|
if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) {
|
|
|
|
|
FreeTensors(&input_tensors, nullptr);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
changed = true;
|
|
|
|
|
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
|
|
|
|
|
auto output_nums = GetOutputTensorNum(input_cnode);
|
|
|
|
|
std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
|
|
|
|
|
auto lite_primitive = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
|
|
|
|
|
if (lite_primitive == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "lite_primitive is nullptr";
|
|
|
|
|
FreeTensors(&input_tensors, &output_tensors);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// here, input_tensor's format need to be transposed nhwc according to fmkType,
|
|
|
|
|
// but for the time being, we only transpose the tensor with 0/1/2/3D.
|
|
|
|
|
// Others should be added in future.
|
|
|
|
|
for (size_t j = 0; j < input_tensors.size(); ++j) {
|
|
|
|
|
input_tensors[j]->SetFormat(schema::Format_NHWC);
|
|
|
|
|
if (input_tensors[j]->shape().size() == 4) {
|
|
|
|
|
MS_LOG(INFO) << "init input_tensor format to nhwc";
|
|
|
|
|
}
|
|
|
|
|
new_parameter->set_name(input_node->fullname_with_scope());
|
|
|
|
|
any_node->set_input(i, new_parameter);
|
|
|
|
|
}
|
|
|
|
|
lite_primitive->InferShape(input_tensors, output_tensors);
|
|
|
|
|
auto parameter = kernel::PopulateParameter(lite_primitive.get());
|
|
|
|
|
if (parameter == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
|
|
|
|
|
<< schema::EnumNamePrimitiveType((schema::PrimitiveType) (lite_primitive->Type()));
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get());
|
|
|
|
|
if (lite_kernel == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
|
|
|
|
FreeTensors(&input_tensors, &output_tensors);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto ret = lite_kernel->Run();
|
|
|
|
|
if (0 != ret) {
|
|
|
|
|
FreeTensors(&input_tensors, &output_tensors);
|
|
|
|
|
MS_LOG(ERROR) << "run kernel failed, name: " << lite_kernel->name();
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// replace cnode by new param
|
|
|
|
|
if (ReplaceCNode(func_graph, any_node, input_node, output_tensors, i) != lite::RET_OK) {
|
|
|
|
|
FreeTensors(&input_tensors, &output_tensors);
|
|
|
|
|
delete (lite_kernel);
|
|
|
|
|
MS_LOG(ERROR) << "constant_folding replace cnode failed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
FreeTensors(&input_tensors, &output_tensors);
|
|
|
|
|
delete (lite_kernel);
|
|
|
|
|
}
|
|
|
|
|
return any_node;
|
|
|
|
|
return changed ? any_node : nullptr;
|
|
|
|
|
}
|
|
|
|
|
} // namespace mindspore::opt
|
|
|
|
|