fix sgd for SelectedRows bug

release/0.13.0
qiaolongfei 7 years ago
parent c797adede7
commit 5825196db9

@ -96,8 +96,12 @@ class SGDOpKernel : public framework::OpKernel<T> {
return;
}
size_t param_row_width = param.value().numel() / param.rows().size();
size_t grad_row_width = grad.value().numel() / grad.rows().size();
auto param_row_width = param.value().dims()[1];
auto grad_row_width = grad.value().dims()[1];
VLOG(4) << " param rows: " << param.rows().size()
<< " param memory rows: " << param.value().dims()[0]
<< " grad rows: " << grad.rows().size()
<< " grad memory rows: " << grad.value().dims()[0];
PADDLE_ENFORCE_EQ(param_row_width, grad_row_width,
"param_row should have the same size with grad_row");

Loading…
Cancel
Save