Add operator double support. test=develop

revert-14398-imperative
dengkaipeng 6 years ago
parent f115eb0d1e
commit 8ef6280c03

@ -215,9 +215,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker, REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker,
ops::Yolov3LossGradMaker); ops::Yolov3LossGradMaker);
REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad); REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel<float>,
yolov3_loss, ops::Yolov3LossKernel<double>);
ops::Yolov3LossKernel<paddle::platform::CPUDeviceContext, float>); REGISTER_OP_CPU_KERNEL(yolov3_loss_grad, ops::Yolov3LossGradKernel<float>,
REGISTER_OP_CPU_KERNEL( ops::Yolov3LossGradKernel<double>);
yolov3_loss_grad,
ops::Yolov3LossGradKernel<paddle::platform::CPUDeviceContext, float>);

@ -323,7 +323,7 @@ static void AddAllGradToInputGrad(
} }
} }
template <typename DeviceContext, typename T> template <typename T>
class Yolov3LossKernel : public framework::OpKernel<T> { class Yolov3LossKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
@ -392,7 +392,7 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T>
class Yolov3LossGradKernel : public framework::OpKernel<T> { class Yolov3LossGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {

@ -195,7 +195,7 @@ class TestYolov3LossOp(OpTest):
self.check_grad_with_place( self.check_grad_with_place(
place, ['X'], place, ['X'],
'Loss', 'Loss',
no_grad_set=set("GTBox"), no_grad_set=set(["GTBox", "GTLabel"]),
max_relative_error=0.06) max_relative_error=0.06)
def initTestCase(self): def initTestCase(self):

Loading…
Cancel
Save