|
|
@ -296,6 +296,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
|
|
|
|
auto input_height = has_value_input->height();
|
|
|
|
auto input_height = has_value_input->height();
|
|
|
|
framework::SelectedRows& out = *output;
|
|
|
|
framework::SelectedRows& out = *output;
|
|
|
|
std::set<int64_t> merged_row_set;
|
|
|
|
std::set<int64_t> merged_row_set;
|
|
|
|
|
|
|
|
size_t row_num = 0;
|
|
|
|
for (auto* input : inputs) {
|
|
|
|
for (auto* input : inputs) {
|
|
|
|
if (input->rows().size() == 0) {
|
|
|
|
if (input->rows().size() == 0) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
@ -305,42 +306,71 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
|
|
|
|
"dimension except for the first one");
|
|
|
|
"dimension except for the first one");
|
|
|
|
PADDLE_ENFORCE_EQ(input_height, input->height(),
|
|
|
|
PADDLE_ENFORCE_EQ(input_height, input->height(),
|
|
|
|
"all input should have same height");
|
|
|
|
"all input should have same height");
|
|
|
|
|
|
|
|
row_num += input->rows().size();
|
|
|
|
merged_row_set.insert(input->rows().begin(), input->rows().end());
|
|
|
|
merged_row_set.insert(input->rows().begin(), input->rows().end());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
std::vector<int64_t> merge_rows(merged_row_set.begin(),
|
|
|
|
|
|
|
|
merged_row_set.end());
|
|
|
|
|
|
|
|
if (sorted_result) {
|
|
|
|
|
|
|
|
std::sort(merge_rows.begin(), merge_rows.end());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
std::unordered_map<int64_t, size_t> rows_to_id;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < merge_rows.size(); ++i) {
|
|
|
|
|
|
|
|
rows_to_id[merge_rows[i]] = i;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
|
|
|
out.set_height(input_height);
|
|
|
|
out.set_height(input_height);
|
|
|
|
out.mutable_value()->mutable_data<T>(
|
|
|
|
out.mutable_value()->mutable_data<T>(
|
|
|
|
framework::make_ddim(
|
|
|
|
framework::make_ddim(
|
|
|
|
{static_cast<int64_t>(merge_rows.size()), input_width}),
|
|
|
|
{static_cast<int64_t>(merged_row_set.size()), input_width}),
|
|
|
|
context.GetPlace());
|
|
|
|
context.GetPlace());
|
|
|
|
|
|
|
|
auto* out_data = out.mutable_value()->data<T>();
|
|
|
|
|
|
|
|
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
|
|
|
|
if (merged_row_set.size() == row_num && !sorted_result) {
|
|
|
|
constant_functor(context, out.mutable_value(), 0.0);
|
|
|
|
// no duplicated ids, just concat the result together
|
|
|
|
|
|
|
|
std::vector<int64_t> merge_rows;
|
|
|
|
|
|
|
|
merge_rows.reserve(row_num);
|
|
|
|
|
|
|
|
// concat rows
|
|
|
|
|
|
|
|
for (auto* in : inputs) {
|
|
|
|
|
|
|
|
merge_rows.insert(merge_rows.end(), in->rows().begin(),
|
|
|
|
|
|
|
|
in->rows().end());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
|
|
|
auto in_place = inputs[0]->place();
|
|
|
|
|
|
|
|
auto out_place = out.place();
|
|
|
|
|
|
|
|
int64_t copied_numel = 0;
|
|
|
|
|
|
|
|
for (auto* in : inputs) {
|
|
|
|
|
|
|
|
auto* in_data = in->value().data<T>();
|
|
|
|
|
|
|
|
auto in_numel = in->value().numel();
|
|
|
|
|
|
|
|
memory::Copy(boost::get<platform::CPUPlace>(out_place),
|
|
|
|
|
|
|
|
out_data + copied_numel,
|
|
|
|
|
|
|
|
boost::get<platform::CPUPlace>(in_place), in_data,
|
|
|
|
|
|
|
|
in_numel * sizeof(T));
|
|
|
|
|
|
|
|
copied_numel += in_numel;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
std::vector<int64_t> merge_rows(merged_row_set.begin(),
|
|
|
|
|
|
|
|
merged_row_set.end());
|
|
|
|
|
|
|
|
|
|
|
|
auto* out_data = out.mutable_value()->data<T>();
|
|
|
|
if (sorted_result) {
|
|
|
|
|
|
|
|
std::sort(merge_rows.begin(), merge_rows.end());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
for (auto* input : inputs) {
|
|
|
|
|
|
|
|
if (input->rows().size() == 0) {
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
|
|
|
|
continue;
|
|
|
|
constant_functor(context, out.mutable_value(), 0.0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::unordered_map<int64_t, size_t> rows_to_id;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < merge_rows.size(); ++i) {
|
|
|
|
|
|
|
|
rows_to_id[merge_rows[i]] = i;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto* input_data = input->value().data<T>();
|
|
|
|
|
|
|
|
auto& input_rows = input->rows();
|
|
|
|
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
|
|
|
|
|
|
|
|
for (auto* input : inputs) {
|
|
|
|
for (size_t i = 0; i < input_rows.size(); i++) {
|
|
|
|
if (input->rows().size() == 0) {
|
|
|
|
size_t out_i = rows_to_id[input_rows[i]];
|
|
|
|
continue;
|
|
|
|
elementwise_add_to<platform::CPUDeviceContext, T>(
|
|
|
|
}
|
|
|
|
context, &blas, static_cast<size_t>(input_width),
|
|
|
|
auto* input_data = input->value().data<T>();
|
|
|
|
&input_data[i * input_width], &out_data[out_i * input_width]);
|
|
|
|
auto& input_rows = input->rows();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < input_rows.size(); i++) {
|
|
|
|
|
|
|
|
size_t out_i = rows_to_id[input_rows[i]];
|
|
|
|
|
|
|
|
elementwise_add_to<platform::CPUDeviceContext, T>(
|
|
|
|
|
|
|
|
context, &blas, static_cast<size_t>(input_width),
|
|
|
|
|
|
|
|
&input_data[i * input_width], &out_data[out_i * input_width]);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|