|
|
|
@ -22,6 +22,7 @@
|
|
|
|
|
#include "tools/converter/converter_flags.h"
|
|
|
|
|
#include "third_party/securec/include/securec.h"
|
|
|
|
|
#include "src/common/log_adapter.h"
|
|
|
|
|
#include "src/common/common.h"
|
|
|
|
|
#include "tools/common/tensor_util.h"
|
|
|
|
|
#include "include/errorcode.h"
|
|
|
|
|
#include "schema/inner/model_generated.h"
|
|
|
|
@ -39,7 +40,6 @@ namespace {
|
|
|
|
|
constexpr const float EPS = 1e-8;
|
|
|
|
|
constexpr const float EPS_DEFAULT_FLOAT = 1e-8;
|
|
|
|
|
constexpr const float POW_NUM = 0.5;
|
|
|
|
|
constexpr const int32_t NCHW_DIM_C = 1;
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) {
|
|
|
|
@ -74,7 +74,9 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std:
|
|
|
|
|
MS_LOG(ERROR) << "new scaleParam failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
scaleParam->axis = NCHW_DIM_C;
|
|
|
|
|
int32_t axis =
|
|
|
|
|
(graph->allTensors.at(bnNode->inputIndex.at(1))->format == Format_NHWC) ? (int32_t)NHWC_C : (int32_t)NCHW_C;
|
|
|
|
|
scaleParam->axis = axis;
|
|
|
|
|
bnNode->primitive->value.value = scaleParam.release();
|
|
|
|
|
auto input0 = bnNode->inputIndex.at(0);
|
|
|
|
|
bnNode->inputIndex.clear();
|
|
|
|
|