|
|
|
@ -267,10 +267,15 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& input,
|
|
|
|
|
framework::SelectedRows* output) {
|
|
|
|
|
framework::SelectedRows& out = *output;
|
|
|
|
|
framework::Vector<int64_t> input_rows(input.rows());
|
|
|
|
|
if (input_rows.size() == 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::SelectedRows& out = *output;
|
|
|
|
|
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
|
|
|
|
|
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
|
|
|
|
|
std::vector<int64_t> merge_rows_cpu(row_set.begin(), row_set.end());
|
|
|
|
|
framework::Vector<int64_t> merge_rows(merge_rows_cpu);
|
|
|
|
|
|
|
|
|
|
auto input_width = input.value().dims()[1];
|
|
|
|
|
|
|
|
|
@ -313,8 +318,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
|
|
|
|
|
"all input should have same height");
|
|
|
|
|
merged_row_set.insert(input->rows().begin(), input->rows().end());
|
|
|
|
|
}
|
|
|
|
|
std::vector<int64_t> merge_rows(merged_row_set.begin(),
|
|
|
|
|
std::vector<int64_t> merge_rows_cpu(merged_row_set.begin(),
|
|
|
|
|
merged_row_set.end());
|
|
|
|
|
framework::Vector<int64_t> merge_rows(merge_rows_cpu);
|
|
|
|
|
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
out.set_height(input_height);
|
|
|
|
@ -334,6 +340,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
|
|
|
|
|
for (auto* input : inputs) {
|
|
|
|
|
auto* input_data = input->value().data<T>();
|
|
|
|
|
auto& input_rows = input->rows();
|
|
|
|
|
if (input_rows.size() == 0) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
dim3 grid1(input_rows.size(), 1);
|
|
|
|
|
|
|
|
|
|
MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
|
|
|
|
|