|
|
@ -589,8 +589,10 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
op->SetInput("SavedVariance", Output("SavedVariance"));
|
|
|
|
op->SetInput("SavedVariance", Output("SavedVariance"));
|
|
|
|
|
|
|
|
|
|
|
|
// used when setting use_global_stats True during training
|
|
|
|
// used when setting use_global_stats True during training
|
|
|
|
op->SetInput("Mean", Output("MeanOut"));
|
|
|
|
if (boost::get<bool>(GetAttr("use_global_stats"))) {
|
|
|
|
op->SetInput("Variance", Output("VarianceOut"));
|
|
|
|
op->SetInput("Mean", Output("MeanOut"));
|
|
|
|
|
|
|
|
op->SetInput("Variance", Output("VarianceOut"));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
op->SetAttrMap(Attrs());
|
|
|
|
op->SetAttrMap(Attrs());
|
|
|
|
|
|
|
|
|
|
|
|