|
|
|
@ -282,6 +282,10 @@ class AdamOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
} else if (grad_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
auto& grad =
|
|
|
|
|
Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad");
|
|
|
|
|
if (grad.rows().size() == 0) {
|
|
|
|
|
VLOG(3) << "grad row size is 0!!";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
// merge duplicated rows if any.
|
|
|
|
|
scatter::MergeAdd<DeviceContext, T> merge_func;
|
|
|
|
|
auto grad_merge =
|
|
|
|
|