|
|
|
@ -17,7 +17,7 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
template <typename T>
|
|
|
|
|
class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
class SyncBatchNormGradMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
|
|
|
|
|
|
|
|
@ -55,6 +55,6 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(sync_batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
|
|
|
|
|
ops::BatchNormOpInferVarType,
|
|
|
|
|
ops::BatchNormGradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::BatchNormGradMaker<paddle::imperative::OpBase>);
|
|
|
|
|
ops::SyncBatchNormGradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::SyncBatchNormGradMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OPERATOR(sync_batch_norm_grad, ops::BatchNormGradOp);
|
|
|
|
|