|
|
|
@ -269,12 +269,29 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const std::vector<const framework::SelectedRows*>& inputs,
|
|
|
|
|
framework::SelectedRows* output) {
|
|
|
|
|
PADDLE_ENFORCE_GT(inputs.size(), 0, "should have at least one input");
|
|
|
|
|
auto input_width = inputs[0]->value().dims()[1];
|
|
|
|
|
auto input_height = inputs[0]->height();
|
|
|
|
|
if (inputs.size() == 0) {
|
|
|
|
|
VLOG(3) << "no input! return";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
const framework::SelectedRows* has_value_input = nullptr;
|
|
|
|
|
for (auto* in : inputs) {
|
|
|
|
|
if (!in->rows().empty()) {
|
|
|
|
|
has_value_input = in;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (has_value_input == nullptr) {
|
|
|
|
|
VLOG(3) << "no input has value! just return" << std::endl;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto input_width = has_value_input->value().dims()[1];
|
|
|
|
|
auto input_height = has_value_input->height();
|
|
|
|
|
framework::SelectedRows& out = *output;
|
|
|
|
|
std::set<int64_t> merged_row_set;
|
|
|
|
|
for (auto* input : inputs) {
|
|
|
|
|
if (input->rows().empty()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
|
|
|
|
|
"all input should have same "
|
|
|
|
|
"dimension except for the first one");
|
|
|
|
@ -288,7 +305,6 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
|
|
|
|
|
for (size_t i = 0; i < merge_rows.size(); ++i) {
|
|
|
|
|
rows_to_id[merge_rows[i]] = i;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
out.set_height(input_height);
|
|
|
|
|
out.mutable_value()->mutable_data<T>(
|
|
|
|
@ -303,6 +319,9 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
|
|
|
|
|
|
|
|
|
|
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
|
|
|
|
|
for (auto* input : inputs) {
|
|
|
|
|
if (input->rows().empty()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto* input_data = input->value().data<T>();
|
|
|
|
|
auto& input_rows = input->rows();
|
|
|
|
|
|
|
|
|
|