Merge pull request #9299 from reyoung/feature/refactor_batch_norm

Shrink batch_norm_grad's inputs
helinwang-patch-1
Yu Yang 7 years ago committed by GitHub
commit 9e3e424ecb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -457,12 +457,39 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
}
};
class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *op = new framework::OpDesc();
op->SetType("batch_norm_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetInput("Scale", Input("Scale"));
op->SetInput("SavedMean", Output("SavedMean"));
op->SetInput("SavedVariance", Output("SavedVariance"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
ops::BatchNormGradMaker);
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OP_CPU_KERNEL(
batch_norm,
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>);

Loading…
Cancel
Save