|
|
|
@ -91,9 +91,9 @@ class IrExportBuilder {
|
|
|
|
|
void SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto);
|
|
|
|
|
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto);
|
|
|
|
|
void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto);
|
|
|
|
|
void SetShapeToNodeProto(const CNodePtr &node, const std::vector<AnfNodePtr> &inputs,
|
|
|
|
|
onnx::NodeProto *const node_proto);
|
|
|
|
|
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto);
|
|
|
|
|
void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto);
|
|
|
|
|
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto,
|
|
|
|
|
std::string suffix = "0");
|
|
|
|
|
void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
|
|
|
|
void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
|
|
|
|
void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
|
|
|
|
@ -112,10 +112,10 @@ class IrExportBuilder {
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
onnx::ModelProto model_;
|
|
|
|
|
onnx::NodeProto *last_node_;
|
|
|
|
|
onnx::NodeProto *last_node_{nullptr};
|
|
|
|
|
std::list<FuncGraphPtr> todo_;
|
|
|
|
|
std::map<AnfNodePtr, size_t> node_index_map_;
|
|
|
|
|
size_t node_index_ = 0;
|
|
|
|
|
size_t node_index_{0};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
using IrExporterPtr = std::shared_ptr<IrExporter>;
|
|
|
|
@ -349,44 +349,34 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
|
|
|
|
|
onnx::NodeProto *const node_proto) {
|
|
|
|
|
onnx::NodeProto *const node_proto, std::string suffix) {
|
|
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
|
|
|
|
attr_proto->set_ref_attr_name("shape");
|
|
|
|
|
attr_proto->set_name("shape");
|
|
|
|
|
if (suffix.compare("0") != 0) {
|
|
|
|
|
attr_proto->set_name("shape" + suffix);
|
|
|
|
|
} else {
|
|
|
|
|
attr_proto->set_name("shape");
|
|
|
|
|
}
|
|
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
|
|
|
|
SetTensorProto(type, shape, tensor_proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, const std::vector<AnfNodePtr> &inputs,
|
|
|
|
|
onnx::NodeProto *const node_proto) {
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) {
|
|
|
|
|
// Get shape of cnode
|
|
|
|
|
// 1. prim kPrimTupleGetItem need to get shape of input node according to the index
|
|
|
|
|
// 1. prim ArgMaxWithValue need to get shape from tuple element
|
|
|
|
|
// 2. some cnode doesn't has shape, such as LayerNorm
|
|
|
|
|
// 3. other cnodes have shape
|
|
|
|
|
if (node->IsApply(prim::kPrimTupleGetItem)) {
|
|
|
|
|
// Get index of tuple get_item
|
|
|
|
|
int index_pos = inputs.size() - 1;
|
|
|
|
|
if (!inputs[index_pos]->isa<ValueNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Index is not ValueNode: " << index_pos;
|
|
|
|
|
}
|
|
|
|
|
auto value = inputs[index_pos]->cast<ValueNodePtr>()->value();
|
|
|
|
|
if (!value->isa<IntergerImm>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Index type is not supported: " << value->type_name();
|
|
|
|
|
}
|
|
|
|
|
size_t index = GetValue<int>(value);
|
|
|
|
|
|
|
|
|
|
// Get type and shape of input node
|
|
|
|
|
auto tup_type = inputs[0]->Type();
|
|
|
|
|
if (!tup_type->isa<Tuple>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Input data of kPrimTupleGetItem cnode must be tuple: " << tup_type->type_name();
|
|
|
|
|
if (node->IsApply(prim::kPrimArgMaxWithValue)) {
|
|
|
|
|
auto type = node->Type();
|
|
|
|
|
auto shape = node->Shape();
|
|
|
|
|
if (!type->isa<Tuple>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name();
|
|
|
|
|
}
|
|
|
|
|
auto type = tup_type->cast<TuplePtr>()->elements()[index];
|
|
|
|
|
auto tup_shape = inputs[0]->Shape()->cast<abstract::TupleShapePtr>();
|
|
|
|
|
if (index >= tup_shape->shape().size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Index exceed upper limit: " << tup_shape->shape().size();
|
|
|
|
|
auto elements = type->cast<TuplePtr>()->elements();
|
|
|
|
|
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
|
|
|
|
|
for (size_t i = 0; i < elements.size(); i++) {
|
|
|
|
|
SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i));
|
|
|
|
|
}
|
|
|
|
|
auto shape = tup_shape->shape()[index];
|
|
|
|
|
SetShapeToNodeProto(type, shape, node_proto);
|
|
|
|
|
} else {
|
|
|
|
|
auto type = node->Type();
|
|
|
|
|
auto shape = node->Shape();
|
|
|
|
@ -422,7 +412,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g
|
|
|
|
|
std::string type_name = GetOpTypeName(op);
|
|
|
|
|
node_proto->set_op_type(type_name);
|
|
|
|
|
last_node_ = node_proto;
|
|
|
|
|
SetShapeToNodeProto(node, op_inputs, node_proto);
|
|
|
|
|
SetShapeToNodeProto(node, node_proto);
|
|
|
|
|
(void)std::for_each(input_names.begin(), input_names.end(),
|
|
|
|
|
[&node_proto](const string &name) { node_proto->add_input(name); });
|
|
|
|
|
|
|
|
|
|