|
|
|
@ -45,7 +45,7 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator)
|
|
|
|
|
REGISTER_OP(test_operator, OperatorTest, OperatorTestProtoAndCheckerMaker);
|
|
|
|
|
|
|
|
|
|
TEST(OperatorBase, all) {
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
@ -69,5 +69,55 @@ TEST(OperatorBase, all) {
|
|
|
|
|
delete op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("input", "input of test op");
|
|
|
|
|
AddOutput("output", "output of test op");
|
|
|
|
|
AddAttr<float>("scale", "scale of cosine op")
|
|
|
|
|
.SetDefault(1.0)
|
|
|
|
|
.LargerThan(0.0);
|
|
|
|
|
AddType("test_operator");
|
|
|
|
|
AddComment("This is test op");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpWithKernelTest : public OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CPUKernelTest : public OpKernel {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const KernelContext& context) const {
|
|
|
|
|
float scale = context.op_.GetAttr<float>("scale");
|
|
|
|
|
ASSERT_NEAR(scale, 3.14, 1e-5);
|
|
|
|
|
std::cout << "this is cpu kernel" << std::endl;
|
|
|
|
|
std::cout << context.op_.DebugString() << std::endl;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
REGISTER_OP(op_with_kernel, OpWithKernelTest, OpKernelTestProtoAndCheckerMaker);
|
|
|
|
|
REGISTER_OP_KERNEL(op_with_kernel, platform::CPUPlace, CPUKernelTest);
|
|
|
|
|
|
|
|
|
|
TEST(OpKernel, all) {
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("op_with_kernel");
|
|
|
|
|
*op_desc.mutable_inputs()->Add() = "IN1";
|
|
|
|
|
*op_desc.mutable_outputs()->Add() = "OUT1";
|
|
|
|
|
auto attr = op_desc.mutable_attrs()->Add();
|
|
|
|
|
attr->set_name("scale");
|
|
|
|
|
attr->set_type(paddle::framework::AttrType::FLOAT);
|
|
|
|
|
attr->set_f(3.14);
|
|
|
|
|
|
|
|
|
|
platform::CPUDeviceContext cpu_device_context;
|
|
|
|
|
auto scope = std::make_shared<Scope>();
|
|
|
|
|
|
|
|
|
|
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
|
|
|
|
|
op->Run(scope, cpu_device_context);
|
|
|
|
|
|
|
|
|
|
delete op;
|
|
|
|
|
}
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|