|
|
|
@ -24,13 +24,9 @@ class GatherOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
// PADDLE_ENFORCE(ctx.InputSize() == 2, "");
|
|
|
|
|
// PADDLE_ENFORCE(ctx.OutputSize() == 1, "");
|
|
|
|
|
// PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0),
|
|
|
|
|
// "Inputs of GatherOp must all be set");
|
|
|
|
|
int batch_size = ctx.Input<Tensor>("Index")->dims()[0];
|
|
|
|
|
PADDLE_ENFORCE(batch_size > 0);
|
|
|
|
|
paddle::framework::DDim output_dims(ctx.Input<Tensor>(0)->dims());
|
|
|
|
|
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
|
|
|
|
|
paddle::framework::DDim output_dims(ctx.Input<Tensor>("X")->dims());
|
|
|
|
|
output_dims[0] = batch_size;
|
|
|
|
|
ctx.Output<Tensor>("Y")->Resize(output_dims);
|
|
|
|
|
}
|
|
|
|
|