|
|
|
@ -253,23 +253,26 @@ elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct MergeAdd<platform::CPUDeviceContext, T> {
|
|
|
|
|
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& input) {
|
|
|
|
|
const framework::SelectedRows& input,
|
|
|
|
|
const bool sorted_result = false) {
|
|
|
|
|
framework::SelectedRows out;
|
|
|
|
|
(*this)(context, input, &out);
|
|
|
|
|
(*this)(context, input, &out, sorted_result);
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& input,
|
|
|
|
|
framework::SelectedRows* output) {
|
|
|
|
|
framework::SelectedRows* output,
|
|
|
|
|
const bool sorted_result = false) {
|
|
|
|
|
std::vector<const framework::SelectedRows*> inputs;
|
|
|
|
|
inputs.push_back(&input);
|
|
|
|
|
(*this)(context, inputs, output);
|
|
|
|
|
(*this)(context, inputs, output, sorted_result);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const std::vector<const framework::SelectedRows*>& inputs,
|
|
|
|
|
framework::SelectedRows* output) {
|
|
|
|
|
framework::SelectedRows* output,
|
|
|
|
|
const bool sorted_result = false) {
|
|
|
|
|
if (inputs.size() == 0) {
|
|
|
|
|
VLOG(3) << "no input! return";
|
|
|
|
|
return;
|
|
|
|
@ -302,8 +305,8 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
|
|
|
|
|
}
|
|
|
|
|
std::vector<int64_t> merge_rows(merged_row_set.begin(),
|
|
|
|
|
merged_row_set.end());
|
|
|
|
|
if (sorted_result_) {
|
|
|
|
|
std::sort(merge_rows);
|
|
|
|
|
if (sorted_result) {
|
|
|
|
|
std::sort(merge_rows.begin(), merge_rows.end());
|
|
|
|
|
}
|
|
|
|
|
std::unordered_map<int64_t, size_t> rows_to_id;
|
|
|
|
|
for (size_t i = 0; i < merge_rows.size(); ++i) {
|
|
|
|
|