|
|
|
@ -457,12 +457,39 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
auto *op = new framework::OpDesc();
|
|
|
|
|
op->SetType("batch_norm_grad");
|
|
|
|
|
op->SetInput("X", Input("X"));
|
|
|
|
|
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
|
|
|
|
|
|
|
|
|
|
op->SetInput("Scale", Input("Scale"));
|
|
|
|
|
op->SetInput("SavedMean", Output("SavedMean"));
|
|
|
|
|
op->SetInput("SavedVariance", Output("SavedVariance"));
|
|
|
|
|
|
|
|
|
|
op->SetAttrMap(Attrs());
|
|
|
|
|
|
|
|
|
|
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
|
|
|
|
|
|
|
|
|
|
return std::unique_ptr<framework::OpDesc>(op);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
|
|
|
|
|
batch_norm_grad, ops::BatchNormGradOp);
|
|
|
|
|
REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
|
|
|
|
|
ops::BatchNormGradMaker);
|
|
|
|
|
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
batch_norm,
|
|
|
|
|
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|