|
|
@ -117,7 +117,7 @@ class MineHardExamplesKernel : public framework::OpKernel<T> {
|
|
|
|
std::vector<int> neg_indices;
|
|
|
|
std::vector<int> neg_indices;
|
|
|
|
std::transform(loss_idx.begin(), loss_idx.begin() + neg_sel,
|
|
|
|
std::transform(loss_idx.begin(), loss_idx.begin() + neg_sel,
|
|
|
|
std::inserter(sel_indices, sel_indices.begin()),
|
|
|
|
std::inserter(sel_indices, sel_indices.begin()),
|
|
|
|
[](std::pair<T, size_t> l) -> int {
|
|
|
|
[](std::pair<T, size_t>& l) -> int {
|
|
|
|
return static_cast<int>(l.second);
|
|
|
|
return static_cast<int>(l.second);
|
|
|
|
});
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
@ -134,12 +134,8 @@ class MineHardExamplesKernel : public framework::OpKernel<T> {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
for (int m = 0; m < prior_num; ++m) {
|
|
|
|
neg_indices.resize(sel_indices.size());
|
|
|
|
if (match_indices(n, m) == -1 &&
|
|
|
|
std::copy(sel_indices.begin(), sel_indices.end(), neg_indices.begin());
|
|
|
|
sel_indices.find(m) != sel_indices.end()) {
|
|
|
|
|
|
|
|
neg_indices.push_back(m);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
all_neg_indices.push_back(neg_indices);
|
|
|
|
all_neg_indices.push_back(neg_indices);
|
|
|
|