[MSLITE] batchnorm to scale bug

pull/11693/head
ling 4 years ago
parent 3708624a25
commit f7cbeb1fe5

@ -22,6 +22,7 @@
#include "tools/converter/converter_flags.h"
#include "third_party/securec/include/securec.h"
#include "src/common/log_adapter.h"
#include "src/common/common.h"
#include "tools/common/tensor_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
@ -39,7 +40,6 @@ namespace {
constexpr const float EPS = 1e-8;
constexpr const float EPS_DEFAULT_FLOAT = 1e-8;
constexpr const float POW_NUM = 0.5;
constexpr const int32_t NCHW_DIM_C = 1;
} // namespace
STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) {
@ -74,7 +74,9 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std:
MS_LOG(ERROR) << "new scaleParam failed";
return RET_ERROR;
}
scaleParam->axis = NCHW_DIM_C;
int32_t axis =
(graph->allTensors.at(bnNode->inputIndex.at(1))->format == Format_NHWC) ? (int32_t)NHWC_C : (int32_t)NCHW_C;
scaleParam->axis = axis;
bnNode->primitive->value.value = scaleParam.release();
auto input0 = bnNode->inputIndex.at(0);
bnNode->inputIndex.clear();

Loading…
Cancel
Save