modify gather_op with test

revert-3824-remove_grad_op_type
zchen0211 8 years ago
parent caaa5f86b9
commit 2a42a73db1

@ -27,6 +27,9 @@ class GatherOp : public framework::OperatorWithKernel {
"Inputs of GatherOp must all be set"); "Inputs of GatherOp must all be set");
int batch_size = ctx.Input<Tensor>(1)->dims()[0]; int batch_size = ctx.Input<Tensor>(1)->dims()[0];
PADDLE_ENFORCE(batch_size > 0); PADDLE_ENFORCE(batch_size > 0);
paddle::framework::DDim output_dims(ctx.Input<Tensor>(0)->dims());
output_dims[0] = batch_size;
ctx.Output<Tensor>(0)->Resize(output_dims);
} }
}; };
@ -48,8 +51,8 @@ Y = X[Index]
class GatherGradOp : public framework::OperatorWithKernel { class GatherGradOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
// ctx.Output<Tensor>("X" + framework::kGradVarSuffix) ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
// ->Resize(ctx.Input<Tensor>("X")->dims()); ->Resize(ctx.Input<Tensor>("X")->dims());
} }
}; };

Loading…
Cancel
Save