|
|
|
@ -260,26 +260,18 @@ static void CalcYolov3Loss(T* loss_data, const Tensor& input, const Tensor& tx,
|
|
|
|
|
const int class_num = tclass.dims()[4];
|
|
|
|
|
const int grid_num = h * w;
|
|
|
|
|
|
|
|
|
|
// T l = 0.0;
|
|
|
|
|
CalcSCE<T>(loss_data, input_data, tx_data, tweight_data, obj_mask_data, n,
|
|
|
|
|
an_num, grid_num, class_num, 1);
|
|
|
|
|
CalcSCE<T>(loss_data, input_data + grid_num, ty_data, tweight_data,
|
|
|
|
|
obj_mask_data, n, an_num, grid_num, class_num, 1);
|
|
|
|
|
// LOG(ERROR) << "C++ xy: " << loss_data[0] - l;
|
|
|
|
|
// l = loss_data[0];
|
|
|
|
|
CalcL1Loss<T>(loss_data, input_data + 2 * grid_num, tw_data, tweight_data,
|
|
|
|
|
obj_mask_data, n, an_num, grid_num, class_num);
|
|
|
|
|
CalcL1Loss<T>(loss_data, input_data + 3 * grid_num, th_data, tweight_data,
|
|
|
|
|
obj_mask_data, n, an_num, grid_num, class_num);
|
|
|
|
|
// LOG(ERROR) << "C++ wh: " << loss_data[0] - l;
|
|
|
|
|
// l = loss_data[0];
|
|
|
|
|
CalcSCE<T>(loss_data, input_data + 4 * grid_num, tconf_data, conf_mask_data,
|
|
|
|
|
conf_mask_data, n, an_num, grid_num, class_num, 1);
|
|
|
|
|
// LOG(ERROR) << "C++ conf: " << loss_data[0] - l;
|
|
|
|
|
// l = loss_data[0];
|
|
|
|
|
CalcSCE<T>(loss_data, input_data + 5 * grid_num, tclass_data, obj_mask_data,
|
|
|
|
|
obj_mask_data, n, an_num, grid_num, class_num, class_num);
|
|
|
|
|
// LOG(ERROR) << "C++ class: " << loss_data[0] - l;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -329,7 +321,7 @@ static void CalcYolov3LossGrad(T* input_grad_data, const Tensor& loss_grad,
|
|
|
|
|
obj_mask_data, n, an_num, grid_num, class_num, class_num);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
template <typename T>
|
|
|
|
|
class Yolov3LossKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
@ -359,24 +351,24 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
|
|
|
|
|
tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
|
|
|
|
|
tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
math::SetConstant<DeviceContext, T> constant;
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &conf_mask,
|
|
|
|
|
static_cast<T>(1.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &obj_mask,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &tx,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &ty,
|
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, T> constant;
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(),
|
|
|
|
|
&conf_mask, static_cast<T>(1.0));
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(),
|
|
|
|
|
&obj_mask, static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(), &tx,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &tw,
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(), &ty,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &th,
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(), &tw,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &tweight,
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(), &th,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &tconf,
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(),
|
|
|
|
|
&tweight, static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(), &tconf,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &tclass,
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(), &tclass,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
|
|
|
|
|
PreProcessGTBox<T>(*gt_box, *gt_label, ignore_thresh, anchors, input_size,
|
|
|
|
@ -390,7 +382,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
template <typename T>
|
|
|
|
|
class Yolov3LossGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
@ -422,24 +414,24 @@ class Yolov3LossGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
|
|
|
|
|
tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
math::SetConstant<DeviceContext, T> constant;
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &conf_mask,
|
|
|
|
|
static_cast<T>(1.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &obj_mask,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &tx,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &ty,
|
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, T> constant;
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(),
|
|
|
|
|
&conf_mask, static_cast<T>(1.0));
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(),
|
|
|
|
|
&obj_mask, static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(), &tx,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &tw,
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(), &ty,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &th,
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(), &tw,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &tweight,
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(), &th,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &tconf,
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(),
|
|
|
|
|
&tweight, static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(), &tconf,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
constant(ctx.template device_context<DeviceContext>(), &tclass,
|
|
|
|
|
constant(ctx.template device_context<platform::CPUDeviceContext>(), &tclass,
|
|
|
|
|
static_cast<T>(0.0));
|
|
|
|
|
|
|
|
|
|
PreProcessGTBox<T>(*gt_box, *gt_label, ignore_thresh, anchors, input_size,
|
|
|
|
|