|
|
|
@ -598,36 +598,13 @@ std::unique_ptr<framework::OpDesc> BatchNormGradMaker::Apply() const {
|
|
|
|
|
return std::unique_ptr<framework::OpDesc>(op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class BatchNormInplaceInToOut : public framework::InplaceOpInference {
|
|
|
|
|
public:
|
|
|
|
|
std::unordered_map<std::string, std::string> operator()(
|
|
|
|
|
const framework::OpDesc &op_desc, bool use_cuda) const override {
|
|
|
|
|
return {{"Mean", "MeanOut"}, {"Variance", "VarianceOut"}, {"X", "Y"}};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class BatchNormGradInplaceInToOut : public framework::InplaceOpInference {
|
|
|
|
|
public:
|
|
|
|
|
std::unordered_map<std::string, std::string> operator()(
|
|
|
|
|
const framework::OpDesc &op_desc, bool use_cuda) const override {
|
|
|
|
|
// Scale, Bias, SavedMean, SavedVariance shape is [batch_size, C]
|
|
|
|
|
return {
|
|
|
|
|
{framework::GradVarName("Y"), framework::GradVarName("X")},
|
|
|
|
|
{"SavedMean", framework::GradVarName("Scale")},
|
|
|
|
|
{"SavedVariance", framework::GradVarName("Bias")},
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
|
|
|
|
|
ops::BatchNormOpInferVarType, ops::BatchNormGradMaker)
|
|
|
|
|
// ops::BatchNormInplaceInToOut);
|
|
|
|
|
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp)
|
|
|
|
|
// ops::BatchNormGradInplaceInToOut);
|
|
|
|
|
ops::BatchNormOpInferVarType, ops::BatchNormGradMaker);
|
|
|
|
|
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|