|
|
|
@ -424,16 +424,23 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::SelectedRows cpu_grad_merge;
|
|
|
|
|
const framework::SelectedRows* grad_merge_ptr;
|
|
|
|
|
if (is_strict_sorted) {
|
|
|
|
|
grad_merge_ptr = &grad;
|
|
|
|
|
} else {
|
|
|
|
|
// merge duplicated rows if any.
|
|
|
|
|
// The rows of grad_merge have been sorted inside MergeAdd functor
|
|
|
|
|
framework::SelectedRows* grad_merge_var;
|
|
|
|
|
scatter::MergeAdd<DeviceContext, T> merge_func;
|
|
|
|
|
auto* grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
|
|
|
|
|
.Var()
|
|
|
|
|
->GetMutable<framework::SelectedRows>();
|
|
|
|
|
if (platform::is_cpu_place(ctx.GetPlace())) {
|
|
|
|
|
grad_merge_var = &cpu_grad_merge;
|
|
|
|
|
} else {
|
|
|
|
|
// FIXME(qiao): GPU also need to fix this
|
|
|
|
|
grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
|
|
|
|
|
.Var()
|
|
|
|
|
->GetMutable<framework::SelectedRows>();
|
|
|
|
|
}
|
|
|
|
|
merge_func(ctx.template device_context<DeviceContext>(), grad,
|
|
|
|
|
grad_merge_var, true);
|
|
|
|
|
grad_merge_ptr = grad_merge_var;
|
|
|
|
|