|
|
|
@ -413,6 +413,38 @@ class BoxWrapper {
|
|
|
|
|
std::vector<std::pair<int, int>> cmatch_rank_v;
|
|
|
|
|
std::string cmatch_rank_varname_;
|
|
|
|
|
};
|
|
|
|
|
class MaskMetricMsg : public MetricMsg {
|
|
|
|
|
public:
|
|
|
|
|
MaskMetricMsg(const std::string& label_varname,
|
|
|
|
|
const std::string& pred_varname, int is_join,
|
|
|
|
|
const std::string& mask_varname, int bucket_size = 1000000) {
|
|
|
|
|
label_varname_ = label_varname;
|
|
|
|
|
pred_varname_ = pred_varname;
|
|
|
|
|
mask_varname_ = mask_varname;
|
|
|
|
|
is_join_ = is_join;
|
|
|
|
|
calculator = new BasicAucCalculator();
|
|
|
|
|
calculator->init(bucket_size);
|
|
|
|
|
}
|
|
|
|
|
virtual ~MaskMetricMsg() {}
|
|
|
|
|
void add_data(const Scope* exe_scope) override {
|
|
|
|
|
std::vector<int64_t> label_data;
|
|
|
|
|
get_data<int64_t>(exe_scope, label_varname_, &label_data);
|
|
|
|
|
std::vector<float> pred_data;
|
|
|
|
|
get_data<float>(exe_scope, pred_varname_, &pred_data);
|
|
|
|
|
std::vector<int64_t> mask_data;
|
|
|
|
|
get_data<int64_t>(exe_scope, mask_varname_, &mask_data);
|
|
|
|
|
auto cal = GetCalculator();
|
|
|
|
|
auto batch_size = label_data.size();
|
|
|
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
|
|
|
if (mask_data[i] == 1) {
|
|
|
|
|
cal->add_data(pred_data[i], label_data[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::string mask_varname_;
|
|
|
|
|
};
|
|
|
|
|
const std::vector<std::string>& GetMetricNameList() const {
|
|
|
|
|
return metric_name_list_;
|
|
|
|
|
}
|
|
|
|
@ -423,7 +455,8 @@ class BoxWrapper {
|
|
|
|
|
void InitMetric(const std::string& method, const std::string& name,
|
|
|
|
|
const std::string& label_varname,
|
|
|
|
|
const std::string& pred_varname,
|
|
|
|
|
const std::string& cmatch_rank_varname, bool is_join,
|
|
|
|
|
const std::string& cmatch_rank_varname,
|
|
|
|
|
const std::string& mask_varname, bool is_join,
|
|
|
|
|
const std::string& cmatch_rank_group,
|
|
|
|
|
int bucket_size = 1000000) {
|
|
|
|
|
if (method == "AucCalculator") {
|
|
|
|
@ -439,10 +472,14 @@ class BoxWrapper {
|
|
|
|
|
name, new CmatchRankMetricMsg(label_varname, pred_varname,
|
|
|
|
|
is_join ? 1 : 0, cmatch_rank_group,
|
|
|
|
|
cmatch_rank_varname, bucket_size));
|
|
|
|
|
} else if (method == "MaskAucCalculator") {
|
|
|
|
|
metric_lists_.emplace(
|
|
|
|
|
name, new MaskMetricMsg(label_varname, pred_varname, is_join ? 1 : 0,
|
|
|
|
|
mask_varname, bucket_size));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
"PaddleBox only support AucCalculator, MultiTaskAucCalculator and "
|
|
|
|
|
"CmatchRankAucCalculator"));
|
|
|
|
|
"PaddleBox only support AucCalculator, MultiTaskAucCalculator "
|
|
|
|
|
"CmatchRankAucCalculator and MaskAucCalculator"));
|
|
|
|
|
}
|
|
|
|
|
metric_name_list_.emplace_back(name);
|
|
|
|
|
}
|
|
|
|
|