use typename DeviceContext. test=develop

revert-15296-async_double_buffered_py_reader
dengkaipeng 6 years ago
parent 0c4acc8305
commit 577a92d992

@ -204,7 +204,11 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker,
ops::Yolov3LossGradMaker);
REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad);
REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel<float>,
ops::Yolov3LossKernel<double>);
REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel<float>,
ops::Yolov3LossGradKernel<double>);
REGISTER_OP_CPU_KERNEL(
yolov3_loss,
ops::Yolov3LossKernel<paddle::platform::CPUDeviceContext, float>,
ops::Yolov3LossKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
yolov3_loss_grad,
ops::Yolov3LossGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Yolov3LossGradKernel<paddle::platform::CPUDeviceContext, double>);

File diff suppressed because it is too large Load Diff

@ -197,12 +197,12 @@ class TestYolov3LossOp(OpTest):
max_relative_error=0.31)
def initTestCase(self):
self.anchors = [12, 12, 11, 13]
self.anchors = [12, 12]
self.class_num = 5
self.ignore_thresh = 0.5
self.input_size = 416
self.x_shape = (3, len(self.anchors) // 2 * (5 + self.class_num), 5, 5)
self.gtbox_shape = (3, 5, 4)
self.x_shape = (1, len(self.anchors) // 2 * (5 + self.class_num), 3, 3)
self.gtbox_shape = (1, 5, 4)
if __name__ == "__main__":

Loading…
Cancel
Save