|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
@ -97,41 +98,39 @@ struct MergeAdd<platform::CPUDeviceContext, float> {
|
|
|
|
|
const framework::SelectedRows& input,
|
|
|
|
|
framework::SelectedRows* output) {
|
|
|
|
|
framework::SelectedRows& out = *output;
|
|
|
|
|
auto input_rows = input.rows();
|
|
|
|
|
std::vector<int64_t> merge_rows;
|
|
|
|
|
merge_rows.reserve(input_rows.size());
|
|
|
|
|
std::unordered_map<int64_t, size_t> rows_pos_map;
|
|
|
|
|
rows_pos_map.reserve(input_rows.size());
|
|
|
|
|
size_t idx = 0u;
|
|
|
|
|
for (std::vector<int64_t>::iterator iter = input_rows.begin();
|
|
|
|
|
iter != input_rows.end(); ++iter) {
|
|
|
|
|
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
|
|
|
|
|
rows_pos_map[*iter] = idx++;
|
|
|
|
|
merge_rows.emplace_back(*iter);
|
|
|
|
|
}
|
|
|
|
|
std::vector<int64_t> input_rows(input.rows());
|
|
|
|
|
|
|
|
|
|
std::map<int64_t, std::vector<int64_t>> merge_row_map;
|
|
|
|
|
for (size_t i = 0; i < input_rows.size(); ++i) {
|
|
|
|
|
merge_row_map[input_rows[i]].push_back(i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto input_width = input.value().dims()[1];
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
std::vector<int64_t> merge_rows(merge_row_map.size());
|
|
|
|
|
size_t idx = 0;
|
|
|
|
|
int64_t input_width = input.value().dims()[1];
|
|
|
|
|
out.set_height(input.height());
|
|
|
|
|
out.mutable_value()->mutable_data<float>(
|
|
|
|
|
|
|
|
|
|
auto* out_data = out.mutable_value()->mutable_data<float>(
|
|
|
|
|
framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(merge_rows.size()), input_width}),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
|
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, float> constant_functor;
|
|
|
|
|
constant_functor(context, out.mutable_value(), 0.0);
|
|
|
|
|
|
|
|
|
|
auto* out_data = out.mutable_value()->data<float>();
|
|
|
|
|
auto* input_data = input.value().data<float>();
|
|
|
|
|
auto* in_data = input.value().data<float>();
|
|
|
|
|
|
|
|
|
|
auto blas = GetBlas<platform::CPUDeviceContext, float>(context);
|
|
|
|
|
for (size_t i = 0; i < input_rows.size(); i++) {
|
|
|
|
|
size_t out_i = rows_pos_map[input_rows[i]];
|
|
|
|
|
float* y = out_data + out_i * input_width;
|
|
|
|
|
const float* x = input_data + i * input_width;
|
|
|
|
|
blas.AXPY(input_width, 1., x, y);
|
|
|
|
|
for (auto& row_pair : merge_row_map) {
|
|
|
|
|
auto* out_ptr = out_data + idx * input_width;
|
|
|
|
|
auto& rows = row_pair.second;
|
|
|
|
|
merge_rows[idx] = row_pair.first;
|
|
|
|
|
++idx;
|
|
|
|
|
// rows.size() is always larger than 0
|
|
|
|
|
blas.VCOPY(input_width, in_data + rows[0] * input_width, out_ptr);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 1; i < rows.size(); ++i) {
|
|
|
|
|
blas.AXPY(input_width, 1., in_data + rows[i] * input_width, out_ptr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -148,41 +147,39 @@ struct MergeAdd<platform::CPUDeviceContext, double> {
|
|
|
|
|
const framework::SelectedRows& input,
|
|
|
|
|
framework::SelectedRows* output) {
|
|
|
|
|
framework::SelectedRows& out = *output;
|
|
|
|
|
auto input_rows = input.rows();
|
|
|
|
|
std::vector<int64_t> merge_rows;
|
|
|
|
|
merge_rows.reserve(input_rows.size());
|
|
|
|
|
std::unordered_map<int64_t, size_t> rows_pos_map;
|
|
|
|
|
rows_pos_map.reserve(input_rows.size());
|
|
|
|
|
size_t idx = 0u;
|
|
|
|
|
for (std::vector<int64_t>::iterator iter = input_rows.begin();
|
|
|
|
|
iter != input_rows.end(); ++iter) {
|
|
|
|
|
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
|
|
|
|
|
rows_pos_map[*iter] = idx++;
|
|
|
|
|
merge_rows.emplace_back(*iter);
|
|
|
|
|
}
|
|
|
|
|
std::vector<int64_t> input_rows(input.rows());
|
|
|
|
|
|
|
|
|
|
std::map<int64_t, std::vector<int64_t>> merge_row_map;
|
|
|
|
|
for (size_t i = 0; i < input_rows.size(); ++i) {
|
|
|
|
|
merge_row_map[input_rows[i]].push_back(i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto input_width = input.value().dims()[1];
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
std::vector<int64_t> merge_rows(merge_row_map.size());
|
|
|
|
|
size_t idx = 0;
|
|
|
|
|
int64_t input_width = input.value().dims()[1];
|
|
|
|
|
out.set_height(input.height());
|
|
|
|
|
out.mutable_value()->mutable_data<double>(
|
|
|
|
|
|
|
|
|
|
auto* out_data = out.mutable_value()->mutable_data<double>(
|
|
|
|
|
framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(merge_rows.size()), input_width}),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
|
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, double> constant_functor;
|
|
|
|
|
constant_functor(context, out.mutable_value(), 0.0);
|
|
|
|
|
|
|
|
|
|
auto* out_data = out.mutable_value()->data<double>();
|
|
|
|
|
auto* input_data = input.value().data<double>();
|
|
|
|
|
auto* in_data = input.value().data<double>();
|
|
|
|
|
|
|
|
|
|
auto blas = GetBlas<platform::CPUDeviceContext, double>(context);
|
|
|
|
|
for (size_t i = 0; i < input_rows.size(); i++) {
|
|
|
|
|
size_t out_i = rows_pos_map[input_rows[i]];
|
|
|
|
|
double* y = out_data + out_i * input_width;
|
|
|
|
|
const double* x = input_data + i * input_width;
|
|
|
|
|
blas.AXPY(input_width, 1., x, y);
|
|
|
|
|
for (auto& row_pair : merge_row_map) {
|
|
|
|
|
auto* out_ptr = out_data + idx * input_width;
|
|
|
|
|
auto& rows = row_pair.second;
|
|
|
|
|
merge_rows[idx] = row_pair.first;
|
|
|
|
|
++idx;
|
|
|
|
|
// rows.size() is always larger than 0
|
|
|
|
|
blas.VCOPY(input_width, in_data + rows[0] * input_width, out_ptr);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 1; i < rows.size(); ++i) {
|
|
|
|
|
blas.AXPY(input_width, 1., in_data + rows[i] * input_width, out_ptr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|