|
|
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <set>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
@ -228,6 +229,11 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
|
|
|
|
|
}
|
|
|
|
|
std::vector<int64_t> merge_rows(merged_row_set.begin(),
|
|
|
|
|
merged_row_set.end());
|
|
|
|
|
std::map<int64_t, size_t> rows_to_id;
|
|
|
|
|
for (size_t i = 0; i < merge_rows.size(); ++i) {
|
|
|
|
|
rows_to_id[merge_rows[i]] = i;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
out.set_height(input_height);
|
|
|
|
|
out.mutable_value()->mutable_data<T>(
|
|
|
|
@ -245,7 +251,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
|
|
|
|
|
auto& input_rows = input->rows();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < input_rows.size(); i++) {
|
|
|
|
|
size_t out_i = FindPos(merge_rows, input_rows[i]);
|
|
|
|
|
size_t out_i = rows_to_id[input_rows[i]];
|
|
|
|
|
for (int64_t j = 0; j < input_width; j++) {
|
|
|
|
|
out_data[out_i * input_width + j] += input_data[i * input_width + j];
|
|
|
|
|
}
|
|
|
|
|