fix shape and add testcase

pull/2220/head
leopz 5 years ago
parent b9e59f9de4
commit a67132f6b7

@ -91,9 +91,9 @@ class IrExportBuilder {
void SetParamToTensorProto(const ParameterPtr &param, onnx::TensorProto *const tensor_proto); void SetParamToTensorProto(const ParameterPtr &param, onnx::TensorProto *const tensor_proto);
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto);
void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto);
void SetShapeToNodeProto(const CNodePtr &node, const std::vector<AnfNodePtr> &inputs, void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto);
onnx::NodeProto *const node_proto); void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto,
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto); std::string suffix = "0");
void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto);
@ -112,10 +112,10 @@ class IrExportBuilder {
private: private:
onnx::ModelProto model_; onnx::ModelProto model_;
onnx::NodeProto *last_node_; onnx::NodeProto *last_node_{nullptr};
std::list<FuncGraphPtr> todo_; std::list<FuncGraphPtr> todo_;
std::map<AnfNodePtr, size_t> node_index_map_; std::map<AnfNodePtr, size_t> node_index_map_;
size_t node_index_ = 0; size_t node_index_{0};
}; };
using IrExporterPtr = std::shared_ptr<IrExporter>; using IrExporterPtr = std::shared_ptr<IrExporter>;
@ -349,44 +349,34 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
} }
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
onnx::NodeProto *const node_proto) { onnx::NodeProto *const node_proto, std::string suffix) {
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_ref_attr_name("shape"); attr_proto->set_ref_attr_name("shape");
attr_proto->set_name("shape"); if (suffix.compare("0") != 0) {
attr_proto->set_name("shape" + suffix);
} else {
attr_proto->set_name("shape");
}
onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
SetTensorProto(type, shape, tensor_proto); SetTensorProto(type, shape, tensor_proto);
} }
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, const std::vector<AnfNodePtr> &inputs, void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) {
onnx::NodeProto *const node_proto) {
// Get shape of cnode // Get shape of cnode
// 1. prim kPrimTupleGetItem need to get shape of input node according to the index // 1. prim ArgMaxWithValue need to get shape from tuple element
// 2. some cnode doesn't has shape, such as LayerNorm // 2. some cnode doesn't has shape, such as LayerNorm
// 3. other cnodes have shape // 3. other cnodes have shape
if (node->IsApply(prim::kPrimTupleGetItem)) { if (node->IsApply(prim::kPrimArgMaxWithValue)) {
// Get index of tuple get_item auto type = node->Type();
int index_pos = inputs.size() - 1; auto shape = node->Shape();
if (!inputs[index_pos]->isa<ValueNode>()) { if (!type->isa<Tuple>()) {
MS_LOG(EXCEPTION) << "Index is not ValueNode: " << index_pos; MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name();
}
auto value = inputs[index_pos]->cast<ValueNodePtr>()->value();
if (!value->isa<IntergerImm>()) {
MS_LOG(EXCEPTION) << "Index type is not supported: " << value->type_name();
}
size_t index = GetValue<int>(value);
// Get type and shape of input node
auto tup_type = inputs[0]->Type();
if (!tup_type->isa<Tuple>()) {
MS_LOG(EXCEPTION) << "Input data of kPrimTupleGetItem cnode must be tuple: " << tup_type->type_name();
} }
auto type = tup_type->cast<TuplePtr>()->elements()[index]; auto elements = type->cast<TuplePtr>()->elements();
auto tup_shape = inputs[0]->Shape()->cast<abstract::TupleShapePtr>(); auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
if (index >= tup_shape->shape().size()) { for (size_t i = 0; i < elements.size(); i++) {
MS_LOG(EXCEPTION) << "Index exceed upper limit: " << tup_shape->shape().size(); SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i));
} }
auto shape = tup_shape->shape()[index];
SetShapeToNodeProto(type, shape, node_proto);
} else { } else {
auto type = node->Type(); auto type = node->Type();
auto shape = node->Shape(); auto shape = node->Shape();
@ -422,7 +412,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const g
std::string type_name = GetOpTypeName(op); std::string type_name = GetOpTypeName(op);
node_proto->set_op_type(type_name); node_proto->set_op_type(type_name);
last_node_ = node_proto; last_node_ = node_proto;
SetShapeToNodeProto(node, op_inputs, node_proto); SetShapeToNodeProto(node, node_proto);
(void)std::for_each(input_names.begin(), input_names.end(), (void)std::for_each(input_names.begin(), input_names.end(),
[&node_proto](const string &name) { node_proto->add_input(name); }); [&node_proto](const string &name) { node_proto->add_input(name); });

@ -147,6 +147,7 @@ const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData"); const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData");
const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask"); const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask");
const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad"); const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue");
// Maths // Maths
const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd"); const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");

@ -156,6 +156,7 @@ extern const PrimitivePtr kPrimAddN;
extern const PrimitivePtr KPrimTransData; extern const PrimitivePtr KPrimTransData;
extern const PrimitivePtr kPrimNMSWithMask; extern const PrimitivePtr kPrimNMSWithMask;
extern const PrimitivePtr kPrimPad; extern const PrimitivePtr kPrimPad;
extern const PrimitivePtr kPrimArgMaxWithValue;
// Maths // Maths
extern const PrimitivePtr kPrimTensorAdd; extern const PrimitivePtr kPrimTensorAdd;

@ -320,6 +320,13 @@ def test_export():
export(net, input_data, file_name="./me_export.pb", file_format="GEIR") export(net, input_data, file_name="./me_export.pb", file_format="GEIR")
@non_graph_engine
def test_binary_export():
net = MYNET()
input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
export(net, input_data, file_name="./me_binary_export.pb", file_format="BINARY")
def teardown_module(): def teardown_module():
files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt'] files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt']
for item in files: for item in files:

Loading…
Cancel
Save