Merge pull request #15111 from jacquesqiao/fix-adam-tmp-var

Fix adam tmp var on cpu
revert-15207-remove_op_handle_lock_and_fix_var
乔龙飞 Qiao Longfei 6 years ago committed by GitHub
commit 7c891e1ecc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -424,16 +424,23 @@ class AdamOpKernel : public framework::OpKernel<T> {
}
}
framework::SelectedRows cpu_grad_merge;
const framework::SelectedRows* grad_merge_ptr;
if (is_strict_sorted) {
grad_merge_ptr = &grad;
} else {
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
framework::SelectedRows* grad_merge_var;
scatter::MergeAdd<DeviceContext, T> merge_func;
auto* grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
.Var()
->GetMutable<framework::SelectedRows>();
if (platform::is_cpu_place(ctx.GetPlace())) {
grad_merge_var = &cpu_grad_merge;
} else {
// FIXME(qiao): GPU also need to fix this
grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
.Var()
->GetMutable<framework::SelectedRows>();
}
merge_func(ctx.template device_context<DeviceContext>(), grad,
grad_merge_var, true);
grad_merge_ptr = grad_merge_var;

Loading…
Cancel
Save