|
|
@ -35,7 +35,7 @@ class ScatterOpKernel : public framework::OpKernel<T> {
|
|
|
|
auto *Out = ctx.Output<Tensor>("Out");
|
|
|
|
auto *Out = ctx.Output<Tensor>("Out");
|
|
|
|
|
|
|
|
|
|
|
|
// In place output: Out = X, Out[Ids] += Updates
|
|
|
|
// In place output: Out = X, Out[Ids] += Updates
|
|
|
|
Out->ShareDataWith(*X);
|
|
|
|
framework::TensorCopySync(*X, ctx.GetPlace(), Out);
|
|
|
|
// Apply ScatterUpdate: Out[index] += Updates[:]
|
|
|
|
// Apply ScatterUpdate: Out[index] += Updates[:]
|
|
|
|
ScatterAssign<T>(ctx.device_context(), *Updates, *Ids, Out);
|
|
|
|
ScatterAssign<T>(ctx.device_context(), *Updates, *Ids, Out);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -53,7 +53,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
|
|
|
|
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
|
|
// In place gradient: dX = dO
|
|
|
|
// In place gradient: dX = dO
|
|
|
|
dX->ShareDataWith(*dOut);
|
|
|
|
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
|
|
|
|
dUpdates->mutable_data<T>(ctx.GetPlace());
|
|
|
|
dUpdates->mutable_data<T>(ctx.GetPlace());
|
|
|
|
// Gradient by Gather: dUpdates += dO[Ids]
|
|
|
|
// Gradient by Gather: dUpdates += dO[Ids]
|
|
|
|
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
|
|
|
|
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
|
|
|
|