|
|
|
@ -143,11 +143,18 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
|
|
|
|
|
if (type == PrimitiveType_Concat) {
|
|
|
|
|
auto origin_axis = node->primitive->value.AsConcat()->axis;
|
|
|
|
|
auto axis_map = GetNc2NhAxisMap();
|
|
|
|
|
MS_ASSERT(node->primitive->value.AsConcat() != nullptr);
|
|
|
|
|
if (node->primitive->value.AsConcat() == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
node->primitive->value.AsConcat()->axis = axis_map[origin_axis];
|
|
|
|
|
}
|
|
|
|
|
if (type == PrimitiveType_StridedSlice) {
|
|
|
|
|
auto attr = node->primitive->value.AsStridedSlice();
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "node->primitive->value.AsStridedSlice() is nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
auto origin_begin = attr->begin;
|
|
|
|
|
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]};
|
|
|
|
|
auto origin_end = attr->end;
|
|
|
|
@ -158,14 +165,20 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
|
|
|
|
|
if (type == PrimitiveType_Split) {
|
|
|
|
|
auto origin_axis = node->primitive->value.AsSplit()->splitDim;
|
|
|
|
|
auto axis_map = GetNc2NhAxisMap();
|
|
|
|
|
MS_ASSERT(node->primitive->value.AsSplit != nullptr);
|
|
|
|
|
if (node->primitive->value.AsSplit() == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "node->primitive->value.AsSplit() is nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis];
|
|
|
|
|
}
|
|
|
|
|
if (type == PrimitiveType_Crop) {
|
|
|
|
|
auto origin_axis = node->primitive->value.AsCrop()->axis;
|
|
|
|
|
auto offsets = node->primitive->value.AsCrop()->offsets;
|
|
|
|
|
auto axis_map = GetNc2NhAxisMap();
|
|
|
|
|
MS_ASSERT(node->primitive->value.AsCrop() != nullptr);
|
|
|
|
|
if (node->primitive->value.AsCrop() == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
node->primitive->value.AsCrop()->axis = axis_map[origin_axis];
|
|
|
|
|
// nchw->nhwc,offsets need pad 0;
|
|
|
|
|
if (axis_map[origin_axis] == 0) {
|
|
|
|
@ -181,13 +194,12 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni
|
|
|
|
|
MS_LOG(ERROR) << "Crop error";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(node->primitive->value.AsCrop() != nullptr);
|
|
|
|
|
node->primitive->value.AsCrop()->offsets = offsets;
|
|
|
|
|
}
|
|
|
|
|
if (type == PrimitiveType_Slice) {
|
|
|
|
|
auto attr = node->primitive->value.AsSlice();
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "attr is nullptr";
|
|
|
|
|
MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
auto origin_begin = attr->begin;
|
|
|
|
|