|
|
|
@ -637,14 +637,19 @@ class BoxWrapper {
|
|
|
|
|
const std::string& pred_varname, int metric_phase,
|
|
|
|
|
const std::string& cmatch_rank_group,
|
|
|
|
|
const std::string& cmatch_rank_varname,
|
|
|
|
|
int bucket_size = 1000000) {
|
|
|
|
|
bool ignore_rank = false, int bucket_size = 1000000) {
|
|
|
|
|
label_varname_ = label_varname;
|
|
|
|
|
pred_varname_ = pred_varname;
|
|
|
|
|
cmatch_rank_varname_ = cmatch_rank_varname;
|
|
|
|
|
metric_phase_ = metric_phase;
|
|
|
|
|
ignore_rank_ = ignore_rank;
|
|
|
|
|
calculator = new BasicAucCalculator();
|
|
|
|
|
calculator->init(bucket_size);
|
|
|
|
|
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
|
|
|
|
|
if (ignore_rank) { // CmatchAUC
|
|
|
|
|
cmatch_rank_v.emplace_back(atoi(cmatch_rank.c_str()), 0);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
const std::vector<std::string>& cur_cmatch_rank =
|
|
|
|
|
string::split_string(cmatch_rank, "_");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
@ -678,7 +683,13 @@ class BoxWrapper {
|
|
|
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
|
|
|
const auto& cur_cmatch_rank = parse_cmatch_rank(cmatch_rank_data[i]);
|
|
|
|
|
for (size_t j = 0; j < cmatch_rank_v.size(); ++j) {
|
|
|
|
|
if (cmatch_rank_v[j] == cur_cmatch_rank) {
|
|
|
|
|
bool is_matched = false;
|
|
|
|
|
if (ignore_rank_) {
|
|
|
|
|
is_matched = cmatch_rank_v[j].first == cur_cmatch_rank.first;
|
|
|
|
|
} else {
|
|
|
|
|
is_matched = cmatch_rank_v[j] == cur_cmatch_rank;
|
|
|
|
|
}
|
|
|
|
|
if (is_matched) {
|
|
|
|
|
cal->add_data(pred_data[i], label_data[i]);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
@ -689,6 +700,7 @@ class BoxWrapper {
|
|
|
|
|
protected:
|
|
|
|
|
std::vector<std::pair<int, int>> cmatch_rank_v;
|
|
|
|
|
std::string cmatch_rank_varname_;
|
|
|
|
|
bool ignore_rank_;
|
|
|
|
|
};
|
|
|
|
|
class MaskMetricMsg : public MetricMsg {
|
|
|
|
|
public:
|
|
|
|
@ -757,7 +769,7 @@ class BoxWrapper {
|
|
|
|
|
const std::string& pred_varname,
|
|
|
|
|
const std::string& cmatch_rank_varname,
|
|
|
|
|
const std::string& mask_varname, int metric_phase,
|
|
|
|
|
const std::string& cmatch_rank_group,
|
|
|
|
|
const std::string& cmatch_rank_group, bool ignore_rank,
|
|
|
|
|
int bucket_size = 1000000) {
|
|
|
|
|
if (method == "AucCalculator") {
|
|
|
|
|
metric_lists_.emplace(name, new MetricMsg(label_varname, pred_varname,
|
|
|
|
@ -768,10 +780,10 @@ class BoxWrapper {
|
|
|
|
|
metric_phase, cmatch_rank_group,
|
|
|
|
|
cmatch_rank_varname, bucket_size));
|
|
|
|
|
} else if (method == "CmatchRankAucCalculator") {
|
|
|
|
|
metric_lists_.emplace(
|
|
|
|
|
name, new CmatchRankMetricMsg(label_varname, pred_varname,
|
|
|
|
|
metric_phase, cmatch_rank_group,
|
|
|
|
|
cmatch_rank_varname, bucket_size));
|
|
|
|
|
metric_lists_.emplace(name, new CmatchRankMetricMsg(
|
|
|
|
|
label_varname, pred_varname, metric_phase,
|
|
|
|
|
cmatch_rank_group, cmatch_rank_varname,
|
|
|
|
|
ignore_rank, bucket_size));
|
|
|
|
|
} else if (method == "MaskAucCalculator") {
|
|
|
|
|
metric_lists_.emplace(
|
|
|
|
|
name, new MaskMetricMsg(label_varname, pred_varname, metric_phase,
|
|
|
|
@ -955,9 +967,6 @@ class BoxHelper {
|
|
|
|
|
new_input_channel->Close();
|
|
|
|
|
dynamic_cast<MultiSlotDataset*>(dataset_)->SetInputChannel(
|
|
|
|
|
new_input_channel);
|
|
|
|
|
if (dataset_->EnablePvMerge()) {
|
|
|
|
|
dataset_->PreprocessInstance();
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
#ifdef PADDLE_WITH_BOX_PS
|
|
|
|
|