|
|
|
@ -62,6 +62,58 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp {
|
|
|
|
|
public:
|
|
|
|
|
using paddle::operators::BatchNormGradOp::BatchNormGradOp;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
// check input
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "InplaceABNGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
|
|
|
|
|
"Y@GRAD", "InplaceABNGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean",
|
|
|
|
|
"InplaceABNGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance",
|
|
|
|
|
"InplaceABNGrad");
|
|
|
|
|
|
|
|
|
|
// check output
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
|
|
|
|
|
"X@GRAD", "InplaceABNGrad");
|
|
|
|
|
|
|
|
|
|
const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale"));
|
|
|
|
|
const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
has_scale_grad, has_bias_grad,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output(Scale@GRAD) and Output(Bias@GRAD) must be null "
|
|
|
|
|
"or not be null at same time. But now, "
|
|
|
|
|
"has Scale@Grad=[%d], has Bias@GRAD=[%d]",
|
|
|
|
|
has_scale_grad, has_bias_grad));
|
|
|
|
|
|
|
|
|
|
const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
|
|
|
|
|
if (use_global_stats) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
!ctx->Attrs().Get<bool>("use_mkldnn"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Using global stats during training is not supported "
|
|
|
|
|
"in gradient op kernel of batch_norm_mkldnn_op now."));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "InplaceABNGrad");
|
|
|
|
|
const auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
const DataLayout data_layout = framework::StringToDataLayout(
|
|
|
|
|
ctx->Attrs().Get<std::string>("data_layout"));
|
|
|
|
|
|
|
|
|
|
const int C =
|
|
|
|
|
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
|
|
|
|
|
? y_dims[1]
|
|
|
|
|
: y_dims[y_dims.size() - 1]);
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), y_dims);
|
|
|
|
|
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
|
|
|
|
|
if (has_scale_grad) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|