[MSLITE] Fix bug of batchnorm_convert_scale_pass

pull/12385/head
wang_shaocong 4 years ago
parent 94054172ce
commit fa2e454af3

@ -40,6 +40,7 @@ namespace {
constexpr const float EPS = 1e-8;
constexpr const float EPS_DEFAULT_FLOAT = 1e-8;
constexpr const float POW_NUM = 0.5;
constexpr uint32_t kQuadrupleNum = 4;
} // namespace
STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) {
@ -52,6 +53,11 @@ STATUS BatchNormConvertScalePass::Run(MetaGraphT *graph) {
continue;
}
auto input_index = node->inputIndex.at(0);
if (graph->allTensors.at(input_index)->dims.empty()) {
MS_LOG(WARNING) << "The shape of input tensor is uncertain.";
return RET_OK;
}
auto status = GenNewScaleTensor(graph, node);
if (status != RET_OK) {
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
@ -75,9 +81,13 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std:
return RET_ERROR;
}
// after fusion bn must NHWC
auto input0 = bnNode->inputIndex.at(0);
if (graph->allTensors.at(input0)->dims.size() == kQuadrupleNum) {
scaleParam->axis = -1;
} else {
scaleParam->axis = 1;
}
bnNode->primitive->value.value = scaleParam.release();
auto input0 = bnNode->inputIndex.at(0);
bnNode->inputIndex.clear();
bnNode->inputIndex.push_back(input0);
graph->allTensors.emplace_back(std::move(newScaleWeightTensor));

Loading…
Cancel
Save