enhance api. test=develop

revert-14398-imperative
dengkaipeng 6 years ago
parent 95d5060ddd
commit f115eb0d1e

@ -288,7 +288,7 @@ paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'i
paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'anchors', 'class_num', 'ignore_thresh', 'lambda_xy', 'lambda_wh', 'lambda_conf_obj', 'lambda_conf_noobj', 'lambda_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None))
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))
paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))

@ -25,11 +25,14 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
"Input(X) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("GTBox"),
"Input(GTBox) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("GTLabel"),
"Input(GTLabel) of Yolov3LossOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
"Output(Loss) of Yolov3LossOp should not be null.");
auto dim_x = ctx->GetInputDim("X");
auto dim_gt = ctx->GetInputDim("GTBox");
auto dim_gtbox = ctx->GetInputDim("GTBox");
auto dim_gtlabel = ctx->GetInputDim("GTLabel");
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
auto class_num = ctx->Attrs().Get<int>("class_num");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
@ -38,8 +41,15 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(dim_x[1], anchors.size() / 2 * (5 + class_num),
"Input(X) dim[1] should be equal to (anchor_number * (5 "
"+ class_num)).");
PADDLE_ENFORCE_EQ(dim_gt.size(), 3, "Input(GTBox) should be a 3-D tensor");
PADDLE_ENFORCE_EQ(dim_gt[2], 5, "Input(GTBox) dim[2] should be 5");
PADDLE_ENFORCE_EQ(dim_gtbox.size(), 3,
"Input(GTBox) should be a 3-D tensor");
PADDLE_ENFORCE_EQ(dim_gtbox[2], 4, "Input(GTBox) dim[2] should be 5");
PADDLE_ENFORCE_EQ(dim_gtlabel.size(), 2,
"Input(GTBox) should be a 2-D tensor");
PADDLE_ENFORCE_EQ(dim_gtlabel[0], dim_gtbox[0],
"Input(GTBox) and Input(GTLabel) dim[0] should be same");
PADDLE_ENFORCE_EQ(dim_gtlabel[1], dim_gtbox[1],
"Input(GTBox) and Input(GTLabel) dim[1] should be same");
PADDLE_ENFORCE_GT(anchors.size(), 0,
"Attr(anchors) length should be greater then 0.");
PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
@ -73,11 +83,15 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
"The input tensor of ground truth boxes, "
"This is a 3-D tensor with shape of [N, max_box_num, 5], "
"max_box_num is the max number of boxes in each image, "
"In the third dimention, stores label, x, y, w, h, "
"label is an integer to specify box class, x, y is the "
"center cordinate of boxes and w, h is the width and height"
"and x, y, w, h should be divided by input image height to "
"scale to [0, 1].");
"In the third dimention, stores x, y, w, h coordinates, "
"x, y is the center cordinate of boxes and w, h is the "
"width and height and x, y, w, h should be divided by "
"input image height to scale to [0, 1].");
AddInput("GTLabel",
"The input tensor of ground truth label, "
"This is a 2-D tensor with shape of [N, max_box_num], "
"and each element shoudl be an integer to indicate the "
"box class id.");
AddOutput("Loss",
"The output yolov3 loss tensor, "
"This is a 1-D tensor with shape of [1]");
@ -88,19 +102,19 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
"it will be parsed pair by pair.");
AddAttr<float>("ignore_thresh",
"The ignore threshold to ignore confidence loss.");
AddAttr<float>("lambda_xy", "The weight of x, y location loss.")
AddAttr<float>("loss_weight_xy", "The weight of x, y location loss.")
.SetDefault(1.0);
AddAttr<float>("lambda_wh", "The weight of w, h location loss.")
AddAttr<float>("loss_weight_wh", "The weight of w, h location loss.")
.SetDefault(1.0);
AddAttr<float>(
"lambda_conf_obj",
"loss_weight_conf_target",
"The weight of confidence score loss in locations with target object.")
.SetDefault(1.0);
AddAttr<float>("lambda_conf_noobj",
AddAttr<float>("loss_weight_conf_notarget",
"The weight of confidence score loss in locations without "
"target object.")
.SetDefault(1.0);
AddAttr<float>("lambda_class", "The weight of classification loss.")
AddAttr<float>("loss_weight_class", "The weight of classification loss.")
.SetDefault(1.0);
AddComment(R"DOC(
This operator generate yolov3 loss by given predict result and ground
@ -141,10 +155,10 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
Final loss will be represented as follow.
$$
loss = \lambda_{xy} * loss_{xy} + \lambda_{wh} * loss_{wh}
+ \lambda_{conf_obj} * loss_{conf_obj}
+ \lambda_{conf_noobj} * loss_{conf_noobj}
+ \lambda_{class} * loss_{class}
loss = \loss_weight_{xy} * loss_{xy} + \loss_weight_{wh} * loss_{wh}
+ \loss_weight_{conf_target} * loss_{conf_target}
+ \loss_weight_{conf_notarget} * loss_{conf_notarget}
+ \loss_weight_{class} * loss_{class}
$$
)DOC");
}
@ -182,12 +196,14 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
op->SetType("yolov3_loss_grad");
op->SetInput("X", Input("X"));
op->SetInput("GTBox", Input("GTBox"));
op->SetInput("GTLabel", Input("GTLabel"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("GTBox"), {});
op->SetOutput(framework::GradVarName("GTLabel"), {});
return std::unique_ptr<framework::OpDesc>(op);
}
};

File diff suppressed because it is too large Load Diff

@ -409,32 +409,36 @@ def polygon_box_transform(input, name=None):
@templatedoc(op_type="yolov3_loss")
def yolov3_loss(x,
gtbox,
gtlabel,
anchors,
class_num,
ignore_thresh,
lambda_xy=None,
lambda_wh=None,
lambda_conf_obj=None,
lambda_conf_noobj=None,
lambda_class=None,
loss_weight_xy=None,
loss_weight_wh=None,
loss_weight_conf_target=None,
loss_weight_conf_notarget=None,
loss_weight_class=None,
name=None):
"""
${comment}
Args:
x (Variable): ${x_comment}
gtbox (Variable): groud truth boxes, shoulb be in shape of [N, B, 5],
in the third dimenstion, class_id, x, y, w, h should
be stored and x, y, w, h should be relative valud of
input image.
gtbox (Variable): groud truth boxes, should be in shape of [N, B, 4],
in the third dimenstion, x, y, w, h should be stored
and x, y, w, h should be relative value of input image.
N is the batch number and B is the max box number in
an image.
gtlabel (Variable): class id of ground truth boxes, shoud be ins shape
of [N, B].
anchors (list|tuple): ${anchors_comment}
class_num (int): ${class_num_comment}
ignore_thresh (float): ${ignore_thresh_comment}
lambda_xy (float|None): ${lambda_xy_comment}
lambda_wh (float|None): ${lambda_wh_comment}
lambda_conf_obj (float|None): ${lambda_conf_obj_comment}
lambda_conf_noobj (float|None): ${lambda_conf_noobj_comment}
lambda_class (float|None): ${lambda_class_comment}
loss_weight_xy (float|None): ${loss_weight_xy_comment}
loss_weight_wh (float|None): ${loss_weight_wh_comment}
loss_weight_conf_target (float|None): ${loss_weight_conf_target_comment}
loss_weight_conf_notarget (float|None): ${loss_weight_conf_notarget_comment}
loss_weight_class (float|None): ${loss_weight_class_comment}
name (string): the name of yolov3 loss
Returns:
@ -443,6 +447,7 @@ def yolov3_loss(x,
Raises:
TypeError: Input x of yolov3_loss must be Variable
TypeError: Input gtbox of yolov3_loss must be Variable"
TypeError: Input gtlabel of yolov3_loss must be Variable"
TypeError: Attr anchors of yolov3_loss must be list or tuple
TypeError: Attr class_num of yolov3_loss must be an integer
TypeError: Attr ignore_thresh of yolov3_loss must be a float number
@ -450,8 +455,9 @@ def yolov3_loss(x,
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[10, 255, 13, 13], dtype='float32')
gtbox = fluid.layers.data(name='gtbox', shape=[10, 6, 5], dtype='float32')
x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32')
gtbox = fluid.layers.data(name='gtbox', shape=[6, 5], dtype='float32')
gtlabel = fluid.layers.data(name='gtlabel', shape=[6, 1], dtype='int32')
anchors = [10, 13, 16, 30, 33, 23]
loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80
anchors=anchors, ignore_thresh=0.5)
@ -462,6 +468,8 @@ def yolov3_loss(x,
raise TypeError("Input x of yolov3_loss must be Variable")
if not isinstance(gtbox, Variable):
raise TypeError("Input gtbox of yolov3_loss must be Variable")
if not isinstance(gtlabel, Variable):
raise TypeError("Input gtlabel of yolov3_loss must be Variable")
if not isinstance(anchors, list) and not isinstance(anchors, tuple):
raise TypeError("Attr anchors of yolov3_loss must be list or tuple")
if not isinstance(class_num, int):
@ -482,21 +490,24 @@ def yolov3_loss(x,
"ignore_thresh": ignore_thresh,
}
if lambda_xy is not None and isinstance(lambda_xy, float):
self.attrs['lambda_xy'] = lambda_xy
if lambda_wh is not None and isinstance(lambda_wh, float):
self.attrs['lambda_wh'] = lambda_wh
if lambda_conf_obj is not None and isinstance(lambda_conf_obj, float):
self.attrs['lambda_conf_obj'] = lambda_conf_obj
if lambda_conf_noobj is not None and isinstance(lambda_conf_noobj, float):
self.attrs['lambda_conf_noobj'] = lambda_conf_noobj
if lambda_class is not None and isinstance(lambda_class, float):
self.attrs['lambda_class'] = lambda_class
if loss_weight_xy is not None and isinstance(loss_weight_xy, float):
self.attrs['loss_weight_xy'] = loss_weight_xy
if loss_weight_wh is not None and isinstance(loss_weight_wh, float):
self.attrs['loss_weight_wh'] = loss_weight_wh
if loss_weight_conf_target is not None and isinstance(
loss_weight_conf_target, float):
self.attrs['loss_weight_conf_target'] = loss_weight_conf_target
if loss_weight_conf_notarget is not None and isinstance(
loss_weight_conf_notarget, float):
self.attrs['loss_weight_conf_notarget'] = loss_weight_conf_notarget
if loss_weight_class is not None and isinstance(loss_weight_class, float):
self.attrs['loss_weight_class'] = loss_weight_class
helper.append_op(
type='yolov3_loss',
inputs={'X': x,
"GTBox": gtbox},
inputs={"X": x,
"GTBox": gtbox,
"GTLabel": gtlabel},
outputs={'Loss': loss},
attrs=attrs)
return loss

@ -366,5 +366,18 @@ class TestGenerateProposals(unittest.TestCase):
print(rpn_rois.shape)
class TestYoloDetection(unittest.TestCase):
def test_yolov3_loss(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
gtbox = layers.data(name='gtbox', shape=[10, 4], dtype='float32')
gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32')
loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], 10,
0.5)
self.assertIsNotNone(loss)
if __name__ == '__main__':
unittest.main()

@ -911,15 +911,6 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(data_1)
print(str(program))
def test_yolov3_loss(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
gtbox = layers.data(name='gtbox', shape=[10, 5], dtype='float32')
loss = layers.yolov3_loss(x, gtbox, [10, 13, 30, 13], 10, 0.5)
self.assertIsNotNone(loss)
def test_bilinear_tensor_product_layer(self):
program = Program()
with program_guard(program):

@ -66,7 +66,7 @@ def box_iou(box1, box2):
return inter_area / (b1_area + b2_area + inter_area)
def build_target(gtboxs, attrs, grid_size):
def build_target(gtboxs, gtlabel, attrs, grid_size):
n, b, _ = gtboxs.shape
ignore_thresh = attrs["ignore_thresh"]
anchors = attrs["anchors"]
@ -87,11 +87,11 @@ def build_target(gtboxs, attrs, grid_size):
if gtboxs[i, j, :].sum() == 0:
continue
gt_label = int(gtboxs[i, j, 0])
gx = gtboxs[i, j, 1] * grid_size
gy = gtboxs[i, j, 2] * grid_size
gw = gtboxs[i, j, 3] * grid_size
gh = gtboxs[i, j, 4] * grid_size
gt_label = gtlabel[i, j]
gx = gtboxs[i, j, 0] * grid_size
gy = gtboxs[i, j, 1] * grid_size
gw = gtboxs[i, j, 2] * grid_size
gh = gtboxs[i, j, 3] * grid_size
gi = int(gx)
gj = int(gy)
@ -121,7 +121,7 @@ def build_target(gtboxs, attrs, grid_size):
return (tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask)
def YoloV3Loss(x, gtbox, attrs):
def YoloV3Loss(x, gtbox, gtlabel, attrs):
n, c, h, w = x.shape
an_num = len(attrs['anchors']) // 2
class_num = attrs["class_num"]
@ -134,7 +134,7 @@ def YoloV3Loss(x, gtbox, attrs):
pred_cls = sigmoid(x[:, :, :, :, 5:])
tx, ty, tw, th, tconf, tcls, obj_mask, noobj_mask = build_target(
gtbox, attrs, x.shape[2])
gtbox, gtlabel, attrs, x.shape[2])
obj_mask_expand = np.tile(
np.expand_dims(obj_mask, 4), (1, 1, 1, 1, int(attrs['class_num'])))
@ -142,73 +142,73 @@ def YoloV3Loss(x, gtbox, attrs):
loss_y = mse(pred_y * obj_mask, ty * obj_mask, obj_mask.sum())
loss_w = mse(pred_w * obj_mask, tw * obj_mask, obj_mask.sum())
loss_h = mse(pred_h * obj_mask, th * obj_mask, obj_mask.sum())
loss_conf_obj = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask)
loss_conf_noobj = bce(pred_conf * noobj_mask, tconf * noobj_mask,
noobj_mask)
loss_conf_target = bce(pred_conf * obj_mask, tconf * obj_mask, obj_mask)
loss_conf_notarget = bce(pred_conf * noobj_mask, tconf * noobj_mask,
noobj_mask)
loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand,
obj_mask_expand)
return attrs['lambda_xy'] * (loss_x + loss_y) \
+ attrs['lambda_wh'] * (loss_w + loss_h) \
+ attrs['lambda_conf_obj'] * loss_conf_obj \
+ attrs['lambda_conf_noobj'] * loss_conf_noobj \
+ attrs['lambda_class'] * loss_class
return attrs['loss_weight_xy'] * (loss_x + loss_y) \
+ attrs['loss_weight_wh'] * (loss_w + loss_h) \
+ attrs['loss_weight_conf_target'] * loss_conf_target \
+ attrs['loss_weight_conf_notarget'] * loss_conf_notarget \
+ attrs['loss_weight_class'] * loss_class
class TestYolov3LossOp(OpTest):
def setUp(self):
self.lambda_xy = 1.0
self.lambda_wh = 1.0
self.lambda_conf_obj = 1.0
self.lambda_conf_noobj = 1.0
self.lambda_class = 1.0
self.loss_weight_xy = 1.0
self.loss_weight_wh = 1.0
self.loss_weight_conf_target = 1.0
self.loss_weight_conf_notarget = 1.0
self.loss_weight_class = 1.0
self.initTestCase()
self.op_type = 'yolov3_loss'
x = np.random.random(size=self.x_shape).astype('float32')
gtbox = np.random.random(size=self.gtbox_shape).astype('float32')
gtbox[:, :, 0] = np.random.randint(0, self.class_num,
self.gtbox_shape[:2])
gtlabel = np.random.randint(0, self.class_num,
self.gtbox_shape[:2]).astype('int32')
self.attrs = {
"anchors": self.anchors,
"class_num": self.class_num,
"ignore_thresh": self.ignore_thresh,
"lambda_xy": self.lambda_xy,
"lambda_wh": self.lambda_wh,
"lambda_conf_obj": self.lambda_conf_obj,
"lambda_conf_noobj": self.lambda_conf_noobj,
"lambda_class": self.lambda_class,
"loss_weight_xy": self.loss_weight_xy,
"loss_weight_wh": self.loss_weight_wh,
"loss_weight_conf_target": self.loss_weight_conf_target,
"loss_weight_conf_notarget": self.loss_weight_conf_notarget,
"loss_weight_class": self.loss_weight_class,
}
self.inputs = {'X': x, 'GTBox': gtbox}
self.inputs = {'X': x, 'GTBox': gtbox, 'GTLabel': gtlabel}
self.outputs = {
'Loss':
np.array([YoloV3Loss(x, gtbox, self.attrs)]).astype('float32')
'Loss': np.array(
[YoloV3Loss(x, gtbox, gtlabel, self.attrs)]).astype('float32')
}
def test_check_output(self):
place = core.CPUPlace()
self.check_output_with_place(place, atol=1e-3)
# def test_check_grad_ignore_gtbox(self):
# place = core.CPUPlace()
# self.check_grad_with_place(
# place, ['X'],
# 'Loss',
# no_grad_set=set("GTBox"),
# max_relative_error=0.06)
def test_check_grad_ignore_gtbox(self):
place = core.CPUPlace()
self.check_grad_with_place(
place, ['X'],
'Loss',
no_grad_set=set("GTBox"),
max_relative_error=0.06)
def initTestCase(self):
self.anchors = [10, 13, 12, 12]
self.class_num = 10
self.ignore_thresh = 0.5
self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7)
self.gtbox_shape = (5, 5, 5)
self.lambda_xy = 2.5
self.lambda_wh = 0.8
self.lambda_conf_obj = 1.5
self.lambda_conf_noobj = 0.5
self.lambda_class = 1.2
self.gtbox_shape = (5, 10, 4)
self.loss_weight_xy = 2.5
self.loss_weight_wh = 0.8
self.loss_weight_conf_target = 1.5
self.loss_weight_conf_notarget = 0.5
self.loss_weight_class = 1.2
if __name__ == "__main__":

Loading…
Cancel
Save