|
|
|
@ -11,11 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
#include "gtest/gtest.h"
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/init.h"
|
|
|
|
|
#include "paddle/framework/op_info.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
@ -27,8 +28,7 @@ class OpWithoutKernelTest : public OperatorBase {
|
|
|
|
|
OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs,
|
|
|
|
|
const VariableNameMap& outputs, const AttributeMap& attrs)
|
|
|
|
|
: OperatorBase(type, inputs, outputs, attrs), x(1) {}
|
|
|
|
|
void Run(const Scope& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const override {
|
|
|
|
|
void Run(const Scope& scope, const platform::Place& place) const override {
|
|
|
|
|
++op_run_num;
|
|
|
|
|
ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
|
|
|
|
|
ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
|
|
|
|
@ -41,10 +41,9 @@ class OpWithoutKernelTest : public OperatorBase {
|
|
|
|
|
int x{0};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
OpeWithoutKernelTestProtoAndCheckerMaker(OpProto* proto,
|
|
|
|
|
OpAttrChecker* op_checker)
|
|
|
|
|
OpWithoutKernelCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("input", "input of test op");
|
|
|
|
|
AddOutput("output", "output of test op");
|
|
|
|
@ -65,11 +64,12 @@ static void BuildVar(const std::string& param_name,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(
|
|
|
|
|
test_operator, paddle::framework::OpWithoutKernelTest,
|
|
|
|
|
paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker);
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(test_operator,
|
|
|
|
|
paddle::framework::OpWithoutKernelTest,
|
|
|
|
|
paddle::framework::OpWithoutKernelCheckerMaker);
|
|
|
|
|
|
|
|
|
|
TEST(OperatorBase, all) {
|
|
|
|
|
paddle::framework::InitDevices({"CPU"});
|
|
|
|
|
paddle::framework::proto::OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("test_operator");
|
|
|
|
|
BuildVar("input", {"IN1"}, op_desc.add_inputs());
|
|
|
|
@ -80,13 +80,13 @@ TEST(OperatorBase, all) {
|
|
|
|
|
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
|
|
|
|
|
attr->set_f(3.14);
|
|
|
|
|
|
|
|
|
|
paddle::platform::CPUDeviceContext device_context;
|
|
|
|
|
paddle::platform::CPUPlace cpu_place;
|
|
|
|
|
paddle::framework::Scope scope;
|
|
|
|
|
|
|
|
|
|
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
|
|
|
|
|
scope.Var("OUT1");
|
|
|
|
|
ASSERT_EQ(paddle::framework::op_run_num, 0);
|
|
|
|
|
op->Run(scope, device_context);
|
|
|
|
|
op->Run(scope, cpu_place);
|
|
|
|
|
ASSERT_EQ(paddle::framework::op_run_num, 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -123,7 +123,6 @@ template <typename T1, typename T2>
|
|
|
|
|
class CPUKernelTest : public OpKernel<float> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const ExecutionContext& ctx) const {
|
|
|
|
|
std::cout << "this is cpu kernel" << std::endl;
|
|
|
|
|
std::cout << ctx.op().DebugString() << std::endl;
|
|
|
|
|
cpu_kernel_run_num++;
|
|
|
|
|
ASSERT_EQ(ctx.op().Input("x"), "IN1");
|
|
|
|
@ -195,6 +194,7 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
|
|
|
|
|
|
|
|
|
|
// test with single input
|
|
|
|
|
TEST(OpKernel, all) {
|
|
|
|
|
paddle::framework::InitDevices({"CPU"});
|
|
|
|
|
paddle::framework::proto::OpDesc op_desc;
|
|
|
|
|
op_desc.set_type("op_with_kernel");
|
|
|
|
|
BuildVar("x", {"IN1"}, op_desc.add_inputs());
|
|
|
|
@ -205,12 +205,12 @@ TEST(OpKernel, all) {
|
|
|
|
|
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
|
|
|
|
|
attr->set_f(3.14);
|
|
|
|
|
|
|
|
|
|
paddle::platform::CPUDeviceContext cpu_device_context;
|
|
|
|
|
paddle::platform::CPUPlace cpu_place;
|
|
|
|
|
paddle::framework::Scope scope;
|
|
|
|
|
|
|
|
|
|
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
|
|
|
|
|
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
|
|
|
|
|
op->Run(scope, cpu_device_context);
|
|
|
|
|
op->Run(scope, cpu_place);
|
|
|
|
|
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -224,7 +224,9 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
|
|
|
|
|
TEST(OpKernel, multi_inputs) {
|
|
|
|
|
using namespace paddle::framework;
|
|
|
|
|
|
|
|
|
|
paddle::framework::InitDevices({"CPU"});
|
|
|
|
|
proto::OpDesc op_desc;
|
|
|
|
|
|
|
|
|
|
op_desc.set_type("op_multi_inputs_with_kernel");
|
|
|
|
|
BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
|
|
|
|
|
BuildVar("k", {"k0"}, op_desc.add_inputs());
|
|
|
|
@ -235,7 +237,7 @@ TEST(OpKernel, multi_inputs) {
|
|
|
|
|
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
|
|
|
|
|
attr->set_f(3.14);
|
|
|
|
|
|
|
|
|
|
paddle::platform::CPUDeviceContext cpu_device_context;
|
|
|
|
|
paddle::platform::CPUPlace cpu_place;
|
|
|
|
|
paddle::framework::Scope scope;
|
|
|
|
|
scope.Var("x0")->GetMutable<LoDTensor>();
|
|
|
|
|
scope.Var("x1")->GetMutable<LoDTensor>();
|
|
|
|
@ -245,7 +247,7 @@ TEST(OpKernel, multi_inputs) {
|
|
|
|
|
scope.Var("y1")->GetMutable<LoDTensor>();
|
|
|
|
|
|
|
|
|
|
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
|
|
|
|
|
op->Run(scope, cpu_device_context);
|
|
|
|
|
op->Run(scope, cpu_place);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class OperatorClone : public paddle::framework::OperatorBase {
|
|
|
|
@ -257,10 +259,11 @@ class OperatorClone : public paddle::framework::OperatorBase {
|
|
|
|
|
const paddle::framework::AttributeMap& attrs)
|
|
|
|
|
: OperatorBase(type, inputs, outputs, attrs) {}
|
|
|
|
|
void Run(const paddle::framework::Scope& scope,
|
|
|
|
|
const paddle::platform::DeviceContext& dev_ctx) const override {}
|
|
|
|
|
const paddle::platform::Place& place) const override {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
TEST(Operator, Clone) {
|
|
|
|
|
paddle::framework::InitDevices({"CPU"});
|
|
|
|
|
OperatorClone a("ABC", paddle::framework::VariableNameMap{},
|
|
|
|
|
paddle::framework::VariableNameMap{},
|
|
|
|
|
paddle::framework::AttributeMap{});
|
|
|
|
|