fix unit test

fix_recordio_link
Qiao Longfei 7 years ago
parent 0225957515
commit 14f5a40898

@ -267,10 +267,15 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows& out = *output;
framework::Vector<int64_t> input_rows(input.rows());
if (input_rows.size() == 0) {
return;
}
framework::SelectedRows& out = *output;
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
std::vector<int64_t> merge_rows_cpu(row_set.begin(), row_set.end());
framework::Vector<int64_t> merge_rows(merge_rows_cpu);
auto input_width = input.value().dims()[1];
@ -313,8 +318,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
"all input should have same height");
merged_row_set.insert(input->rows().begin(), input->rows().end());
}
std::vector<int64_t> merge_rows(merged_row_set.begin(),
std::vector<int64_t> merge_rows_cpu(merged_row_set.begin(),
merged_row_set.end());
framework::Vector<int64_t> merge_rows(merge_rows_cpu);
out.set_rows(merge_rows);
out.set_height(input_height);
@ -334,6 +340,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
for (auto* input : inputs) {
auto* input_data = input->value().data<T>();
auto& input_rows = input->rows();
if (input_rows.size() == 0) {
continue;
}
dim3 grid1(input_rows.size(), 1);
MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(

Loading…
Cancel
Save