|
|
|
@ -37,7 +37,7 @@ class ScatterOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
// In place output: Out = Ref, Out[Index] += Updates
|
|
|
|
|
Out->ShareDataWith<T>(*Ref);
|
|
|
|
|
// Apply ScatterUpdate: Out[index] += Updates[:]
|
|
|
|
|
ScatterAssign<T>(ctx.GetPlace(), Updates, Index, Out);
|
|
|
|
|
ScatterAssign<T>(ctx.device_context(), Updates, Index, Out);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
dRef->ShareDataWith<T>(*dOut);
|
|
|
|
|
dUpdates->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
// Gradient by Gather: dUpdates += dO[Index]
|
|
|
|
|
CPUGather<T>(ctx.GetPlace(), dOut, Index, dUpdates);
|
|
|
|
|
CPUGather<T>(ctx.device_context(), dOut, Index, dUpdates);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|