|
|
|
@ -87,108 +87,6 @@ struct MergeAdd {
|
|
|
|
|
framework::SelectedRows* output);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct MergeAdd<platform::CPUDeviceContext, float> {
|
|
|
|
|
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& input) {
|
|
|
|
|
framework::SelectedRows out;
|
|
|
|
|
(*this)(context, input, &out);
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& input,
|
|
|
|
|
framework::SelectedRows* output) {
|
|
|
|
|
framework::SelectedRows& out = *output;
|
|
|
|
|
auto input_rows = input.rows();
|
|
|
|
|
std::vector<int64_t> merge_rows;
|
|
|
|
|
merge_rows.reserve(input_rows.size());
|
|
|
|
|
std::unordered_map<int64_t, size_t> rows_pos_map;
|
|
|
|
|
rows_pos_map.reserve(input_rows.size());
|
|
|
|
|
size_t idx = 0u;
|
|
|
|
|
for (std::vector<int64_t>::iterator iter = input_rows.begin();
|
|
|
|
|
iter != input_rows.end(); ++iter) {
|
|
|
|
|
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
|
|
|
|
|
rows_pos_map[*iter] = idx++;
|
|
|
|
|
merge_rows.emplace_back(*iter);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto input_width = input.value().dims()[1];
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
out.set_height(input.height());
|
|
|
|
|
out.mutable_value()->mutable_data<float>(
|
|
|
|
|
framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(merge_rows.size()), input_width}),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
|
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, float> constant_functor;
|
|
|
|
|
constant_functor(context, out.mutable_value(), 0.0);
|
|
|
|
|
|
|
|
|
|
auto* out_data = out.mutable_value()->data<float>();
|
|
|
|
|
auto* input_data = input.value().data<float>();
|
|
|
|
|
|
|
|
|
|
auto blas = GetBlas<platform::CPUDeviceContext, float>(context);
|
|
|
|
|
for (size_t i = 0; i < input_rows.size(); i++) {
|
|
|
|
|
size_t out_i = rows_pos_map[input_rows[i]];
|
|
|
|
|
float* y = out_data + out_i * input_width;
|
|
|
|
|
const float* x = input_data + i * input_width;
|
|
|
|
|
blas.AXPY(input_width, 1., x, y);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct MergeAdd<platform::CPUDeviceContext, double> {
|
|
|
|
|
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& input) {
|
|
|
|
|
framework::SelectedRows out;
|
|
|
|
|
(*this)(context, input, &out);
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& input,
|
|
|
|
|
framework::SelectedRows* output) {
|
|
|
|
|
framework::SelectedRows& out = *output;
|
|
|
|
|
auto input_rows = input.rows();
|
|
|
|
|
std::vector<int64_t> merge_rows;
|
|
|
|
|
merge_rows.reserve(input_rows.size());
|
|
|
|
|
std::unordered_map<int64_t, size_t> rows_pos_map;
|
|
|
|
|
rows_pos_map.reserve(input_rows.size());
|
|
|
|
|
size_t idx = 0u;
|
|
|
|
|
for (std::vector<int64_t>::iterator iter = input_rows.begin();
|
|
|
|
|
iter != input_rows.end(); ++iter) {
|
|
|
|
|
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
|
|
|
|
|
rows_pos_map[*iter] = idx++;
|
|
|
|
|
merge_rows.emplace_back(*iter);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto input_width = input.value().dims()[1];
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
out.set_height(input.height());
|
|
|
|
|
out.mutable_value()->mutable_data<double>(
|
|
|
|
|
framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(merge_rows.size()), input_width}),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
|
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, double> constant_functor;
|
|
|
|
|
constant_functor(context, out.mutable_value(), 0.0);
|
|
|
|
|
|
|
|
|
|
auto* out_data = out.mutable_value()->data<double>();
|
|
|
|
|
auto* input_data = input.value().data<double>();
|
|
|
|
|
|
|
|
|
|
auto blas = GetBlas<platform::CPUDeviceContext, double>(context);
|
|
|
|
|
for (size_t i = 0; i < input_rows.size(); i++) {
|
|
|
|
|
size_t out_i = rows_pos_map[input_rows[i]];
|
|
|
|
|
double* y = out_data + out_i * input_width;
|
|
|
|
|
const double* x = input_data + i * input_width;
|
|
|
|
|
blas.AXPY(input_width, 1., x, y);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
struct Add {
|
|
|
|
|
framework::SelectedRows operator()(const DeviceContext& context,
|
|
|
|
|