modify gather_op with test

revert-3824-remove_grad_op_type
zchen0211 8 years ago
parent caaa5f86b9
commit 2a42a73db1

@ -27,6 +27,9 @@ class GatherOp : public framework::OperatorWithKernel {
"Inputs of GatherOp must all be set");
int batch_size = ctx.Input<Tensor>(1)->dims()[0];
PADDLE_ENFORCE(batch_size > 0);
paddle::framework::DDim output_dims(ctx.Input<Tensor>(0)->dims());
output_dims[0] = batch_size;
ctx.Output<Tensor>(0)->Resize(output_dims);
}
};
@ -48,8 +51,8 @@ Y = X[Index]
class GatherGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
// ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
// ->Resize(ctx.Input<Tensor>("X")->dims());
ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
->Resize(ctx.Input<Tensor>("X")->dims());
}
};

Loading…
Cancel
Save