|
|
|
@ -58,12 +58,12 @@ TEST(OpRegistry, CreateOp) {
|
|
|
|
|
paddle::framework::OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("cos_sim");
|
|
|
|
|
auto input = op_desc.add_inputs();
|
|
|
|
|
input->set_op_proto_name("input");
|
|
|
|
|
*input->mutable_var_names()->Add() = "aa";
|
|
|
|
|
input->set_parameter("input");
|
|
|
|
|
*input->mutable_arguments()->Add() = "aa";
|
|
|
|
|
|
|
|
|
|
auto output = op_desc.add_outputs();
|
|
|
|
|
output->set_op_proto_name("output");
|
|
|
|
|
*output->mutable_var_names()->Add() = "bb";
|
|
|
|
|
output->set_parameter("output");
|
|
|
|
|
*output->mutable_arguments()->Add() = "bb";
|
|
|
|
|
|
|
|
|
|
float scale = 3.3;
|
|
|
|
|
auto attr = op_desc.mutable_attrs()->Add();
|
|
|
|
@ -84,12 +84,12 @@ TEST(OpRegistry, IllegalAttr) {
|
|
|
|
|
paddle::framework::OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("cos_sim");
|
|
|
|
|
auto input = op_desc.add_inputs();
|
|
|
|
|
input->set_op_proto_name("input");
|
|
|
|
|
*input->mutable_var_names()->Add() = "aa";
|
|
|
|
|
input->set_parameter("input");
|
|
|
|
|
*input->mutable_arguments()->Add() = "aa";
|
|
|
|
|
|
|
|
|
|
auto output = op_desc.add_outputs();
|
|
|
|
|
output->set_op_proto_name("output");
|
|
|
|
|
*output->mutable_var_names()->Add() = "bb";
|
|
|
|
|
output->set_parameter("output");
|
|
|
|
|
*output->mutable_arguments()->Add() = "bb";
|
|
|
|
|
|
|
|
|
|
auto attr = op_desc.mutable_attrs()->Add();
|
|
|
|
|
attr->set_name("scale");
|
|
|
|
@ -114,12 +114,12 @@ TEST(OpRegistry, DefaultValue) {
|
|
|
|
|
paddle::framework::OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("cos_sim");
|
|
|
|
|
auto input = op_desc.add_inputs();
|
|
|
|
|
input->set_op_proto_name("input");
|
|
|
|
|
*input->mutable_var_names()->Add() = "aa";
|
|
|
|
|
input->set_parameter("input");
|
|
|
|
|
*input->mutable_arguments()->Add() = "aa";
|
|
|
|
|
|
|
|
|
|
auto output = op_desc.add_outputs();
|
|
|
|
|
output->set_op_proto_name("output");
|
|
|
|
|
*output->mutable_var_names()->Add() = "bb";
|
|
|
|
|
output->set_parameter("output");
|
|
|
|
|
*output->mutable_arguments()->Add() = "bb";
|
|
|
|
|
|
|
|
|
|
ASSERT_TRUE(op_desc.IsInitialized());
|
|
|
|
|
|
|
|
|
@ -143,12 +143,12 @@ TEST(OpRegistry, CustomChecker) {
|
|
|
|
|
paddle::framework::OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("my_test_op");
|
|
|
|
|
auto input = op_desc.add_inputs();
|
|
|
|
|
input->set_op_proto_name("input");
|
|
|
|
|
*input->mutable_var_names()->Add() = "ii";
|
|
|
|
|
input->set_parameter("input");
|
|
|
|
|
*input->mutable_arguments()->Add() = "ii";
|
|
|
|
|
|
|
|
|
|
auto output = op_desc.add_outputs();
|
|
|
|
|
output->set_op_proto_name("output");
|
|
|
|
|
*output->mutable_var_names()->Add() = "oo";
|
|
|
|
|
output->set_parameter("output");
|
|
|
|
|
*output->mutable_arguments()->Add() = "oo";
|
|
|
|
|
SetInputFormat(&op_desc);
|
|
|
|
|
|
|
|
|
|
// attr 'test_attr' is not set
|
|
|
|
|