|
|
|
@ -65,7 +65,8 @@ class SGDOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto &grad_rows = grad->rows();
|
|
|
|
|
|
|
|
|
|
size_t grad_row_numel = grad_value.numel() / grad_rows.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(grad_row_numel, param_out->numel() / grad_height);
|
|
|
|
|
PADDLE_ENFORCE_EQ(static_cast<int64_t>(grad_row_numel),
|
|
|
|
|
param_out->numel() / grad_height);
|
|
|
|
|
|
|
|
|
|
auto *grad_data = grad_value.data<T>();
|
|
|
|
|
auto *out_data = param_out->data<T>();
|
|
|
|
@ -73,7 +74,7 @@ class SGDOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (size_t i = 0; i < grad_rows.size(); i++) {
|
|
|
|
|
PADDLE_ENFORCE(grad_rows[i] < grad_height,
|
|
|
|
|
"Input rows index should less than height");
|
|
|
|
|
for (int64_t j = 0; j < grad_row_numel; j++) {
|
|
|
|
|
for (size_t j = 0; j < grad_row_numel; j++) {
|
|
|
|
|
out_data[grad_rows[i] * grad_row_numel + j] -=
|
|
|
|
|
lr[0] * grad_data[i * grad_row_numel + j];
|
|
|
|
|
}
|
|
|
|
@ -107,7 +108,7 @@ class SGDOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE(grad.rows()[i] < grad.height(),
|
|
|
|
|
"Input rows index should less than height");
|
|
|
|
|
int64_t id_index = param.index(grad.rows()[i]);
|
|
|
|
|
for (int64_t j = 0; j < grad_row_width; j++) {
|
|
|
|
|
for (size_t j = 0; j < grad_row_width; j++) {
|
|
|
|
|
out_data[id_index * grad_row_width + j] -=
|
|
|
|
|
lr[0] * grad_data[i * grad_row_width + j];
|
|
|
|
|
}
|
|
|
|
|