You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/framework/op_registry_test.cc

374 lines
12 KiB

/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
namespace pd = paddle::framework;
namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void Run(const Scope& scope, const platform::Place& place) const override {}
};
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.GreaterThan(0.0);
AddComment("This is cos op");
}
};
class MyTestOp : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void Run(const Scope& scope, const platform::Place& place) const override {}
};
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op").AsDuplicable();
AddOutput("output", "output of cosine op").AsIntermediate();
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddComment("This is my_test op");
}
};
} // namespace framework
} // namespace paddle
static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
var->add_arguments(arg_name);
}
}
REGISTER_OP_WITHOUT_GRADIENT(cos_sim, paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP_WITHOUT_GRADIENT(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker);
TEST(OpRegistry, CreateOp) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(scale);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope;
paddle::platform::CPUPlace cpu_place;
op->Run(scope, cpu_place);
float scale_get = op->Attr<float>("scale");
ASSERT_EQ(scale_get, scale);
}
TEST(OpRegistry, IllegalAttr) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(-2.0);
bool caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "larger_than check fail";
const char* err_msg = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(err_msg[i], msg[i]);
}
}
ASSERT_TRUE(caught);
}
TEST(OpRegistry, DefaultValue) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
ASSERT_TRUE(op_desc.IsInitialized());
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope;
paddle::platform::CPUPlace cpu_place;
op->Run(scope, cpu_place);
ASSERT_EQ(op->Attr<float>("scale"), 1.0);
}
TEST(OpRegistry, CustomChecker) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("my_test_op");
BuildVar("input", {"ii"}, op_desc.add_inputs());
BuildVar("output", {"oo"}, op_desc.add_outputs());
// attr 'test_attr' is not set
bool caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "Attribute 'test_attr' is required!";
const char* err_msg = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(err_msg[i], msg[i]);
}
}
ASSERT_TRUE(caught);
// set 'test_attr' set to an illegal value
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("test_attr");
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(3);
caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "'test_attr' must be even!";
const char* err_msg = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(err_msg[i], msg[i]);
}
}
ASSERT_TRUE(caught);
// set 'test_attr' set to a legal value
op_desc.mutable_attrs()->Clear();
attr = op_desc.mutable_attrs()->Add();
attr->set_name("test_attr");
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(4);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
op->Run(scope, cpu_place);
int test_attr = op->Attr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}
class CosineOpComplete : public paddle::framework::CosineOp {
public:
DEFINE_OP_CONSTRUCTOR(CosineOpComplete, paddle::framework::CosineOp);
DEFINE_OP_CLONE_METHOD(CosineOpComplete);
};
TEST(OperatorRegistrar, Test) {
using namespace paddle::framework;
OperatorRegistrar<CosineOpComplete, CosineOpProtoAndCheckerMaker> reg("cos");
}
namespace paddle {
namespace framework {
class OpKernelTestMaker : public OpProtoAndCheckerMaker {
public:
OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddComment("NoGradOp, same input output. no Grad");
}
};
class OpWithKernelTest : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(InferShapeContext* ctx) const override {}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
}
};
template <typename DeviceContext, typename T>
class OpKernelTest : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const {}
};
} // namespace framework
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(op_with_kernel,
paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestMaker);
REGISTER_OP_CPU_KERNEL(
op_with_kernel,
paddle::framework::OpKernelTest<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(op_with_kernel,
paddle::framework::OpKernelTest<
paddle::platform::CUDADeviceContext, float>);
TEST(OperatorRegistrar, CPU) {
paddle::framework::proto::OpDesc op_desc;
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
op_desc.set_type("op_with_kernel");
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_place);
}
TEST(OperatorRegistrar, CUDA) {
paddle::framework::proto::OpDesc op_desc;
paddle::platform::CUDAPlace cuda_place(0);
paddle::framework::Scope scope;
op_desc.set_type("op_with_kernel");
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cuda_place);
}
static int op_test_value = 0;
using paddle::platform::DeviceContext;
using paddle::platform::CPUDeviceContext;
using paddle::platform::CUDADeviceContext;
namespace paddle {
namespace framework {
class OpWithMultiKernelTest : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(InferShapeContext* ctx) const override {}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
proto::DataType::FP32, platform::CUDAPlace(0), DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
};
template <typename DeviceContext, typename T>
class OpMultiKernelTest : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const;
};
template <typename T>
class OpMultiKernelTest<CPUDeviceContext, T>
: public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const {
++op_test_value;
}
};
template <typename T>
class OpMultiKernelTest<CUDADeviceContext, T>
: public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const {
--op_test_value;
}
};
template <typename DeviceContext, typename T>
class OpMultiKernelTest2 : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const;
};
template <typename T>
class OpMultiKernelTest2<CPUDeviceContext, T>
: public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const {
op_test_value += 10;
}
};
template <typename T>
class OpMultiKernelTest2<CUDADeviceContext, T>
: public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const {
op_test_value -= 10;
}
};
} // namespace framework
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(op_with_multi_kernel,
paddle::framework::OpWithMultiKernelTest,
paddle::framework::OpKernelTestMaker);
REGISTER_OP_KERNEL(
op_with_multi_kernel, CPU, paddle::platform::CPUPlace,
paddle::framework::OpMultiKernelTest<CPUDeviceContext, float>);
REGISTER_OP_KERNEL(
op_with_multi_kernel, MKLDNN, paddle::platform::CPUPlace,
paddle::framework::OpMultiKernelTest2<CPUDeviceContext, float>);
REGISTER_OP_KERNEL(
op_with_multi_kernel, CUDA, paddle::platform::CUDAPlace,
paddle::framework::OpMultiKernelTest<CUDADeviceContext, float>);
REGISTER_OP_KERNEL(
op_with_multi_kernel, CUDNN, paddle::platform::CUDAPlace,
paddle::framework::OpMultiKernelTest2<CUDADeviceContext, float>);
TEST(OperatorRegistrar, OpWithMultiKernel) {
paddle::framework::proto::OpDesc op_desc;
paddle::platform::CUDAPlace cuda_place(0);
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
op_desc.set_type("op_with_multi_kernel");
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
// TODO(qiao) add priority back
// use all available kernels
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -10);
}