!8717 [lite] fix onnx operator converter bug

From: @xu_anyue
Reviewed-by: 
Signed-off-by:
pull/8717/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 5d4e95ad95

@ -689,7 +689,7 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNo
MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr"; MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
node->primitive->value.AsConcat()->axis = axis_map[origin_axis]; node->primitive->value.AsConcat()->axis = axis_map[origin_axis < 0 ? origin_axis + 4 : origin_axis];
} }
if (type == schema::PrimitiveType_Split) { if (type == schema::PrimitiveType_Split) {
MS_ASSERT(node->primitive->value.AsSplit() != nullptr); MS_ASSERT(node->primitive->value.AsSplit() != nullptr);

@ -16,6 +16,9 @@
#include "tools/converter/parser/onnx/onnx_tile_parser.h" #include "tools/converter/parser/onnx/onnx_tile_parser.h"
#include <memory> #include <memory>
#include <numeric>
#include <vector>
#include "tools/converter/parser/onnx/onnx_tensor_parser.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -36,7 +39,26 @@ STATUS OnnxTileParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
MS_LOG(ERROR) << "new op failed"; MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
const auto &onnx_tile_multiple = onnx_node.input(1);
int index = OnnxTensorParser::GetInstance()->GetTensorCache()->FindTensor(onnx_tile_multiple);
if (index == -1) {
MS_LOG(ERROR) << "can not find node: " << onnx_tile_multiple;
return RET_ERROR;
}
auto tile_attr = OnnxTensorParser::GetInstance()->GetTensorCache()->GetCachedTensor()[index];
if (tile_attr->data.data() == nullptr) {
MS_LOG(ERROR) << "power's attr pow can't be obtained.";
return RET_INVALID_OP_ATTR;
}
int element_size = std::accumulate(tile_attr->dims.begin(), tile_attr->dims.end(), 1, std::multiplies<int>());
std::vector<int> multiples;
std::vector<int> dims;
for (int i = 0; i < element_size; ++i) {
multiples.push_back(reinterpret_cast<int *>(tile_attr->data.data())[i]);
dims.push_back(i);
}
attr->multiples = multiples;
attr->dims = dims;
op->primitive->value.type = schema::PrimitiveType_Tile; op->primitive->value.type = schema::PrimitiveType_Tile;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
return RET_OK; return RET_OK;

Loading…
Cancel
Save