|
|
|
|
@ -27,12 +27,12 @@ class OpWithoutKernelTest : public OperatorBase {
|
|
|
|
|
void InferShape(const Scope& scope) const override {}
|
|
|
|
|
void Run(const Scope& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const override {
|
|
|
|
|
op_run_num++;
|
|
|
|
|
ASSERT_EQ((int)inputs_.size(), 1);
|
|
|
|
|
ASSERT_EQ((int)outputs_.size(), 1);
|
|
|
|
|
ASSERT_EQ(scope.FindVar(inputs_[0]), nullptr);
|
|
|
|
|
++op_run_num;
|
|
|
|
|
ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
|
|
|
|
|
ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
|
|
|
|
|
ASSERT_EQ(scope.FindVar(inputs_.at("input")[0]), nullptr);
|
|
|
|
|
ASSERT_EQ(x, 1);
|
|
|
|
|
ASSERT_NE(scope.FindVar(outputs_[0]), nullptr);
|
|
|
|
|
ASSERT_NE(scope.FindVar(outputs_.at("output")[0]), nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
@ -60,8 +60,13 @@ REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest,
|
|
|
|
|
TEST(OperatorBase, all) {
|
|
|
|
|
paddle::framework::OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("test_operator");
|
|
|
|
|
*op_desc.mutable_inputs()->Add() = "IN1";
|
|
|
|
|
*op_desc.mutable_outputs()->Add() = "OUT1";
|
|
|
|
|
auto* ipt = op_desc.mutable_inputs()->Add();
|
|
|
|
|
*ipt->mutable_var_names()->Add() = "IN1";
|
|
|
|
|
ipt->set_op_proto_name("input");
|
|
|
|
|
|
|
|
|
|
auto* output = op_desc.mutable_outputs()->Add();
|
|
|
|
|
*output->mutable_var_names()->Add() = "OUT1";
|
|
|
|
|
output->set_op_proto_name("output");
|
|
|
|
|
auto attr = op_desc.mutable_attrs()->Add();
|
|
|
|
|
attr->set_name("scale");
|
|
|
|
|
attr->set_type(paddle::framework::AttrType::FLOAT);
|
|
|
|
|
@ -113,24 +118,6 @@ class CPUKernelTest : public OpKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// multiple inputs test
|
|
|
|
|
class OperatorMultiInputsTest : public OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
void Init() override { x = 1; }
|
|
|
|
|
void InferShape(const Scope& scope) const override {}
|
|
|
|
|
void Run(const Scope& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const override {
|
|
|
|
|
ASSERT_EQ(scope.FindVar(inputs_[0]), nullptr);
|
|
|
|
|
ASSERT_EQ(x, 1);
|
|
|
|
|
ASSERT_NE(scope.FindVar(outputs_[0]), nullptr);
|
|
|
|
|
ASSERT_EQ(Input("x"), "IN1");
|
|
|
|
|
ASSERT_EQ(Input("y"), "OUT1");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
float x = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpKernelTestMultiInputsProtoAndCheckerMaker
|
|
|
|
|
: public OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
@ -196,8 +183,14 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
|
|
|
|
|
TEST(OpKernel, all) {
|
|
|
|
|
paddle::framework::OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("op_with_kernel");
|
|
|
|
|
*op_desc.mutable_inputs()->Add() = "IN1";
|
|
|
|
|
*op_desc.mutable_outputs()->Add() = "OUT1";
|
|
|
|
|
auto* ipt = op_desc.mutable_inputs()->Add();
|
|
|
|
|
*ipt->mutable_var_names()->Add() = "IN1";
|
|
|
|
|
ipt->set_op_proto_name("input");
|
|
|
|
|
|
|
|
|
|
auto* output = op_desc.mutable_outputs()->Add();
|
|
|
|
|
*output->mutable_var_names()->Add() = "OUT1";
|
|
|
|
|
output->set_op_proto_name("output");
|
|
|
|
|
|
|
|
|
|
auto attr = op_desc.mutable_attrs()->Add();
|
|
|
|
|
attr->set_name("scale");
|
|
|
|
|
attr->set_type(paddle::framework::AttrType::FLOAT);
|
|
|
|
|
@ -223,12 +216,19 @@ TEST(OpKernel, multi_inputs) {
|
|
|
|
|
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("op_multi_inputs_with_kernel");
|
|
|
|
|
*op_desc.mutable_inputs()->Add() = "x0";
|
|
|
|
|
*op_desc.mutable_inputs()->Add() = "x1";
|
|
|
|
|
*op_desc.mutable_inputs()->Add() = "x2";
|
|
|
|
|
*op_desc.mutable_inputs()->Add() = "k0";
|
|
|
|
|
*op_desc.mutable_outputs()->Add() = "y0";
|
|
|
|
|
*op_desc.mutable_outputs()->Add() = "y1";
|
|
|
|
|
auto x = op_desc.mutable_inputs()->Add();
|
|
|
|
|
x->set_op_proto_name("xs");
|
|
|
|
|
*x->mutable_var_names()->Add() = "x0";
|
|
|
|
|
*x->mutable_var_names()->Add() = "x1";
|
|
|
|
|
*x->mutable_var_names()->Add() = "x2";
|
|
|
|
|
auto k = op_desc.mutable_inputs()->Add();
|
|
|
|
|
k->set_op_proto_name("k");
|
|
|
|
|
*k->mutable_var_names()->Add() = "k0";
|
|
|
|
|
auto y = op_desc.mutable_outputs()->Add();
|
|
|
|
|
y->set_op_proto_name("ys");
|
|
|
|
|
*y->mutable_var_names()->Add() = "y0";
|
|
|
|
|
*y->mutable_var_names()->Add() = "y1";
|
|
|
|
|
|
|
|
|
|
auto attr = op_desc.mutable_attrs()->Add();
|
|
|
|
|
attr->set_name("scale");
|
|
|
|
|
attr->set_type(paddle::framework::AttrType::FLOAT);
|
|
|
|
|
|