|
|
|
@ -51,8 +51,10 @@ Y = X[Index]
|
|
|
|
|
class GatherGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
|
|
|
|
|
->Resize(ctx.Input<Tensor>("X")->dims());
|
|
|
|
|
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto X = ctx.Input<Tensor>("X");
|
|
|
|
|
|
|
|
|
|
X_grad->Resize(X->dims());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|