diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc index 34296e4ef6..e17e3faa41 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.cc @@ -22,6 +22,7 @@ #include "tools/converter/converter_flags.h" #include "third_party/securec/include/securec.h" #include "src/common/log_adapter.h" +#include "src/common/common.h" #include "tools/common/tensor_util.h" #include "include/errorcode.h" #include "schema/inner/model_generated.h" @@ -39,7 +40,6 @@ namespace { constexpr const float EPS = 1e-8; constexpr const float EPS_DEFAULT_FLOAT = 1e-8; constexpr const float POW_NUM = 0.5; -constexpr const int32_t NCHW_DIM_C = 1; } // namespace STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) { @@ -74,7 +74,9 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std: MS_LOG(ERROR) << "new scaleParam failed"; return RET_ERROR; } - scaleParam->axis = NCHW_DIM_C; + int32_t axis = + (graph->allTensors.at(bnNode->inputIndex.at(1))->format == Format_NHWC) ? (int32_t)NHWC_C : (int32_t)NCHW_C; + scaleParam->axis = axis; bnNode->primitive->value.value = scaleParam.release(); auto input0 = bnNode->inputIndex.at(0); bnNode->inputIndex.clear();