|
|
|
@ -27,9 +27,9 @@ template <typename Place, typename T>
|
|
|
|
|
class GatherOpKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto X = ctx.Input<Tensor>("X");
|
|
|
|
|
auto Index = ctx.Input<Tensor>("Index");
|
|
|
|
|
auto Y = ctx.Output<Tensor>("Y");
|
|
|
|
|
auto *X = ctx.Input<Tensor>("X");
|
|
|
|
|
auto *Index = ctx.Input<Tensor>("Index");
|
|
|
|
|
auto *Y = ctx.Output<Tensor>("Out");
|
|
|
|
|
|
|
|
|
|
Y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
Gather<T>(ctx.GetPlace(), X, Index, Y);
|
|
|
|
@ -40,11 +40,12 @@ template <typename Place, typename T>
|
|
|
|
|
class GatherGradientOpKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto Index = ctx.Input<Tensor>("Index");
|
|
|
|
|
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
auto *Index = ctx.Input<Tensor>("Index");
|
|
|
|
|
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
ScatterUpdate<T>(ctx.GetPlace(), dY, Index, dX);
|
|
|
|
|
dX->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
ScatterUpdate<T>(ctx.GetPlace(), dO, Index, dX);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|