|
|
|
@ -38,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("input", "input of cosine op").SetMultiple();
|
|
|
|
|
AddOutput("output", "output of cosine op").SetTemporary();
|
|
|
|
|
AddInput("input", "input of cosine op").SetDuplicable();
|
|
|
|
|
AddOutput("output", "output of cosine op").SetIntermediate();
|
|
|
|
|
auto my_checker = [](int i) {
|
|
|
|
|
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
|
|
|
|
|
};
|
|
|
|
@ -51,6 +51,15 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
static void ConstructVars(const std::string& param_name,
|
|
|
|
|
std::initializer_list<const char*> arguments,
|
|
|
|
|
paddle::framework::OpDesc::Var* var) {
|
|
|
|
|
var->set_parameter(param_name);
|
|
|
|
|
for (auto& arg_name : arguments) {
|
|
|
|
|
*var->mutable_arguments()->Add() = arg_name;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REGISTER_OP(cos_sim, paddle::framework::CosineOp,
|
|
|
|
|
paddle::framework::CosineOpProtoAndCheckerMaker);
|
|
|
|
|
REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
|
|
|
|
@ -59,13 +68,11 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
|
|
|
|
|
TEST(OpRegistry, CreateOp) {
|
|
|
|
|
paddle::framework::OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("cos_sim");
|
|
|
|
|
auto input = op_desc.add_inputs();
|
|
|
|
|
input->set_parameter("input");
|
|
|
|
|
*input->mutable_arguments()->Add() = "aa";
|
|
|
|
|
auto* input = op_desc.add_inputs();
|
|
|
|
|
ConstructVars("input", {"aa"}, input);
|
|
|
|
|
|
|
|
|
|
auto output = op_desc.add_outputs();
|
|
|
|
|
output->set_parameter("output");
|
|
|
|
|
*output->mutable_arguments()->Add() = "bb";
|
|
|
|
|
auto* output = op_desc.add_outputs();
|
|
|
|
|
ConstructVars("output", {"bb"}, output);
|
|
|
|
|
|
|
|
|
|
float scale = 3.3;
|
|
|
|
|
auto attr = op_desc.mutable_attrs()->Add();
|
|
|
|
@ -85,13 +92,11 @@ TEST(OpRegistry, CreateOp) {
|
|
|
|
|
TEST(OpRegistry, IllegalAttr) {
|
|
|
|
|
paddle::framework::OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("cos_sim");
|
|
|
|
|
auto input = op_desc.add_inputs();
|
|
|
|
|
input->set_parameter("input");
|
|
|
|
|
*input->mutable_arguments()->Add() = "aa";
|
|
|
|
|
auto* input = op_desc.add_inputs();
|
|
|
|
|
ConstructVars("input", {"aa"}, input);
|
|
|
|
|
|
|
|
|
|
auto output = op_desc.add_outputs();
|
|
|
|
|
output->set_parameter("output");
|
|
|
|
|
*output->mutable_arguments()->Add() = "bb";
|
|
|
|
|
auto* output = op_desc.add_outputs();
|
|
|
|
|
ConstructVars("output", {"bb"}, output);
|
|
|
|
|
|
|
|
|
|
auto attr = op_desc.mutable_attrs()->Add();
|
|
|
|
|
attr->set_name("scale");
|
|
|
|
@ -115,13 +120,11 @@ TEST(OpRegistry, IllegalAttr) {
|
|
|
|
|
TEST(OpRegistry, DefaultValue) {
|
|
|
|
|
paddle::framework::OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("cos_sim");
|
|
|
|
|
auto input = op_desc.add_inputs();
|
|
|
|
|
input->set_parameter("input");
|
|
|
|
|
*input->mutable_arguments()->Add() = "aa";
|
|
|
|
|
auto* input = op_desc.add_inputs();
|
|
|
|
|
ConstructVars("input", {"aa"}, input);
|
|
|
|
|
|
|
|
|
|
auto output = op_desc.add_outputs();
|
|
|
|
|
output->set_parameter("output");
|
|
|
|
|
*output->mutable_arguments()->Add() = "bb";
|
|
|
|
|
auto* output = op_desc.add_outputs();
|
|
|
|
|
ConstructVars("output", {"bb"}, output);
|
|
|
|
|
|
|
|
|
|
ASSERT_TRUE(op_desc.IsInitialized());
|
|
|
|
|
|
|
|
|
@ -136,13 +139,11 @@ TEST(OpRegistry, DefaultValue) {
|
|
|
|
|
TEST(OpRegistry, CustomChecker) {
|
|
|
|
|
paddle::framework::OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("my_test_op");
|
|
|
|
|
auto input = op_desc.add_inputs();
|
|
|
|
|
input->set_parameter("input");
|
|
|
|
|
*input->mutable_arguments()->Add() = "ii";
|
|
|
|
|
auto* input = op_desc.add_inputs();
|
|
|
|
|
ConstructVars("input", {"ii"}, input);
|
|
|
|
|
|
|
|
|
|
auto output = op_desc.add_outputs();
|
|
|
|
|
output->set_parameter("output");
|
|
|
|
|
*output->mutable_arguments()->Add() = "oo";
|
|
|
|
|
auto* output = op_desc.add_outputs();
|
|
|
|
|
ConstructVars("output", {"oo"}, output);
|
|
|
|
|
|
|
|
|
|
// attr 'test_attr' is not set
|
|
|
|
|
bool caught = false;
|
|
|
|
|