|
|
|
@ -27,6 +27,9 @@ class GatherOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Inputs of GatherOp must all be set");
|
|
|
|
|
int batch_size = ctx.Input<Tensor>(1)->dims()[0];
|
|
|
|
|
PADDLE_ENFORCE(batch_size > 0);
|
|
|
|
|
paddle::framework::DDim output_dims(ctx.Input<Tensor>(0)->dims());
|
|
|
|
|
output_dims[0] = batch_size;
|
|
|
|
|
ctx.Output<Tensor>(0)->Resize(output_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -48,8 +51,8 @@ Y = X[Index]
|
|
|
|
|
class GatherGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
// ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
|
|
|
|
|
// ->Resize(ctx.Input<Tensor>("X")->dims());
|
|
|
|
|
ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
|
|
|
|
|
->Resize(ctx.Input<Tensor>("X")->dims());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|