|
|
|
@ -96,8 +96,12 @@ class SGDOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t param_row_width = param.value().numel() / param.rows().size();
|
|
|
|
|
size_t grad_row_width = grad.value().numel() / grad.rows().size();
|
|
|
|
|
auto param_row_width = param.value().dims()[1];
|
|
|
|
|
auto grad_row_width = grad.value().dims()[1];
|
|
|
|
|
VLOG(4) << " param rows: " << param.rows().size()
|
|
|
|
|
<< " param memory rows: " << param.value().dims()[0]
|
|
|
|
|
<< " grad rows: " << grad.rows().size()
|
|
|
|
|
<< " grad memory rows: " << grad.value().dims()[0];
|
|
|
|
|
PADDLE_ENFORCE_EQ(param_row_width, grad_row_width,
|
|
|
|
|
"param_row should have the same size with grad_row");
|
|
|
|
|
|
|
|
|
|