|
|
|
@ -39,29 +39,18 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
attr->format = schema::Format::Format_NCHW;
|
|
|
|
|
std::vector<onnx::TensorProto> params;
|
|
|
|
|
for (int i = 0; i < onnx_node.input_size(); ++i) {
|
|
|
|
|
const auto &input_name = onnx_node.input(i);
|
|
|
|
|
for (const auto &it : onnx_graph.initializer()) {
|
|
|
|
|
if (it.name() == input_name) {
|
|
|
|
|
params.emplace_back(it);
|
|
|
|
|
break;
|
|
|
|
|
attr->format = schema::Format_NCHW;
|
|
|
|
|
std::vector<int64_t> shape;
|
|
|
|
|
shape.clear();
|
|
|
|
|
for (const auto &onnx_node_attr : onnx_node.attribute()) {
|
|
|
|
|
const auto &attribute_name = onnx_node_attr.name();
|
|
|
|
|
if (attribute_name == "shape") {
|
|
|
|
|
for (int i = 0; i < onnx_node_attr.ints_size(); ++i) {
|
|
|
|
|
shape.push_back(static_cast<int64_t>(onnx_node_attr.ints(i)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (params.empty()) {
|
|
|
|
|
MS_LOG(DEBUG) << "shape from another op other than const initializer";
|
|
|
|
|
} else {
|
|
|
|
|
if (params.size() != 1) {
|
|
|
|
|
MS_LOG(ERROR) << "shape param num is " << params.size() << ", not equal to 1";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < params[0].int64_data_size(); ++i) {
|
|
|
|
|
attr->shape.emplace_back(params[0].int64_data(i));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
attr->shape = shape;
|
|
|
|
|
|
|
|
|
|
op->primitive->value.type = schema::PrimitiveType_Reshape;
|
|
|
|
|
op->primitive->value.value = attr.release();
|
|
|
|
|