|
|
|
@ -90,23 +90,29 @@ STATUS BatchNormConvertScalePass::DoFusion(MetaGraphT *graph, const std::string
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
auto bnPath = matchedPath.at(bnOpName);
|
|
|
|
|
status = GetTransParam(graph, bnPath);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "GetTransParam failed: " << status;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
status = GenNewScaleTensor(graph, bnPath);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
|
|
|
|
|
delete[] transScale;
|
|
|
|
|
delete[] transBias;
|
|
|
|
|
transScale = nullptr;
|
|
|
|
|
transBias = nullptr;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
status = ConvertBNToScale(graph, bnPath);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status;
|
|
|
|
|
delete[] transScale;
|
|
|
|
|
delete[] transBias;
|
|
|
|
|
transScale = nullptr;
|
|
|
|
|
transBias = nullptr;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
delete[] transScale;
|
|
|
|
|
delete[] transBias;
|
|
|
|
|
transScale = nullptr;
|
|
|
|
|
transBias = nullptr;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std::shared_ptr<Path> &bnPath) {
|
|
|
|
@ -245,6 +251,10 @@ STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::sh
|
|
|
|
|
// cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps)
|
|
|
|
|
if (memcpy_s(transScale, bnChannel * sizeof(float), varianceData, bnChannel * sizeof(float)) != 0) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s transScale error";
|
|
|
|
|
delete[] transScale;
|
|
|
|
|
delete[] transBias;
|
|
|
|
|
transScale = nullptr;
|
|
|
|
|
transBias = nullptr;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
// 1/sqrt(variance + eps)
|
|
|
|
@ -370,14 +380,5 @@ STATUS BatchNormConvertScalePass::GetBnEpsilon(MetaGraphT *graph) {
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BatchNormConvertScalePass::~BatchNormConvertScalePass() {
|
|
|
|
|
if (this->transScale != nullptr) {
|
|
|
|
|
delete (this->transScale);
|
|
|
|
|
}
|
|
|
|
|
if (this->transBias != nullptr) {
|
|
|
|
|
delete (this->transBias);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace lite
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|