|
|
|
@ -19,17 +19,33 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
class GatherOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx.InputSize() == 2, "");
|
|
|
|
|
PADDLE_ENFORCE(ctx.OutputSize() == 1, "");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0),
|
|
|
|
|
"Inputs of GatherOp must all be set");
|
|
|
|
|
int batch_size = ctx.Input<Tensor>(1)->dims()[0];
|
|
|
|
|
// PADDLE_ENFORCE(ctx.InputSize() == 2, "");
|
|
|
|
|
// PADDLE_ENFORCE(ctx.OutputSize() == 1, "");
|
|
|
|
|
// PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0),
|
|
|
|
|
// "Inputs of GatherOp must all be set");
|
|
|
|
|
int batch_size = ctx.Input<Tensor>("Index")->dims()[0];
|
|
|
|
|
PADDLE_ENFORCE(batch_size > 0);
|
|
|
|
|
paddle::framework::DDim output_dims(ctx.Input<Tensor>(0)->dims());
|
|
|
|
|
output_dims[0] = batch_size;
|
|
|
|
|
ctx.Output<Tensor>(0)->Resize(output_dims);
|
|
|
|
|
ctx.Output<Tensor>("Y")->Resize(output_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class GatherGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto X = ctx.Input<Tensor>("X");
|
|
|
|
|
|
|
|
|
|
X_grad->Resize(X->dims());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -47,25 +63,14 @@ Y = X[Index]
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class GatherGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto X = ctx.Input<Tensor>("X");
|
|
|
|
|
|
|
|
|
|
X_grad->Resize(X->dims());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker);
|
|
|
|
|
REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad,
|
|
|
|
|
ops::GatherGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(gather,
|
|
|
|
|
ops::GatherOpKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
REGISTER_GRADIENT_OP(gather, gather_grad, ops::GatherGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
gather_grad,
|
|
|
|
|
ops::GatherGradientOpKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|