fix shape and add testcase

pull/2220/head
leopz 5 years ago
parent b9e59f9de4
commit a67132f6b7

@ -91,9 +91,9 @@ class IrExportBuilder {
void SetParamToTensorProto(const ParameterPtr &param, onnx::TensorProto *const tensor_proto);
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto);
void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto);
void SetShapeToNodeProto(const CNodePtr &node, const std::vector<AnfNodePtr> &inputs,
onnx::NodeProto *const node_proto);
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto);
void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto);
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto,
std::string suffix = "0");
void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
@ -112,10 +112,10 @@ class IrExportBuilder {
private:
onnx::ModelProto model_;
onnx::NodeProto *last_node_;
onnx::NodeProto *last_node_{nullptr};
std::list<FuncGraphPtr> todo_;
std::map<AnfNodePtr, size_t> node_index_map_;
size_t node_index_ = 0;
size_t node_index_{0};
};
using IrExporterPtr = std::shared_ptr<IrExporter>;
@ -349,44 +349,34 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
}
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
onnx::NodeProto *const node_proto) {
onnx::NodeProto *const node_proto, std::string suffix) {
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_ref_attr_name("shape");
attr_proto->set_name("shape");
if (suffix.compare("0") != 0) {
attr_proto->set_name("shape" + suffix);
} else {
attr_proto->set_name("shape");
}
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
SetTensorProto(type, shape, tensor_proto);
}
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, const std::vector<AnfNodePtr> &inputs,
onnx::NodeProto *const node_proto) {
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) {
// Get shape of cnode
// 1. prim kPrimTupleGetItem need to get shape of input node according to the index
// 1. prim ArgMaxWithValue need to get shape from tuple element
// 2. some cnode doesn't has shape, such as LayerNorm
// 3. other cnodes have shape
if (node->IsApply(prim::kPrimTupleGetItem)) {
// Get index of tuple get_item
int index_pos = inputs.size() - 1;
if (!inputs[index_pos]->isa<ValueNode>()) {
MS_LOG(EXCEPTION) << "Index is not ValueNode: " << index_pos;
}
auto value = inputs[index_pos]->cast<ValueNodePtr>()->value();
if (!value->isa<IntergerImm>()) {
MS_LOG(EXCEPTION) << "Index type is not supported: " << value->type_name();
}
size_t index = GetValue<int>(value);
// Get type and shape of input node
auto tup_type = inputs[0]->Type();
if (!tup_type->isa<Tuple>()) {
MS_LOG(EXCEPTION) << "Input data of kPrimTupleGetItem cnode must be tuple: " << tup_type->type_name();
if (node->IsApply(prim::kPrimArgMaxWithValue)) {
auto type = node->Type();
auto shape = node->Shape();
if (!type->isa<Tuple>()) {
MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name();
}
auto type = tup_type->cast<TuplePtr>()->elements()[index];
auto tup_shape = inputs[0]->Shape()->cast<abstract::TupleShapePtr>();
if (index >= tup_shape->shape().size()) {
MS_LOG(EXCEPTION) << "Index exceed upper limit: " << tup_shape->shape().size();
auto elements = type->cast<TuplePtr>()->elements();
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
for (size_t i = 0; i < elements.size(); i++) {
SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i));
}
auto shape = tup_shape->shape()[index];
SetShapeToNodeProto(type, shape, node_proto);
} else {
auto type = node->Type();
auto shape = node->Shape();
@ -422,7 +412,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g
std::string type_name = GetOpTypeName(op);
node_proto->set_op_type(type_name);
last_node_ = node_proto;
SetShapeToNodeProto(node, op_inputs, node_proto);
SetShapeToNodeProto(node, node_proto);
(void)std::for_each(input_names.begin(), input_names.end(),
[&node_proto](const string &name) { node_proto->add_input(name); });

@ -147,6 +147,7 @@ const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData");
const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask");
const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue");
// Maths
const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");

@ -156,6 +156,7 @@ extern const PrimitivePtr kPrimAddN;
extern const PrimitivePtr KPrimTransData;
extern const PrimitivePtr kPrimNMSWithMask;
extern const PrimitivePtr kPrimPad;
extern const PrimitivePtr kPrimArgMaxWithValue;
// Maths
extern const PrimitivePtr kPrimTensorAdd;

@ -320,6 +320,13 @@ def test_export():
export(net, input_data, file_name="./me_export.pb", file_format="GEIR")
@non_graph_engine
def test_binary_export():
net = MYNET()
input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
export(net, input_data, file_name="./me_binary_export.pb", file_format="BINARY")
def teardown_module():
files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt']
for item in files:

Loading…
Cancel
Save