|
|
@ -103,11 +103,24 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
void EmbSeqPool(const T* table, const int64_t* idx, T* out,
|
|
|
|
void EmbSeqPool(const T* table, const int64_t* idx, T* out,
|
|
|
|
const emb_seq_pool_attr_t* attr) {
|
|
|
|
const emb_seq_pool_attr_t* attr) {
|
|
|
|
PADDLE_ENFORCE_EQ(attr->table_width * attr->index_width, attr->out_width);
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
attr->table_width * attr->index_width, attr->out_width,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The attribute table_width * index_width of EmbSeqPool should "
|
|
|
|
|
|
|
|
"be equal to out_width. But table_width * index_width is %d, "
|
|
|
|
|
|
|
|
"out_width is %d.",
|
|
|
|
|
|
|
|
attr->table_width * attr->index_width, attr->out_width));
|
|
|
|
auto check_idx_value_valid = [&](int64_t i) {
|
|
|
|
auto check_idx_value_valid = [&](int64_t i) {
|
|
|
|
PADDLE_ENFORCE_LT(idx[i], attr->table_height, "idx value: %d, i: %d",
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
idx[i], i);
|
|
|
|
idx[i], attr->table_height,
|
|
|
|
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The idx shoud be lower than the attribute table_height of "
|
|
|
|
|
|
|
|
"EmbSeqPool. But %dth of idx is %d and table_height is %d.",
|
|
|
|
|
|
|
|
i, idx[i], attr->table_height));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(idx[i], 0, platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The idx shoud be equal to or larger than "
|
|
|
|
|
|
|
|
"the 0. But %dth of idx is %d.",
|
|
|
|
|
|
|
|
i, idx[i]));
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
for (int64_t w = 0; w != attr->index_width; ++w) {
|
|
|
|
for (int64_t w = 0; w != attr->index_width; ++w) {
|
|
|
@ -168,22 +181,50 @@ void Softmax(const T* x, T* y, int n, int bs, int remain = 1) {
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
|
|
|
|
void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
|
|
|
|
T* out, const sgd_attr_t* attr) {
|
|
|
|
T* out, const sgd_attr_t* attr) {
|
|
|
|
PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width);
|
|
|
|
PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width,
|
|
|
|
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height);
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The attribute param_width of Sgd should be "
|
|
|
|
|
|
|
|
"equal to the attribute grad_width. But param_width "
|
|
|
|
|
|
|
|
"is %d and grad_width is %d.",
|
|
|
|
|
|
|
|
attr->param_width, attr->grad_width));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The attribute selected_rows_size of Sgd should be "
|
|
|
|
|
|
|
|
"equal to or less than the attribute grad_height. "
|
|
|
|
|
|
|
|
"But selected_rows_size is %d and grad_height is %d.",
|
|
|
|
|
|
|
|
attr->selected_rows_size, attr->grad_height));
|
|
|
|
T scalar = -lr[0];
|
|
|
|
T scalar = -lr[0];
|
|
|
|
int width = attr->grad_width;
|
|
|
|
int width = attr->grad_width;
|
|
|
|
if (out == param) {
|
|
|
|
if (out == param) {
|
|
|
|
for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
|
|
|
|
for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
|
|
|
|
auto h_idx = rows[i];
|
|
|
|
auto h_idx = rows[i];
|
|
|
|
PADDLE_ENFORCE_LT(h_idx, attr->param_height);
|
|
|
|
PADDLE_ENFORCE_LT(h_idx, attr->param_height,
|
|
|
|
PADDLE_ENFORCE_GE(h_idx, 0);
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The rows of Sgd should be "
|
|
|
|
|
|
|
|
"less than the attribute. But %dth of rows "
|
|
|
|
|
|
|
|
"is %d and grad_width is %d.",
|
|
|
|
|
|
|
|
i, h_idx, attr->param_height));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(h_idx, 0, platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The rows of Sgd should be "
|
|
|
|
|
|
|
|
"larger than 0. But %dth of rows "
|
|
|
|
|
|
|
|
"is %d.",
|
|
|
|
|
|
|
|
i, h_idx));
|
|
|
|
VAXPY(scalar, grad + i * width, out + h_idx * width, width);
|
|
|
|
VAXPY(scalar, grad + i * width, out + h_idx * width, width);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
|
|
|
|
for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
|
|
|
|
auto h_idx = rows[i];
|
|
|
|
auto h_idx = rows[i];
|
|
|
|
PADDLE_ENFORCE_LT(h_idx, attr->param_height);
|
|
|
|
PADDLE_ENFORCE_LT(h_idx, attr->param_height,
|
|
|
|
PADDLE_ENFORCE_GE(h_idx, 0);
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The rows of Sgd should be "
|
|
|
|
|
|
|
|
"less than the attribute. But %dth of rows "
|
|
|
|
|
|
|
|
"is %d and grad_width is %d.",
|
|
|
|
|
|
|
|
i, h_idx, attr->param_height));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(h_idx, 0, platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The rows of Sgd should be "
|
|
|
|
|
|
|
|
"larger than 0. But %dth of rows "
|
|
|
|
|
|
|
|
"is %d.",
|
|
|
|
|
|
|
|
i, h_idx));
|
|
|
|
VScal(&scalar, grad + i * width, out + h_idx * width, width);
|
|
|
|
VScal(&scalar, grad + i * width, out + h_idx * width, width);
|
|
|
|
VAdd(param + h_idx * width, out + h_idx * width, out + h_idx * width,
|
|
|
|
VAdd(param + h_idx * width, out + h_idx * width, out + h_idx * width,
|
|
|
|
width);
|
|
|
|
width);
|
|
|
|