diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc index 9681d7a8e9..033afe9124 100644 --- a/mindspore/lite/tools/common/graph_util.cc +++ b/mindspore/lite/tools/common/graph_util.cc @@ -689,7 +689,7 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptrprimitive->value.AsConcat() is nullptr"; return RET_NULL_PTR; } - node->primitive->value.AsConcat()->axis = axis_map[origin_axis]; + node->primitive->value.AsConcat()->axis = axis_map[origin_axis < 0 ? origin_axis + 4 : origin_axis]; } if (type == schema::PrimitiveType_Split) { MS_ASSERT(node->primitive->value.AsSplit() != nullptr); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc index 09785373bd..eaa93acba2 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_tile_parser.cc @@ -16,6 +16,9 @@ #include "tools/converter/parser/onnx/onnx_tile_parser.h" #include +#include +#include +#include "tools/converter/parser/onnx/onnx_tensor_parser.h" namespace mindspore { namespace lite { @@ -36,7 +39,26 @@ STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } - + const auto &onnx_tile_multiple = onnx_node.input(1); + int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_tile_multiple); + if (index == -1) { + MS_LOG(ERROR) << "can not find node: " << onnx_tile_multiple; + return RET_ERROR; + } + auto tile_attr = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index]; + if (tile_attr->data.data() == nullptr) { + MS_LOG(ERROR) << "power's attr pow can't be obtained."; + return RET_INVALID_OP_ATTR; + } + int element_size = std::accumulate(tile_attr->dims.begin(), tile_attr->dims.end(), 1, std::multiplies()); + std::vector multiples; + std::vector dims; + for (int i = 0; i < element_size; ++i) { + multiples.push_back(reinterpret_cast(tile_attr->data.data())[i]); + dims.push_back(i); + } + attr->multiples = multiples; + attr->dims = dims; op->primitive->value.type = schema::PrimitiveType_Tile; op->primitive->value.value = attr.release(); return RET_OK;