|
|
|
@ -296,6 +296,52 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
|
|
|
|
|
out.mutable_rows()->CUDAMutableData(context.GetPlace()),
|
|
|
|
|
out.rows().size(), input_width);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
|
const std::vector<const framework::SelectedRows*>& inputs,
|
|
|
|
|
framework::SelectedRows* output) {
|
|
|
|
|
PADDLE_ENFORCE_GT(inputs.size(), 0, "should have at least one input");
|
|
|
|
|
auto input_width = inputs[0]->value().dims()[1];
|
|
|
|
|
auto input_height = inputs[0]->height();
|
|
|
|
|
framework::SelectedRows& out = *output;
|
|
|
|
|
std::set<int64_t> merged_row_set;
|
|
|
|
|
for (auto* input : inputs) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
|
|
|
|
|
"all input should have same "
|
|
|
|
|
"dimension except for the first one");
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_height, input->height(),
|
|
|
|
|
"all input should have same height");
|
|
|
|
|
merged_row_set.insert(input->rows().begin(), input->rows().end());
|
|
|
|
|
}
|
|
|
|
|
std::vector<int64_t> merge_rows(merged_row_set.begin(),
|
|
|
|
|
merged_row_set.end());
|
|
|
|
|
|
|
|
|
|
out.set_rows(merge_rows);
|
|
|
|
|
out.set_height(input_height);
|
|
|
|
|
out.mutable_value()->mutable_data<T>(
|
|
|
|
|
framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(merge_rows.size()), input_width}),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
|
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
|
|
|
|
|
constant_functor(context, out.mutable_value(), 0.0);
|
|
|
|
|
|
|
|
|
|
auto* out_data = out.mutable_value()->data<T>();
|
|
|
|
|
|
|
|
|
|
const int block_size = 256;
|
|
|
|
|
dim3 threads(block_size, 1);
|
|
|
|
|
|
|
|
|
|
for (auto* input : inputs) {
|
|
|
|
|
auto* input_data = input->value().data<T>();
|
|
|
|
|
auto& input_rows = input->rows();
|
|
|
|
|
dim3 grid1(input_rows.size(), 1);
|
|
|
|
|
|
|
|
|
|
MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
|
|
|
|
|
input_data, input_rows.CUDAData(context.GetPlace()), out_data,
|
|
|
|
|
out.mutable_rows()->CUDAMutableData(context.GetPlace()),
|
|
|
|
|
out.rows().size(), input_width);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template struct MergeAdd<platform::CUDADeviceContext, float>;
|
|
|
|
|