Add operator double support. test=develop

revert-14398-imperative
dengkaipeng 6 years ago
parent f115eb0d1e
commit 8ef6280c03

@ -215,9 +215,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker,
ops::Yolov3LossGradMaker);
REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad);
REGISTER_OP_CPU_KERNEL(
yolov3_loss,
ops::Yolov3LossKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
yolov3_loss_grad,
ops::Yolov3LossGradKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel<float>,
ops::Yolov3LossKernel<double>);
REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel<float>,
ops::Yolov3LossGradKernel<double>);

@ -323,7 +323,7 @@ static void AddAllGradToInputGrad(
}
}
template <typename DeviceContext, typename T>
template <typename T>
class Yolov3LossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
@ -392,7 +392,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T>
class Yolov3LossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {

@ -195,7 +195,7 @@ class TestYolov3LossOp(OpTest):
self.check_grad_with_place(
place, ['X'],
'Loss',
no_grad_set=set("GTBox"),
no_grad_set=set(["GTBox", "GTLabel"]),
max_relative_error=0.06)
def initTestCase(self):

Loading…
Cancel
Save