|
|
|
@ -152,13 +152,10 @@ static inline void CalcSCEGradWithWeight(const T* loss_grad, Tensor* grad,
|
|
|
|
|
const T* label_data = label.data<T>();
|
|
|
|
|
const T* weight_data = weight.data<T>();
|
|
|
|
|
|
|
|
|
|
// LOG(ERROR) << "SCE grad start";
|
|
|
|
|
for (int i = 0; i < n; i++) {
|
|
|
|
|
for (int j = 0; j < stride; j++) {
|
|
|
|
|
grad_data[j] = (1.0 / (1.0 + std::exp(-x_data[j])) - label_data[j]) *
|
|
|
|
|
weight_data[j] * loss_grad[i];
|
|
|
|
|
// if (j == 18) LOG(ERROR) << x_data[j] << " " << label_data[j] << " " <<
|
|
|
|
|
// weight_data[j] << " " << loss_grad[i];
|
|
|
|
|
}
|
|
|
|
|
grad_data += stride;
|
|
|
|
|
x_data += stride;
|
|
|
|
@ -258,8 +255,7 @@ static void PreProcessGTBox(const Tensor& gt_box, const Tensor& gt_label,
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < n; i++) {
|
|
|
|
|
for (int j = 0; j < b; j++) {
|
|
|
|
|
if (isZero<T>(gt_box_t(i, j, 0)) && isZero<T>(gt_box_t(i, j, 1)) &&
|
|
|
|
|
isZero<T>(gt_box_t(i, j, 2)) && isZero<T>(gt_box_t(i, j, 3))) {
|
|
|
|
|
if (isZero<T>(gt_box_t(i, j, 2)) && isZero<T>(gt_box_t(i, j, 3))) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -425,12 +421,6 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
|
|
|
|
|
loss_weight_conf_notarget, loss_data);
|
|
|
|
|
CalcSCEWithWeight<T>(pred_class, tclass, obj_mask_expand, loss_weight_class,
|
|
|
|
|
loss_data);
|
|
|
|
|
|
|
|
|
|
// loss_data[0] = (loss_weight_xy * (loss_x + loss_y) +
|
|
|
|
|
// loss_weight_wh * (loss_w + loss_h) +
|
|
|
|
|
// loss_weight_conf_target * loss_conf_target +
|
|
|
|
|
// loss_weight_conf_notarget * loss_conf_notarget +
|
|
|
|
|
// loss_weight_class * loss_class) / n;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -494,8 +484,6 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto tweight_t = EigenTensor<T, 4>::From(tweight);
|
|
|
|
|
obj_weight_t = obj_mask_t * tweight_t;
|
|
|
|
|
|
|
|
|
|
// LOG(ERROR) << obj_mask_t;
|
|
|
|
|
|
|
|
|
|
Tensor obj_mask_expand;
|
|
|
|
|
obj_mask_expand.mutable_data<T>({n, an_num, h, w, class_num},
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|