!335 Synchronize latest Ascend software suite 19 Nov 2020
From: @nicholas_yhr Reviewed-by: @youui,@liujunzhu Signed-off-by: @liujunzhupull/335/MERGE
commit
9153665631
@ -0,0 +1,144 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file axis_util.h
|
||||
* \brief get the axis value
|
||||
*/
|
||||
#ifndef COMMON_UTILS_TRANSFER_AXIS_UTIL_H_
|
||||
#define COMMON_UTILS_TRANSFER_AXIS_UTIL_H_
|
||||
|
||||
#include <memory.h>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "external/graph/ge_error_codes.h"
|
||||
#include "external/graph/types.h"
|
||||
#include "framework/common/debug/ge_log.h"
|
||||
|
||||
namespace common {
|
||||
namespace transformer {
|
||||
|
||||
const int32_t DIM_DEFAULT_SIZE = 4;
|
||||
const uint32_t NCHW_DIMENSION_NUM = 4;
|
||||
|
||||
const int32_t AXIS_NCHW_DIM_N = 0;
|
||||
const int32_t AXIS_NCHW_DIM_C = 1;
|
||||
const int32_t AXIS_NCHW_DIM_H = 2;
|
||||
const int32_t AXIS_NCHW_DIM_W = 3;
|
||||
|
||||
const int32_t AXIS_NHWC_DIM_N = 0;
|
||||
const int32_t AXIS_NHWC_DIM_H = 1;
|
||||
const int32_t AXIS_NHWC_DIM_W = 2;
|
||||
const int32_t AXIS_NHWC_DIM_C = 3;
|
||||
|
||||
const int32_t AXIS_NC1HWC0_DIM_N = 0;
|
||||
const int32_t AXIS_NC1HWC0_DIM_C1 = 1;
|
||||
const int32_t AXIS_NC1HWC0_DIM_C0 = 4;
|
||||
const int32_t AXIS_NC1HWC0_DIM_H = 2;
|
||||
const int32_t AXIS_NC1HWC0_DIM_W = 3;
|
||||
|
||||
const int32_t AXIS_HWCN_DIM_H = 0;
|
||||
const int32_t AXIS_HWCN_DIM_W = 1;
|
||||
const int32_t AXIS_HWCN_DIM_C = 2;
|
||||
const int32_t AXIS_HWCN_DIM_N = 3;
|
||||
|
||||
const int32_t AXIS_C1HWNCoC0_DIM_C1 = 0;
|
||||
const int32_t AXIS_C1HWNCoC0_DIM_H = 1;
|
||||
const int32_t AXIS_C1HWNCoC0_DIM_W = 2;
|
||||
const int32_t AXIS_C1HWNCoC0_DIM_N = 3;
|
||||
const int32_t AXIS_C1HWNCoC0_DIM_Co = 4;
|
||||
const int32_t AXIS_C1HWNCoC0_DIM_C0 = 5;
|
||||
|
||||
#define CHECK_NOTNULL(val) \
|
||||
do { \
|
||||
if ((val) == nullptr) { \
|
||||
GELOGE(GRAPH_FAILED, "[ERROR]Parameter[%s] must not be null.", #val); \
|
||||
return false; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CHECK(cond, log_func, return_expr) \
|
||||
do { \
|
||||
if (cond) { \
|
||||
log_func; \
|
||||
return_expr; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
enum AxisValueType {
|
||||
AXIS_N = 0,
|
||||
AXIS_C = 1,
|
||||
AXIS_H = 2,
|
||||
AXIS_W = 3,
|
||||
AXIS_C1 = 4,
|
||||
AXIS_C0 = 5,
|
||||
AXIS_Co = 6,
|
||||
AXIS_D = 7,
|
||||
AXIS_BOTTOM = 8
|
||||
};
|
||||
|
||||
int64_t DivisionCeiling(int64_t dividend, int64_t divisor);
|
||||
|
||||
/* Axis value is arranged as {N,C,H,W,C1,C0,...} */
|
||||
/* The first parameter is old shape's dimension,
|
||||
* second is c0 and third is axis value. */
|
||||
using GetAxisValueInfoByFormat =
|
||||
std::function<bool(const std::vector<int64_t>&, const uint32_t&, std::vector<int64_t>&, std::vector<int64_t>&)>;
|
||||
|
||||
using GetAxisValueInfoByFormatPtr = std::shared_ptr<GetAxisValueInfoByFormat>;
|
||||
|
||||
class AxisUtil {
|
||||
public:
|
||||
AxisUtil();
|
||||
~AxisUtil(){};
|
||||
bool GetAxisValueByOriginFormat(const ge::Format& format, const std::vector<int64_t>& dimVec, const uint32_t& c0,
|
||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
|
||||
bool HasAxisValueFunc(const ge::Format& format);
|
||||
|
||||
private:
|
||||
static bool CheckParams(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
|
||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
|
||||
|
||||
static bool GetAxisValueByNCHW(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
|
||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
|
||||
|
||||
static bool GetAxisValueByNHWC(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
|
||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
|
||||
|
||||
static bool GetAxisValueByNC1HWC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
|
||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
|
||||
|
||||
static bool GetAxisValueByFz(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
|
||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
|
||||
|
||||
static bool GetAxisValueByHWCN(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
|
||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
|
||||
|
||||
static bool GetAxisValueByND(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
|
||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
|
||||
|
||||
static bool GetAxisValueByC1HWNCoC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
|
||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
|
||||
|
||||
/* map of GetAxisValueInfoByFormat, get axis value by different original
|
||||
* formats. */
|
||||
std::map<ge::Format, GetAxisValueInfoByFormatPtr> getAxisValueFuncMap;
|
||||
};
|
||||
} // namespace transformer
|
||||
} // namespace common
|
||||
|
||||
#endif // COMMON_UTILS_TRANSFER_AXIS_UTIL_H_
|
@ -0,0 +1,122 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file transfer_shape_according_to_format.h
|
||||
* \brief set shape according to original format and current format
|
||||
*/
|
||||
#ifndef COMMON_UTILS_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
|
||||
#define COMMON_UTILS_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
|
||||
|
||||
#include "transformer/inc/axis_util.h"
|
||||
|
||||
#include <memory.h>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "graph/types.h"
|
||||
#include "graph/utils/op_desc_utils.h"
|
||||
|
||||
namespace common {
|
||||
namespace transformer {
|
||||
|
||||
enum OpImplType {
|
||||
EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op
|
||||
EN_IMPL_CUSTOM_TIK, // custom tik op
|
||||
EN_IMPL_CUSTOM_TBE, // custom tbe op
|
||||
EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op
|
||||
EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op
|
||||
EN_IMPL_HW_TIK, // Huawei built-in tik op
|
||||
EN_IMPL_HW_TBE, // Huawei built-in tbe op
|
||||
EN_IMPL_RL, // RL op
|
||||
EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op
|
||||
EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op
|
||||
EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op
|
||||
EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op
|
||||
EN_RESERVED // reserved value
|
||||
};
|
||||
|
||||
const uint32_t SHAPE_NUMBER_16 = 16;
|
||||
const uint32_t SHAPE_NUMBER_32 = 32;
|
||||
const uint32_t SHAPE_DIM_VALUE_C04 = 4;
|
||||
const uint32_t NI = 16;
|
||||
const uint32_t MINUS_VALUE_ONE = 1;
|
||||
const uint32_t MINUS_VALUE_TWO = 2;
|
||||
const uint32_t SIZE_OF_CN = 2;
|
||||
const uint32_t MINIMUM_NZ_SHAPE_DIM_NUM = 2;
|
||||
|
||||
/* The first parameter is axis value, second is new shape and third is
|
||||
* op implementation type. */
|
||||
using GetNewShapeByAxisValueAndFormat =
|
||||
std::function<bool(vector<int64_t> &, const int64_t &, vector<int64_t> &, vector<int64_t> &)>;
|
||||
|
||||
using GetNewShapeByAxisValueAndFormatPtr = std::shared_ptr<GetNewShapeByAxisValueAndFormat>;
|
||||
|
||||
struct ShapeAndFormatInfo {
|
||||
const std::vector<int64_t> &oldShape;
|
||||
std::vector<int64_t> &newShape;
|
||||
const ge::Format &oldFormat;
|
||||
const ge::Format &newFormat;
|
||||
const ge::DataType ¤tDataType;
|
||||
const int64_t &opImplType;
|
||||
};
|
||||
|
||||
using ShapeAndFormat = struct ShapeAndFormatInfo;
|
||||
|
||||
class ShapeTransferAccordingToFormat {
|
||||
public:
|
||||
ShapeTransferAccordingToFormat();
|
||||
|
||||
~ShapeTransferAccordingToFormat(){};
|
||||
|
||||
ShapeTransferAccordingToFormat(const ShapeTransferAccordingToFormat &) = delete;
|
||||
|
||||
ShapeTransferAccordingToFormat &operator=(const ShapeTransferAccordingToFormat &) = delete;
|
||||
|
||||
bool GetShapeAccordingToFormat(ShapeAndFormat &inputAndOutputInfo, int64_t *c = nullptr);
|
||||
|
||||
/* ----------Below is the function of getting new shape---------------------- */
|
||||
static bool GetNCHWShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
|
||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
|
||||
|
||||
static bool GetNHWCShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
|
||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
|
||||
|
||||
static bool GetNC1HWC0ShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
|
||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
|
||||
|
||||
static bool GetFzShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
|
||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
|
||||
|
||||
static bool GetHWCNShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
|
||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
|
||||
|
||||
static bool GetC1HWNCoC0ShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
|
||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
|
||||
|
||||
static bool GetNzShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
|
||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
|
||||
|
||||
private:
|
||||
/* map of GetAxisValueInfoByFormat, get axis value by different original
|
||||
* formats. */
|
||||
std::map<ge::Format, GetNewShapeByAxisValueAndFormatPtr> getNewShapeFuncMap;
|
||||
std::map<ge::DataType, uint32_t> mapOfDtypeAndC0;
|
||||
};
|
||||
} // namespace transformer
|
||||
} // namespace common
|
||||
|
||||
#endif // COMMON_UTILS_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
|
@ -0,0 +1,198 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file axis_util.cpp
|
||||
* \brief get the axis value
|
||||
*/
|
||||
#include "transformer/inc/axis_util.h"
|
||||
#include "graph/types.h"
|
||||
|
||||
namespace common {
|
||||
namespace transformer {
|
||||
using namespace ge;
|
||||
using namespace std;
|
||||
|
||||
AxisUtil::AxisUtil() {
|
||||
getAxisValueFuncMap = {{FORMAT_NCHW, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNCHW)},
|
||||
{FORMAT_NHWC, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNHWC)},
|
||||
{FORMAT_NC1HWC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNC1HWC0)},
|
||||
{FORMAT_HWCN, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByHWCN)},
|
||||
{FORMAT_ND, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByND)},
|
||||
{FORMAT_C1HWNCoC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByC1HWNCoC0)}};
|
||||
}
|
||||
|
||||
int64_t DivisionCeiling(int64_t dividend, int64_t divisor) {
|
||||
if (divisor == 0) {
|
||||
return 0;
|
||||
} else {
|
||||
return (dividend + divisor - 1) / divisor;
|
||||
}
|
||||
}
|
||||
|
||||
bool AxisUtil::GetAxisValueByOriginFormat(const Format &format, const vector<int64_t> &dimVec, const uint32_t &c0,
|
||||
vector<int64_t> &axisValue, vector<int64_t> &ndValue) {
|
||||
auto iterGetAxisFunc = getAxisValueFuncMap.find(format);
|
||||
if (iterGetAxisFunc == getAxisValueFuncMap.end()) {
|
||||
GELOGI("Can not get axis value of old format %u!", format);
|
||||
return false;
|
||||
}
|
||||
GetAxisValueInfoByFormatPtr getAxisFunc = iterGetAxisFunc->second;
|
||||
CHECK_NOTNULL(getAxisFunc);
|
||||
return (*getAxisFunc)(dimVec, c0, axisValue, ndValue);
|
||||
}
|
||||
|
||||
bool AxisUtil::HasAxisValueFunc(const Format &format) {
|
||||
auto iterGetAxisFunc = getAxisValueFuncMap.find(format);
|
||||
if (iterGetAxisFunc == getAxisValueFuncMap.end()) {
|
||||
GELOGI("Can not get axis value of format %u!", format);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AxisUtil::CheckParams(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue,
|
||||
vector<int64_t> &ndValue) {
|
||||
ndValue = originalDimVec;
|
||||
auto dimSize = originalDimVec.size();
|
||||
if (dimSize < DIM_DEFAULT_SIZE) {
|
||||
/* Before this funcion, we should call function PadDimensionTo4. */
|
||||
GELOGI("Dimension size %zu is invalid.", dimSize);
|
||||
return false;
|
||||
}
|
||||
if (c0 == 0) {
|
||||
GELOGE(GRAPH_FAILED, "[ERROR]c0 is zero!");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AxisUtil::GetAxisValueByND(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue,
|
||||
vector<int64_t> &ndValue) {
|
||||
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true);
|
||||
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true);
|
||||
ndValue = originalDimVec;
|
||||
/* To differentiate the input datatype of int8 and others */
|
||||
axisValue[AXIS_C0] = c0;
|
||||
if (originalDimVec.size() == NCHW_DIMENSION_NUM) {
|
||||
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N];
|
||||
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C];
|
||||
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H];
|
||||
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W];
|
||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0);
|
||||
axisValue[AXIS_Co] = c0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AxisUtil::GetAxisValueByNCHW(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue,
|
||||
vector<int64_t> &ndValue) {
|
||||
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true);
|
||||
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true);
|
||||
/* C0 Must be set for case ND or 2D-NCHW to NZ */
|
||||
axisValue[AXIS_C0] = c0;
|
||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED,"[ERROR]Parameter is invalid!"),
|
||||
return false);
|
||||
|
||||
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N];
|
||||
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C];
|
||||
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H];
|
||||
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W];
|
||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0);
|
||||
axisValue[AXIS_Co] = c0;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AxisUtil::GetAxisValueByNHWC(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue,
|
||||
vector<int64_t> &ndValue) {
|
||||
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true);
|
||||
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true);
|
||||
/* C0 Must be set for case ND or 2D-NHWC to NZ */
|
||||
axisValue[AXIS_C0] = c0;
|
||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED, "[ERROR]Parameter is invalid!"),
|
||||
return false);
|
||||
|
||||
axisValue[AXIS_N] = originalDimVec[AXIS_NHWC_DIM_N];
|
||||
axisValue[AXIS_C] = originalDimVec[AXIS_NHWC_DIM_C];
|
||||
axisValue[AXIS_H] = originalDimVec[AXIS_NHWC_DIM_H];
|
||||
axisValue[AXIS_W] = originalDimVec[AXIS_NHWC_DIM_W];
|
||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NHWC_DIM_C], (int64_t)c0);
|
||||
axisValue[AXIS_Co] = c0;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AxisUtil::GetAxisValueByNC1HWC0(const vector<int64_t> &originalDimVec, const uint32_t &c0,
|
||||
vector<int64_t> &axisValue, vector<int64_t> &ndValue) {
|
||||
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true);
|
||||
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true);
|
||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED,"[ERROR]Parameter is invalid!"),
|
||||
return false);
|
||||
|
||||
auto dimSize = originalDimVec.size();
|
||||
if (dimSize == DIM_DEFAULT_SIZE + 1) {
|
||||
axisValue[AXIS_C1] = originalDimVec[AXIS_NC1HWC0_DIM_C1];
|
||||
axisValue[AXIS_C0] = originalDimVec[AXIS_NC1HWC0_DIM_C0];
|
||||
axisValue[AXIS_C] = axisValue[AXIS_C1] * axisValue[AXIS_C0];
|
||||
} else {
|
||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0);
|
||||
axisValue[AXIS_C0] = c0;
|
||||
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C];
|
||||
}
|
||||
|
||||
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N];
|
||||
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H];
|
||||
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W];
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AxisUtil::GetAxisValueByHWCN(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue,
|
||||
vector<int64_t> &ndValue) {
|
||||
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true);
|
||||
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true);
|
||||
/* C0 Must be set for case ND or 2D-NHWC to NZ */
|
||||
axisValue[AXIS_C0] = c0;
|
||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED, "[ERROR]Parameter is invalid!"),
|
||||
return false);
|
||||
|
||||
axisValue[AXIS_N] = originalDimVec[AXIS_HWCN_DIM_N];
|
||||
axisValue[AXIS_C] = originalDimVec[AXIS_HWCN_DIM_C];
|
||||
axisValue[AXIS_H] = originalDimVec[AXIS_HWCN_DIM_H];
|
||||
axisValue[AXIS_W] = originalDimVec[AXIS_HWCN_DIM_W];
|
||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_HWCN_DIM_C], (int64_t)c0);
|
||||
axisValue[AXIS_Co] = c0;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AxisUtil::GetAxisValueByC1HWNCoC0(const vector<int64_t> &originalDimVec, const uint32_t &c0,
|
||||
vector<int64_t> &axisValue, vector<int64_t> &ndValue) {
|
||||
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true);
|
||||
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true);
|
||||
/* C0 Must be set for case ND or 2D-NHWC to NZ */
|
||||
axisValue[AXIS_C0] = c0;
|
||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED, "[ERROR]Parameter is invalid!"),
|
||||
return false);
|
||||
|
||||
axisValue[AXIS_N] = originalDimVec[AXIS_C1HWNCoC0_DIM_N];
|
||||
axisValue[AXIS_C] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1] * c0;
|
||||
axisValue[AXIS_H] = originalDimVec[AXIS_C1HWNCoC0_DIM_H];
|
||||
axisValue[AXIS_W] = originalDimVec[AXIS_C1HWNCoC0_DIM_W];
|
||||
axisValue[AXIS_C1] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1];
|
||||
axisValue[AXIS_Co] = originalDimVec[AXIS_C1HWNCoC0_DIM_Co];
|
||||
return true;
|
||||
}
|
||||
} // namespace transformer
|
||||
} // namespace common
|
@ -0,0 +1,242 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file transfer_shape_according_to_format.cpp
|
||||
* \brief set shape according to original format and current format
|
||||
*/
|
||||
#include "transformer/inc/transfer_shape_according_to_format.h"
|
||||
|
||||
namespace common {
|
||||
namespace transformer {
|
||||
using namespace ge;
|
||||
using namespace std;
|
||||
|
||||
ShapeTransferAccordingToFormat::ShapeTransferAccordingToFormat(void) {
|
||||
getNewShapeFuncMap = {
|
||||
{ge::FORMAT_NCHW, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNCHWShapeByAxisValue)},
|
||||
{ge::FORMAT_NHWC, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNHWCShapeByAxisValue)},
|
||||
{ge::FORMAT_NC1HWC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNC1HWC0ShapeByAxisValue)},
|
||||
{ge::FORMAT_FRACTAL_Z, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetFzShapeByAxisValue)},
|
||||
{ge::FORMAT_HWCN, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetHWCNShapeByAxisValue)},
|
||||
{ge::FORMAT_C1HWNCoC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetC1HWNCoC0ShapeByAxisValue)},
|
||||
{ge::FORMAT_FRACTAL_NZ, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNzShapeByAxisValue)}};
|
||||
|
||||
mapOfDtypeAndC0 = {
|
||||
{ge::DT_FLOAT16, SHAPE_NUMBER_16}, {ge::DT_FLOAT, SHAPE_NUMBER_16}, {ge::DT_INT8, SHAPE_NUMBER_32},
|
||||
{ge::DT_INT16, SHAPE_NUMBER_16}, {ge::DT_INT32, SHAPE_NUMBER_16}, {ge::DT_INT64, SHAPE_NUMBER_16},
|
||||
{ge::DT_UINT8, SHAPE_NUMBER_16}, {ge::DT_UINT16, SHAPE_NUMBER_32}, {ge::DT_UINT32, SHAPE_NUMBER_16},
|
||||
{ge::DT_UINT64, SHAPE_NUMBER_16}, {ge::DT_BOOL, SHAPE_NUMBER_16}};
|
||||
}
|
||||
|
||||
bool ShapeTransferAccordingToFormat::GetNCHWShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
|
||||
const vector<int64_t>& axisValue,
|
||||
const vector<int64_t>& ndValue) {
|
||||
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true);
|
||||
/* axisValue is initialized as a size 6 vector. */
|
||||
newShape.push_back(axisValue[AXIS_N]);
|
||||
newShape.push_back(axisValue[AXIS_C]);
|
||||
newShape.push_back(axisValue[AXIS_H]);
|
||||
newShape.push_back(axisValue[AXIS_W]);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ShapeTransferAccordingToFormat::GetNHWCShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
|
||||
const vector<int64_t>& axisValue,
|
||||
const vector<int64_t>& ndValue) {
|
||||
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true);
|
||||
/* axisValue is initialized as a size 6 vector. */
|
||||
newShape.push_back(axisValue[AXIS_N]);
|
||||
newShape.push_back(axisValue[AXIS_H]);
|
||||
newShape.push_back(axisValue[AXIS_W]);
|
||||
newShape.push_back(axisValue[AXIS_C]);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ShapeTransferAccordingToFormat::GetNC1HWC0ShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
|
||||
const vector<int64_t>& axisValue,
|
||||
const vector<int64_t>& ndValue) {
|
||||
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true);
|
||||
/* axisValue is initialized as a size 6 vector. */
|
||||
if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) {
|
||||
newShape.push_back(axisValue[AXIS_N]);
|
||||
newShape.push_back(axisValue[AXIS_C1]);
|
||||
newShape.push_back(axisValue[AXIS_H]);
|
||||
newShape.push_back(axisValue[AXIS_W]);
|
||||
newShape.push_back(axisValue[AXIS_C0]);
|
||||
} else {
|
||||
newShape.push_back(axisValue[AXIS_N]);
|
||||
newShape.push_back(axisValue[AXIS_C]);
|
||||
newShape.push_back(axisValue[AXIS_H]);
|
||||
newShape.push_back(axisValue[AXIS_W]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ShapeTransferAccordingToFormat::GetFzShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
|
||||
const vector<int64_t>& axisValue,
|
||||
const vector<int64_t>& ndValue) {
|
||||
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true);
|
||||
/* axisValue is initialized as a size 6 vector. */
|
||||
if (ndValue.size() == SIZE_OF_CN) {
|
||||
auto sizeOfOriginalVec = ndValue.size();
|
||||
newShape = ndValue;
|
||||
/* sizeOfOriginalVec - 1 mean the last value of original vec
|
||||
* sizeOfOriginalVec - 2 mean the second last value of original vec */
|
||||
newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] =
|
||||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], SHAPE_NUMBER_16);
|
||||
newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] =
|
||||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], axisValue[AXIS_C0]);
|
||||
newShape.push_back(SHAPE_NUMBER_16);
|
||||
newShape.push_back(axisValue[AXIS_C0]);
|
||||
} else {
|
||||
if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) {
|
||||
int64_t hwc1 = axisValue[AXIS_C1] * axisValue[AXIS_H] * axisValue[AXIS_W];
|
||||
newShape.push_back(hwc1);
|
||||
newShape.push_back(DivisionCeiling(axisValue[AXIS_N], NI));
|
||||
newShape.push_back(NI);
|
||||
newShape.push_back(axisValue[AXIS_C0]);
|
||||
} else {
|
||||
newShape.push_back(axisValue[AXIS_N]);
|
||||
newShape.push_back(axisValue[AXIS_C]);
|
||||
newShape.push_back(axisValue[AXIS_H]);
|
||||
newShape.push_back(axisValue[AXIS_W]);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ShapeTransferAccordingToFormat::GetHWCNShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
|
||||
const vector<int64_t>& axisValue,
|
||||
const vector<int64_t>& ndValue) {
|
||||
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true);
|
||||
/* axisValue is initialized as a size 6 vector. */
|
||||
newShape.push_back(axisValue[AXIS_H]);
|
||||
newShape.push_back(axisValue[AXIS_W]);
|
||||
newShape.push_back(axisValue[AXIS_C]);
|
||||
newShape.push_back(axisValue[AXIS_N]);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ShapeTransferAccordingToFormat::GetC1HWNCoC0ShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
|
||||
const vector<int64_t>& axisValue,
|
||||
const vector<int64_t>& ndValue) {
|
||||
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true);
|
||||
/* axisValue is initialized as a size 6 vector. */
|
||||
newShape.push_back(axisValue[AXIS_C1]);
|
||||
newShape.push_back(axisValue[AXIS_H]);
|
||||
newShape.push_back(axisValue[AXIS_W]);
|
||||
newShape.push_back(axisValue[AXIS_N]);
|
||||
newShape.push_back(axisValue[AXIS_Co]);
|
||||
newShape.push_back(axisValue[AXIS_C0]);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ShapeTransferAccordingToFormat::GetNzShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
|
||||
const vector<int64_t>& axisValue,
|
||||
const vector<int64_t>& ndValue) {
|
||||
CHECK(ndValue.empty(), GELOGD("ndValue is empty!"), return true);
|
||||
CHECK(axisValue.empty() || axisValue.size() <= AXIS_C0,
|
||||
GELOGD("AxisValue is empty or its size %zu <= AXIS_C0[%u]", axisValue.size(), AXIS_C0), return true);
|
||||
uint32_t sizeOfOriginalVec = ndValue.size();
|
||||
if (sizeOfOriginalVec < MINIMUM_NZ_SHAPE_DIM_NUM) {
|
||||
GELOGD("ndValue's dim num is less than 2!");
|
||||
return true;
|
||||
}
|
||||
/* axisValue is initialized as a size 6 vector. */
|
||||
newShape = ndValue;
|
||||
|
||||
/* sizeOfOriginalVec - 1 mean the last value of original vec
|
||||
* sizeOfOriginalVec - 2 mean the second last value of original vec */
|
||||
newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] =
|
||||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], (int64_t)SHAPE_NUMBER_16);
|
||||
|
||||
newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] =
|
||||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], axisValue[AXIS_C0]);
|
||||
newShape.push_back(SHAPE_NUMBER_16);
|
||||
newShape.push_back(axisValue[AXIS_C0]);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& shapeAndFormatInfo, int64_t* c) {
|
||||
/* The default new shape is old shape */
|
||||
shapeAndFormatInfo.newShape = shapeAndFormatInfo.oldShape;
|
||||
if (shapeAndFormatInfo.oldFormat >= ge::FORMAT_RESERVED || shapeAndFormatInfo.newFormat >= ge::FORMAT_RESERVED) {
|
||||
GELOGE(GRAPH_FAILED, "Old format %u or new format %u is invalid!", shapeAndFormatInfo.oldFormat,
|
||||
shapeAndFormatInfo.newFormat);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (shapeAndFormatInfo.currentDataType >= ge::DT_UNDEFINED) {
|
||||
GELOGE(GRAPH_FAILED, "currentDataType %u is invalid!", shapeAndFormatInfo.currentDataType);
|
||||
return false;
|
||||
}
|
||||
AxisUtil* axisutil_object = new AxisUtil();
|
||||
if (!axisutil_object->HasAxisValueFunc(shapeAndFormatInfo.oldFormat)) {
|
||||
delete axisutil_object;
|
||||
return true;
|
||||
}
|
||||
|
||||
auto iterGetNewShapeFunc = getNewShapeFuncMap.find(shapeAndFormatInfo.newFormat);
|
||||
if (iterGetNewShapeFunc == getNewShapeFuncMap.end()) {
|
||||
GELOGD("Can not get new shape of new format %u!", shapeAndFormatInfo.newFormat);
|
||||
delete axisutil_object;
|
||||
return true;
|
||||
}
|
||||
GELOGD("Original format %u, new format %u", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat);
|
||||
GetNewShapeByAxisValueAndFormatPtr getNewShapeFunc = iterGetNewShapeFunc->second;
|
||||
CHECK_NOTNULL(getNewShapeFunc);
|
||||
std::vector<int64_t> axisValue;
|
||||
for (uint32_t i = 0; i < AXIS_BOTTOM; i++) {
|
||||
axisValue.push_back(1);
|
||||
}
|
||||
std::vector<int64_t> ndValue;
|
||||
uint32_t c0;
|
||||
if (mapOfDtypeAndC0.empty()) {
|
||||
c0 = SHAPE_NUMBER_16;
|
||||
} else {
|
||||
auto iterGetC0 = mapOfDtypeAndC0.find(shapeAndFormatInfo.currentDataType);
|
||||
if (iterGetC0 == mapOfDtypeAndC0.end()) {
|
||||
GELOGE(GRAPH_FAILED, "Dtype is not support.");
|
||||
delete axisutil_object;
|
||||
return true;
|
||||
}
|
||||
c0 = iterGetC0->second;
|
||||
}
|
||||
|
||||
// The value of C0 should be 4 while format is 5HD-4 or FRAZ-4
|
||||
if (shapeAndFormatInfo.newFormat == ge::FORMAT_NC1HWC0_C04) {
|
||||
c0 = SHAPE_DIM_VALUE_C04;
|
||||
}
|
||||
|
||||
bool status = axisutil_object->GetAxisValueByOriginFormat(
|
||||
shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.oldShape, c0, axisValue, ndValue);
|
||||
if (status != true && shapeAndFormatInfo.newFormat != ge::FORMAT_FRACTAL_NZ) {
|
||||
delete axisutil_object;
|
||||
return true;
|
||||
}
|
||||
delete axisutil_object;
|
||||
|
||||
shapeAndFormatInfo.newShape.clear();
|
||||
(*getNewShapeFunc)(shapeAndFormatInfo.newShape, shapeAndFormatInfo.opImplType, axisValue, ndValue);
|
||||
if (c != nullptr) {
|
||||
*c = axisValue[AXIS_C];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace transformer
|
||||
} // namespace common
|
@ -0,0 +1,50 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_
|
||||
#define COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
#include "external/graph/types.h"
|
||||
#include "graph/op_desc.h"
|
||||
#include "graph/ge_tensor.h"
|
||||
#include "transformer/inc/transfer_shape_according_to_format.h"
|
||||
|
||||
namespace ge {
|
||||
class NodeShapeTransUtils {
|
||||
public:
|
||||
bool CatchFormatAndShape();
|
||||
bool UpdateFormatAndShape();
|
||||
|
||||
explicit NodeShapeTransUtils(OpDescPtr op_desc) : op_desc_(op_desc) {}
|
||||
|
||||
~NodeShapeTransUtils() {}
|
||||
|
||||
private:
|
||||
std::map<std::string, Format> map_format_in_;
|
||||
std::map<std::string, Format> map_ori_format_in_;
|
||||
std::map<std::string, DataType> map_dtype_in_;
|
||||
std::map<std::string, Format> map_format_out_;
|
||||
std::map<std::string, Format> map_ori_format_out_;
|
||||
std::map<std::string, DataType> map_dtype_out_;
|
||||
std::map<std::string, uint32_t> inputs_;
|
||||
std::map<std::string, uint32_t> outputs_;
|
||||
|
||||
OpDescPtr op_desc_;
|
||||
};
|
||||
} // namespace ge
|
||||
#endif // COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue