|
|
|
|
@ -126,7 +126,8 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
|
|
|
|
|
dim3 grid(1, in1_rows.size());
|
|
|
|
|
SelectedRowsAddTensorKernel<
|
|
|
|
|
T, block_size><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
|
in1_data, in1_rows.cuda_data(), out_data, in1_row_numel);
|
|
|
|
|
in1_data, in1_rows.CUDAData(context.GetPlace()), out_data,
|
|
|
|
|
in1_row_numel);
|
|
|
|
|
|
|
|
|
|
auto out_eigen = framework::EigenVector<T>::Flatten(*output);
|
|
|
|
|
auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
|
|
|
|
|
@ -153,7 +154,7 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
|
|
|
|
|
auto* in2_value = input2->mutable_value();
|
|
|
|
|
|
|
|
|
|
// concat rows
|
|
|
|
|
in2_rows.insert(in2_rows.end(), in1_rows.begin(), in1_rows.end());
|
|
|
|
|
in2_rows.Extend(in1_rows.begin(), in1_rows.end());
|
|
|
|
|
|
|
|
|
|
auto in1_place = input1.place();
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
|
|
|
|
|
@ -216,7 +217,8 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
|
|
|
|
|
dim3 grid(1, in1_rows.size());
|
|
|
|
|
SelectedRowsAddToTensorKernel<
|
|
|
|
|
T, block_size><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
|
in1_data, in1_rows.cuda_data(), in2_data, in1_row_numel);
|
|
|
|
|
in1_data, in1_rows.CUDAData(context.GetPlace()), in2_data,
|
|
|
|
|
in1_row_numel);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
@ -283,9 +285,10 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
|
|
|
|
|
MergeAddKernel<
|
|
|
|
|
T, 256><<<grid1, threads, 0,
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
|
|
|
|
.stream()>>>(input_data, input_rows.cuda_data(), out_data,
|
|
|
|
|
out.mutable_rows()->cuda_data(),
|
|
|
|
|
out.rows().size(), input_width);
|
|
|
|
|
.stream()>>>(
|
|
|
|
|
input_data, input_rows.CUDAData(context.GetPlace()), out_data,
|
|
|
|
|
out.mutable_rows()->CUDAMutableData(context.GetPlace()),
|
|
|
|
|
out.rows().size(), input_width);
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|