|
|
|
@ -27,19 +27,19 @@ using namespace std;
|
|
|
|
|
|
|
|
|
|
ShapeTransferAccordingToFormat::ShapeTransferAccordingToFormat(void) {
|
|
|
|
|
getNewShapeFuncMap = {
|
|
|
|
|
{ge::FORMAT_NCHW, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNCHWShapeByAxisValue)},
|
|
|
|
|
{ge::FORMAT_NHWC, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNHWCShapeByAxisValue)},
|
|
|
|
|
{ge::FORMAT_NC1HWC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNC1HWC0ShapeByAxisValue)},
|
|
|
|
|
{ge::FORMAT_FRACTAL_Z, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetFzShapeByAxisValue)},
|
|
|
|
|
{ge::FORMAT_HWCN, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetHWCNShapeByAxisValue)},
|
|
|
|
|
{ge::FORMAT_C1HWNCoC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetC1HWNCoC0ShapeByAxisValue)},
|
|
|
|
|
{ge::FORMAT_FRACTAL_NZ, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNzShapeByAxisValue)}};
|
|
|
|
|
{ge::FORMAT_NCHW, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNCHWShapeByAxisValue)},
|
|
|
|
|
{ge::FORMAT_NHWC, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNHWCShapeByAxisValue)},
|
|
|
|
|
{ge::FORMAT_NC1HWC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNC1HWC0ShapeByAxisValue)},
|
|
|
|
|
{ge::FORMAT_FRACTAL_Z, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetFzShapeByAxisValue)},
|
|
|
|
|
{ge::FORMAT_HWCN, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetHWCNShapeByAxisValue)},
|
|
|
|
|
{ge::FORMAT_C1HWNCoC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetC1HWNCoC0ShapeByAxisValue)},
|
|
|
|
|
{ge::FORMAT_FRACTAL_NZ, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNzShapeByAxisValue)}};
|
|
|
|
|
|
|
|
|
|
mapOfDtypeAndC0 = {
|
|
|
|
|
{ge::DT_FLOAT16, SHAPE_NUMBER_16}, {ge::DT_FLOAT, SHAPE_NUMBER_16}, {ge::DT_INT8, SHAPE_NUMBER_32},
|
|
|
|
|
{ge::DT_INT16, SHAPE_NUMBER_16}, {ge::DT_INT32, SHAPE_NUMBER_16}, {ge::DT_INT64, SHAPE_NUMBER_16},
|
|
|
|
|
{ge::DT_UINT8, SHAPE_NUMBER_16}, {ge::DT_UINT16, SHAPE_NUMBER_32}, {ge::DT_UINT32, SHAPE_NUMBER_16},
|
|
|
|
|
{ge::DT_UINT64, SHAPE_NUMBER_16}, {ge::DT_BOOL, SHAPE_NUMBER_16}};
|
|
|
|
|
{ge::DT_FLOAT16, SHAPE_NUMBER_16}, {ge::DT_FLOAT, SHAPE_NUMBER_16}, {ge::DT_INT8, SHAPE_NUMBER_32},
|
|
|
|
|
{ge::DT_INT16, SHAPE_NUMBER_16}, {ge::DT_INT32, SHAPE_NUMBER_16}, {ge::DT_INT64, SHAPE_NUMBER_16},
|
|
|
|
|
{ge::DT_UINT8, SHAPE_NUMBER_16}, {ge::DT_UINT16, SHAPE_NUMBER_32}, {ge::DT_UINT32, SHAPE_NUMBER_16},
|
|
|
|
|
{ge::DT_UINT64, SHAPE_NUMBER_16}, {ge::DT_BOOL, SHAPE_NUMBER_16}};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ShapeTransferAccordingToFormat::GetNCHWShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
|
|
|
|
@ -97,9 +97,9 @@ bool ShapeTransferAccordingToFormat::GetFzShapeByAxisValue(vector<int64_t>& newS
|
|
|
|
|
/* sizeOfOriginalVec - 1 mean the last value of original vec
|
|
|
|
|
* sizeOfOriginalVec - 2 mean the second last value of original vec */
|
|
|
|
|
newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] =
|
|
|
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], SHAPE_NUMBER_16);
|
|
|
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], SHAPE_NUMBER_16);
|
|
|
|
|
newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] =
|
|
|
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], axisValue[AXIS_C0]);
|
|
|
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], axisValue[AXIS_C0]);
|
|
|
|
|
newShape.push_back(SHAPE_NUMBER_16);
|
|
|
|
|
newShape.push_back(axisValue[AXIS_C0]);
|
|
|
|
|
} else {
|
|
|
|
@ -163,10 +163,10 @@ bool ShapeTransferAccordingToFormat::GetNzShapeByAxisValue(vector<int64_t>& newS
|
|
|
|
|
/* sizeOfOriginalVec - 1 mean the last value of original vec
|
|
|
|
|
* sizeOfOriginalVec - 2 mean the second last value of original vec */
|
|
|
|
|
newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] =
|
|
|
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], (int64_t)SHAPE_NUMBER_16);
|
|
|
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], (int64_t)SHAPE_NUMBER_16);
|
|
|
|
|
|
|
|
|
|
newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] =
|
|
|
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], axisValue[AXIS_C0]);
|
|
|
|
|
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], axisValue[AXIS_C0]);
|
|
|
|
|
newShape.push_back(SHAPE_NUMBER_16);
|
|
|
|
|
newShape.push_back(axisValue[AXIS_C0]);
|
|
|
|
|
return true;
|
|
|
|
@ -177,7 +177,7 @@ bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& s
|
|
|
|
|
shapeAndFormatInfo.newShape = shapeAndFormatInfo.oldShape;
|
|
|
|
|
if (shapeAndFormatInfo.oldFormat >= ge::FORMAT_RESERVED || shapeAndFormatInfo.newFormat >= ge::FORMAT_RESERVED) {
|
|
|
|
|
GELOGE(GRAPH_FAILED, "Old format %u or new format %u is invalid!", shapeAndFormatInfo.oldFormat,
|
|
|
|
|
shapeAndFormatInfo.newFormat);
|
|
|
|
|
shapeAndFormatInfo.newFormat);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -223,8 +223,8 @@ bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& s
|
|
|
|
|
c0 = SHAPE_DIM_VALUE_C04;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool status = axisutil_object->GetAxisValueByOriginFormat(
|
|
|
|
|
shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.oldShape, c0, axisValue, ndValue);
|
|
|
|
|
bool status = axisutil_object->GetAxisValueByOriginFormat(shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.oldShape,
|
|
|
|
|
c0, axisValue, ndValue);
|
|
|
|
|
if (status != true && shapeAndFormatInfo.newFormat != ge::FORMAT_FRACTAL_NZ) {
|
|
|
|
|
delete axisutil_object;
|
|
|
|
|
return true;
|
|
|
|
@ -238,5 +238,5 @@ bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& s
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} // namespace transformer
|
|
|
|
|
} // namespace common
|
|
|
|
|
} // namespace transformer
|
|
|
|
|
} // namespace common
|