|
|
|
@ -69,9 +69,9 @@ STATUS BatchNormConvertScalePass::ConvertBNToScale(MetaGraphT *graph, const std:
|
|
|
|
|
MS_ASSERT(graph != nullptr);
|
|
|
|
|
MS_ASSERT(bnNode != nullptr);
|
|
|
|
|
bnNode->primitive->value.type = schema::PrimitiveType_Scale;
|
|
|
|
|
std::unique_ptr<ScaleT> scaleParam(new ScaleT());
|
|
|
|
|
std::unique_ptr<ScaleT> scaleParam(new (std::nothrow) ScaleT());
|
|
|
|
|
if (scaleParam == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new transposeParam failed";
|
|
|
|
|
MS_LOG(ERROR) << "new scaleParam failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
scaleParam->axis = NCHW_DIM_C;
|
|
|
|
@ -104,7 +104,7 @@ STATUS BatchNormConvertScalePass::GenNewScaleTensor(MetaGraphT *graph, const std
|
|
|
|
|
newScaleWeightTensor->data.resize(weightShapeSize * sizeof(float));
|
|
|
|
|
auto ret = memcpy_s(newScaleWeightTensor->data.data(), weightShapeSize * sizeof(float), transScale,
|
|
|
|
|
weightShapeSize * sizeof(float));
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy error: " << ret;
|
|
|
|
|
delete[] transScale;
|
|
|
|
|
delete[] transBias;
|
|
|
|
@ -127,7 +127,7 @@ STATUS BatchNormConvertScalePass::GenNewScaleTensor(MetaGraphT *graph, const std
|
|
|
|
|
newScaleBiasTensor->data.resize(weightShapeSize * sizeof(float));
|
|
|
|
|
ret = memcpy_s(newScaleBiasTensor->data.data(), weightShapeSize * sizeof(float), transBias,
|
|
|
|
|
weightShapeSize * sizeof(float));
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy error: " << ret;
|
|
|
|
|
delete[] transScale;
|
|
|
|
|
delete[] transBias;
|
|
|
|
@ -166,9 +166,17 @@ STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::un
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
this->transScale = new (std::nothrow) float[bnChannel];
|
|
|
|
|
if (this->transScale == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new transScale failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
this->transBias = new (std::nothrow) float[bnChannel];
|
|
|
|
|
if (this->transBias == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new transBias failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
// cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps)
|
|
|
|
|
if (memcpy_s(transScale, bnChannel * sizeof(float), varianceData, bnChannel * sizeof(float)) != 0) {
|
|
|
|
|
if (memcpy_s(transScale, bnChannel * sizeof(float), varianceData, bnChannel * sizeof(float)) != EOK) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s transScale error";
|
|
|
|
|
delete[] transScale;
|
|
|
|
|
delete[] transBias;
|
|
|
|
@ -180,6 +188,10 @@ STATUS BatchNormConvertScalePass::GetTransParam(MetaGraphT *graph, const std::un
|
|
|
|
|
for (uint32_t i = 0; i < bnChannel; i++) {
|
|
|
|
|
float tmp = transScale[i] + eps;
|
|
|
|
|
tmp = pow(tmp, POW_NUM);
|
|
|
|
|
if (tmp <= 0.0f) {
|
|
|
|
|
MS_LOG(ERROR) << "divisor 'tmp' cannot be 0";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
transScale[i] = 1 / tmp;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -278,6 +290,7 @@ STATUS BatchNormConvertScalePass::GetBnWeightTensors(MetaGraphT *graph, BNWeight
|
|
|
|
|
STATUS BatchNormConvertScalePass::GetBnEpsilon(const std::unique_ptr<CNodeT> &bnNode) {
|
|
|
|
|
MS_ASSERT(graph != nullptr);
|
|
|
|
|
MS_ASSERT(bnNode != nullptr);
|
|
|
|
|
MS_ASSERT(bnNode->primitive != nullptr);
|
|
|
|
|
if (bnNode->primitive->value.type == schema::PrimitiveType_FusedBatchNorm) {
|
|
|
|
|
eps = bnNode->primitive->value.AsFusedBatchNorm()->epsilon;
|
|
|
|
|
} else if (bnNode->primitive->value.type == schema::PrimitiveType_BatchNorm) {
|
|
|
|
|