|
|
|
@ -94,8 +94,20 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
|
|
|
|
|
ctx->ShareLoD("X", framework::GradVarName("X"));
|
|
|
|
|
auto in_x = "X";
|
|
|
|
|
auto out_x_g_n = framework::GradVarName(in_x);
|
|
|
|
|
ctx->SetOutputsDim(out_x_g_n, ctx->GetInputsDim(in_x));
|
|
|
|
|
auto &in_names = ctx->Inputs(in_x);
|
|
|
|
|
auto &out_names = ctx->Outputs(out_x_g_n);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_names.size(), out_names.size(),
|
|
|
|
|
"The number of arguments in %s[%d] and %s[%d] is not equal.", in_x,
|
|
|
|
|
in_names.size(), out_x_g_n, out_names.size());
|
|
|
|
|
for (size_t i = 0; i < in_names.size(); ++i) {
|
|
|
|
|
if (out_names[i] != framework::kEmptyVarName) {
|
|
|
|
|
ctx->ShareLoD(in_x, out_x_g_n, i, i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|