|
|
|
@ -47,12 +47,15 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
|
|
|
|
|
auto *Ids = ctx.Input<Tensor>("Ids");
|
|
|
|
|
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
// In place gradient: dX = dO
|
|
|
|
|
dX->ShareDataWith(*dOut);
|
|
|
|
|
dUpdates->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
// Gradient by Gather: dUpdates = dO[Ids]
|
|
|
|
|
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
|
|
|
|
|
if (dX) {
|
|
|
|
|
// In place gradient: dX = dO
|
|
|
|
|
framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
|
|
|
|
|
}
|
|
|
|
|
if (dUpdates) {
|
|
|
|
|
dUpdates->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
// Gradient by Gather: dUpdates = dO[Ids]
|
|
|
|
|
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|