|
|
|
@ -98,29 +98,28 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
|
|
|
|
|
// activation
|
|
|
|
|
auto input_quant_params = primitive->GetInputQuantParams();
|
|
|
|
|
auto node_type = (schema::PrimitiveType)primitive->Type();
|
|
|
|
|
if (input_quant_params.empty()) {
|
|
|
|
|
MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty";
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < input_quant_params.size(); i++) {
|
|
|
|
|
if (i >= dst_node->inputIndex.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size()
|
|
|
|
|
<< " quant_params; but only " << dst_node->inputIndex.size() << " input";
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
auto activate_index = dst_node->inputIndex[i];
|
|
|
|
|
auto tensor_input = meta_graph->allTensors[activate_index].get();
|
|
|
|
|
if (tensor_input->quantParams.empty()) {
|
|
|
|
|
for (auto input_quant_param : input_quant_params[i]) {
|
|
|
|
|
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
|
|
|
|
|
std::make_unique<schema::QuantParamT>(input_quant_param);
|
|
|
|
|
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
|
|
|
|
|
<< " zp: " << input_quant_param_ptr->zeroPoint;
|
|
|
|
|
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
|
|
|
|
|
if (!input_quant_params.empty()) {
|
|
|
|
|
for (size_t i = 0; i < input_quant_params.size(); i++) {
|
|
|
|
|
if (i >= dst_node->inputIndex.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size()
|
|
|
|
|
<< " quant_params; but only " << dst_node->inputIndex.size() << " input";
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
auto activate_index = dst_node->inputIndex[i];
|
|
|
|
|
auto tensor_input = meta_graph->allTensors[activate_index].get();
|
|
|
|
|
if (tensor_input->quantParams.empty()) {
|
|
|
|
|
for (auto input_quant_param : input_quant_params[i]) {
|
|
|
|
|
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
|
|
|
|
|
std::make_unique<schema::QuantParamT>(input_quant_param);
|
|
|
|
|
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
|
|
|
|
|
<< " zp: " << input_quant_param_ptr->zeroPoint;
|
|
|
|
|
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// output
|
|
|
|
|
auto output_index = dst_node->outputIndex[0];
|
|
|
|
|
auto tensor_output = meta_graph->allTensors[output_index].get();
|
|
|
|
@ -171,7 +170,7 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
|
|
|
|
schema::CNodeT *return_node) {
|
|
|
|
|
schema::CNodeT *return_node) {
|
|
|
|
|
MS_ASSERT(nullptr != meta_graph);
|
|
|
|
|
MS_ASSERT(nullptr != return_node);
|
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
|
|
|
@ -210,9 +209,9 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee
|
|
|
|
|
if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem ||
|
|
|
|
|
primitive_c->Type() == schema::PrimitiveType_MakeTuple
|
|
|
|
|
#ifdef SUPPORT_TRAIN
|
|
|
|
|
|| primitive_c->Type() == schema::PrimitiveType_Depend
|
|
|
|
|
|| primitive_c->Type() == schema::PrimitiveType_Depend
|
|
|
|
|
#endif
|
|
|
|
|
) {
|
|
|
|
|
) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
RemoveIfMakeTuple(cnode);
|
|
|
|
@ -403,8 +402,7 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
|
|
|
|
|
if (value_track->isa<Int32Imm>()) {
|
|
|
|
|
shape.push_back((GetValue<int>(value_track)));
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is "
|
|
|
|
|
<< value_track->ToString() << ".";
|
|
|
|
|
MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is " << value_track->ToString() << ".";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (shape.size()) {
|
|
|
|
@ -417,10 +415,10 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
|
|
|
|
|
node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size();
|
|
|
|
|
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());
|
|
|
|
|
meta_graphT->allTensors.emplace_back(std::move(paramTensor));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << ".";
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << ".";
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
} else if (value->isa<Number>()) {
|
|
|
|
|
MS_LOG(INFO) << "Value is a number.";
|
|
|
|
|