|
|
|
@ -7,9 +7,9 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
class CosineOp : public OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
void Run(const ScopePtr& scope,
|
|
|
|
|
void Run(const std::shared_ptr<Scope>& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const override {}
|
|
|
|
|
void InferShape(const ScopePtr& scope) const override {}
|
|
|
|
|
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
|
|
|
@ -27,8 +27,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
|
|
|
|
|
class MyTestOp : public OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
void InferShape(const ScopePtr& scope) const override {}
|
|
|
|
|
void Run(const ScopePtr& scope,
|
|
|
|
|
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
|
|
|
|
|
void Run(const std::shared_ptr<Scope>& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const override {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -67,7 +67,7 @@ TEST(OpRegistry, CreateOp) {
|
|
|
|
|
attr->set_type(paddle::framework::AttrType::FLOAT);
|
|
|
|
|
attr->set_f(scale);
|
|
|
|
|
|
|
|
|
|
paddle::framework::OperatorPtr op =
|
|
|
|
|
std::shared_ptr<paddle::framework::OperatorBase> op =
|
|
|
|
|
paddle::framework::OpRegistry::CreateOp(op_desc);
|
|
|
|
|
auto scope = std::make_shared<paddle::framework::Scope>();
|
|
|
|
|
paddle::platform::CPUDeviceContext dev_ctx;
|
|
|
|
@ -89,7 +89,6 @@ TEST(OpRegistry, IllegalAttr) {
|
|
|
|
|
|
|
|
|
|
bool caught = false;
|
|
|
|
|
try {
|
|
|
|
|
paddle::framework::OperatorPtr op __attribute__((unused)) =
|
|
|
|
|
paddle::framework::OpRegistry::CreateOp(op_desc);
|
|
|
|
|
} catch (std::runtime_error& err) {
|
|
|
|
|
caught = true;
|
|
|
|
@ -110,7 +109,7 @@ TEST(OpRegistry, DefaultValue) {
|
|
|
|
|
|
|
|
|
|
ASSERT_TRUE(op_desc.IsInitialized());
|
|
|
|
|
|
|
|
|
|
paddle::framework::OperatorPtr op =
|
|
|
|
|
std::shared_ptr<paddle::framework::OperatorBase> op =
|
|
|
|
|
paddle::framework::OpRegistry::CreateOp(op_desc);
|
|
|
|
|
auto scope = std::make_shared<paddle::framework::Scope>();
|
|
|
|
|
paddle::platform::CPUDeviceContext dev_ctx;
|
|
|
|
@ -136,7 +135,6 @@ TEST(OpRegistry, CustomChecker) {
|
|
|
|
|
// attr 'test_attr' is not set
|
|
|
|
|
bool caught = false;
|
|
|
|
|
try {
|
|
|
|
|
paddle::framework::OperatorPtr op __attribute__((unused)) =
|
|
|
|
|
paddle::framework::OpRegistry::CreateOp(op_desc);
|
|
|
|
|
} catch (std::runtime_error& err) {
|
|
|
|
|
caught = true;
|
|
|
|
@ -155,7 +153,6 @@ TEST(OpRegistry, CustomChecker) {
|
|
|
|
|
attr->set_i(3);
|
|
|
|
|
caught = false;
|
|
|
|
|
try {
|
|
|
|
|
paddle::framework::OperatorPtr op __attribute__((unused)) =
|
|
|
|
|
paddle::framework::OpRegistry::CreateOp(op_desc);
|
|
|
|
|
} catch (std::runtime_error& err) {
|
|
|
|
|
caught = true;
|
|
|
|
@ -174,8 +171,7 @@ TEST(OpRegistry, CustomChecker) {
|
|
|
|
|
attr->set_type(paddle::framework::AttrType::INT);
|
|
|
|
|
attr->set_i(4);
|
|
|
|
|
SetInputFormat(&op_desc);
|
|
|
|
|
paddle::framework::OperatorPtr op =
|
|
|
|
|
paddle::framework::OpRegistry::CreateOp(op_desc);
|
|
|
|
|
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
|
|
|
|
|
paddle::platform::CPUDeviceContext dev_ctx;
|
|
|
|
|
auto scope = std::make_shared<paddle::framework::Scope>();
|
|
|
|
|
op->Run(scope, dev_ctx);
|
|
|
|
|