|
|
|
@ -63,7 +63,7 @@ class DetectionOutputKernel : public framework::OpKernel<T> {
|
|
|
|
|
float nms_threshold = context.template Attr<float>("nms_threshold");
|
|
|
|
|
float confidence_threshold =
|
|
|
|
|
context.template Attr<float>("confidence_threshold");
|
|
|
|
|
int batch_size = in_conf->dims()[1];
|
|
|
|
|
size_t batch_size = in_conf->dims()[1];
|
|
|
|
|
int conf_sum_size = in_conf->numel();
|
|
|
|
|
// for softmax
|
|
|
|
|
std::vector<int64_t> conf_shape_softmax_vec(
|
|
|
|
|