|
|
|
@ -26,10 +26,10 @@ using Tensor = framework::Tensor;
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class GatherOpKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto X = ctx.Input<Tensor>("X");
|
|
|
|
|
auto Index = ctx.Input<Tensor>("Index");
|
|
|
|
|
auto Y = ctx.Output<Tensor>("Y");
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto *X = ctx.Input<Tensor>("X");
|
|
|
|
|
auto *Index = ctx.Input<Tensor>("Index");
|
|
|
|
|
auto *Y = ctx.Output<Tensor>("Out");
|
|
|
|
|
|
|
|
|
|
Y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
Gather<T>(ctx.GetPlace(), X, Index, Y);
|
|
|
|
@ -39,12 +39,13 @@ class GatherOpKernel : public framework::OpKernel {
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class GatherGradientOpKernel : public framework::OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto Index = ctx.Input<Tensor>("Index");
|
|
|
|
|
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto *Index = ctx.Input<Tensor>("Index");
|
|
|
|
|
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
ScatterUpdate<T>(ctx.GetPlace(), dY, Index, dX);
|
|
|
|
|
dX->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
ScatterUpdate<T>(ctx.GetPlace(), dO, Index, dX);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|