gather_op modified

revert-3824-remove_grad_op_type
zchen0211 8 years ago
parent 2a42a73db1
commit f6bffd4e1f

@ -51,8 +51,10 @@ Y = X[Index]
class GatherGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
->Resize(ctx.Input<Tensor>("X")->dims());
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
X_grad->Resize(X->dims());
}
};

Loading…
Cancel
Save