|
|
|
@ -22,6 +22,8 @@
|
|
|
|
|
#include "utils/utils.h"
|
|
|
|
|
#include "tools/optimizer/common/gllo_utils.h"
|
|
|
|
|
#include "securec/include/securec.h"
|
|
|
|
|
#include "src/ops/batch_norm.h"
|
|
|
|
|
#include "src/ops/fused_batchnorm.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore::opt {
|
|
|
|
|
namespace {
|
|
|
|
@ -94,7 +96,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const {
|
|
|
|
|
auto bn_mean_var = std::make_shared<CondVar>(IsParamNode);
|
|
|
|
|
auto bn_variable_var = std::make_shared<CondVar>(IsParamNode);
|
|
|
|
|
auto bn_other_var = std::make_shared<SeqVar>();
|
|
|
|
|
return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var});;
|
|
|
|
|
return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var});
|
|
|
|
|
}
|
|
|
|
|
// BatchNorm weight Tensor definition:
|
|
|
|
|
// caffe
|
|
|
|
@ -106,7 +108,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const {
|
|
|
|
|
// estimated_mean --2
|
|
|
|
|
// estimated_variance --3
|
|
|
|
|
const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale,
|
|
|
|
|
float *trans_bias) const {
|
|
|
|
|
float *trans_bias) const {
|
|
|
|
|
MS_ASSERT(bn_node != nullptr);
|
|
|
|
|
AnfNodePtr bn_mean_node = nullptr;
|
|
|
|
|
AnfNodePtr bn_variance_node = nullptr;
|
|
|
|
@ -119,13 +121,19 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern
|
|
|
|
|
bn_variance_node = bn_node->input(kCaffeBNVarIndex);
|
|
|
|
|
CheckIfNodeIsParam(bn_mean_node);
|
|
|
|
|
CheckIfNodeIsParam(bn_variance_node);
|
|
|
|
|
eps = primitiveT_value->GetPrimitiveT()->value.AsBatchNorm()->epsilon;
|
|
|
|
|
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value));
|
|
|
|
|
auto primc = utils::cast<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value);
|
|
|
|
|
MS_ASSERT(primc != nullptr);
|
|
|
|
|
eps = primc->GetEpsilon();
|
|
|
|
|
} else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) {
|
|
|
|
|
bn_scale_node = bn_node->input(kTFBNScaleIndex);
|
|
|
|
|
bn_bias_node = bn_node->input(kTFBNBiasIndex);
|
|
|
|
|
bn_mean_node = bn_node->input(kTFBNMeanIndex);
|
|
|
|
|
bn_variance_node = bn_node->input(kTFBNVarIndex);
|
|
|
|
|
eps = primitiveT_value->GetPrimitiveT()->value.AsFusedBatchNorm()->epsilon;
|
|
|
|
|
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value));
|
|
|
|
|
auto primc = utils::cast<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value);
|
|
|
|
|
MS_ASSERT(primc != nullptr);
|
|
|
|
|
eps = primc->GetEpsilon();
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "not caffe or tf batchnorm op.";
|
|
|
|
|
}
|
|
|
|
|