add operator base (#2725)
Add OperatorBase. issue: https://github.com/PaddlePaddle/Paddle/issues/2790 Paddle design the Operator with Kernel. OperatorBase has no type and device information when create, One operator can have multiple kernels, Operator will choose a kernel to run according to context. The kernel should be bind to Operator before or during Operator running.gangliao-patch-1
parent
267f9a2cdf
commit
a2e5f652d3
@ -0,0 +1,51 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
std::string OperatorBase::DebugString() const {
|
||||
std::stringstream ss;
|
||||
ss << "=================\n";
|
||||
ss << "type = " << desc_.type() << "\n";
|
||||
ss << "inputs = [";
|
||||
for (auto& ipt : inputs_) {
|
||||
ss << ipt << ", ";
|
||||
}
|
||||
ss << "]\n";
|
||||
ss << "outputs = [";
|
||||
for (auto& opt : outputs_) {
|
||||
ss << opt << ", ";
|
||||
}
|
||||
ss << "]\n";
|
||||
ss << "attr_keys = [";
|
||||
for (auto& attr : attrs_) {
|
||||
ss << attr.first << ", ";
|
||||
}
|
||||
ss << "]\n";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
const Variable* OpRunContext::Input(int index) const {
|
||||
return scope_->GetVariable(op_->inputs_[index]);
|
||||
}
|
||||
|
||||
Variable* OpRunContext::Output(int index) const {
|
||||
return scope_->GetVariable(op_->outputs_[index]);
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,107 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <boost/variant.hpp>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/framework/attr_checker.h"
|
||||
#include "paddle/framework/op_desc.pb.h"
|
||||
#include "paddle/framework/scope.h"
|
||||
#include "paddle/utils/Error.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class OperatorBase;
|
||||
|
||||
class DeviceContext {};
|
||||
|
||||
/**
|
||||
* OpRunContext is the only parameter of Operator's Run function.
|
||||
* Run will get input/output variables, state such as momentum and
|
||||
* device resource such as CUDA stream, cublas handle, etc. from
|
||||
* OpRunContext. User should construct it before run the Operator.
|
||||
*/
|
||||
class OpRunContext {
|
||||
public:
|
||||
OpRunContext(const OperatorBase* op, const std::shared_ptr<Scope> scope,
|
||||
const DeviceContext* device_context)
|
||||
: op_(op), scope_(scope), device_context_(device_context) {}
|
||||
|
||||
const Variable* Input(int index) const;
|
||||
Variable* Output(int index) const;
|
||||
|
||||
public:
|
||||
const OperatorBase* op_;
|
||||
const std::shared_ptr<Scope> scope_;
|
||||
const DeviceContext* device_context_;
|
||||
};
|
||||
|
||||
/**
|
||||
* OperatorBase has the basic element that Net will call to do computation.
|
||||
* Only CreateOperator from OpRegistry will new Operator directly. User
|
||||
* should always construct a proto message OpDesc and call
|
||||
* OpRegistry::CreateOp(op_desc) to get an Operator instance.
|
||||
*/
|
||||
class OperatorBase {
|
||||
public:
|
||||
virtual ~OperatorBase() {}
|
||||
|
||||
template <typename T>
|
||||
inline const T& GetAttr(const std::string& name) const {
|
||||
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
|
||||
name);
|
||||
return boost::get<T>(attrs_.at(name));
|
||||
}
|
||||
|
||||
std::string DebugString() const;
|
||||
|
||||
/// InferShape infer the size of Variables used by this Operator with
|
||||
/// information inside scope
|
||||
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;
|
||||
|
||||
/// Net will call this function to Run an op.
|
||||
virtual void Run(const std::shared_ptr<Scope>& scope,
|
||||
const DeviceContext* dev_ctx) const = 0;
|
||||
|
||||
public:
|
||||
OpDesc desc_;
|
||||
std::vector<std::string> inputs_;
|
||||
std::vector<std::string> outputs_;
|
||||
AttributeMap attrs_;
|
||||
};
|
||||
|
||||
class OperatorWithKernel : public OperatorBase {
|
||||
public:
|
||||
virtual ~OperatorWithKernel() {}
|
||||
|
||||
virtual void InferShape(const std::shared_ptr<Scope>& scope) const {}
|
||||
|
||||
void Run(const std::shared_ptr<Scope>& scope,
|
||||
const DeviceContext* dev_ctx) const {
|
||||
OpRunContext op_ctx(this, scope, dev_ctx);
|
||||
Run(&op_ctx);
|
||||
}
|
||||
|
||||
/// when implement an Op, your should implement this function.
|
||||
/// this function should be moved to OpKernel later
|
||||
virtual void Run(const OpRunContext* context) const = 0;
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,80 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/operator.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class OperatorTest : public OperatorWithKernel {
|
||||
public:
|
||||
void Run(const OpRunContext* ctx) const override {
|
||||
float scale = ctx->op_->GetAttr<float>("scale");
|
||||
PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized");
|
||||
PADDLE_ENFORCE(ctx->Output(0) == nullptr,
|
||||
"Output(1) should not initialized");
|
||||
auto output1 = ctx->scope_->CreateVariable("output1");
|
||||
PADDLE_ENFORCE(output1 != nullptr, "should create output1 from scope");
|
||||
printf("get attr %s = %f\n", "scale", scale);
|
||||
printf("%s\n", DebugString().c_str());
|
||||
}
|
||||
};
|
||||
|
||||
class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
||||
public:
|
||||
OperatorTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("input", "input of test op");
|
||||
AddOutput("output", "output of test op");
|
||||
AddAttr<float>("scale", "scale of cosine op")
|
||||
.SetDefault(1.0)
|
||||
.LargerThan(0.0);
|
||||
AddType("test_operator");
|
||||
AddComment("This is test op");
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator)
|
||||
|
||||
TEST(OperatorBase, DebugString) {
|
||||
OpDesc op_desc;
|
||||
op_desc.set_type("test_operator");
|
||||
std::vector<std::string> inputs = {"IN1", "IN2"};
|
||||
for (auto& input : inputs) {
|
||||
op_desc.add_inputs(input);
|
||||
}
|
||||
std::vector<std::string> outputs = {"OUT1", "OUT2"};
|
||||
for (auto& output : outputs) {
|
||||
op_desc.add_outputs(output);
|
||||
}
|
||||
auto attr = op_desc.mutable_attrs()->Add();
|
||||
attr->set_name("scale");
|
||||
attr->set_type(paddle::framework::AttrType::FLOAT);
|
||||
float scale = 3.14;
|
||||
attr->set_f(scale);
|
||||
|
||||
DeviceContext device_context;
|
||||
auto scope = std::make_shared<Scope>();
|
||||
|
||||
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||
ASSERT_EQ(op->inputs_, inputs);
|
||||
ASSERT_EQ(op->outputs_, outputs);
|
||||
ASSERT_EQ(op->GetAttr<float>("scale"), scale);
|
||||
op->Run(scope, &device_context);
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,5 @@
|
||||
---
|
||||
Language: Cpp
|
||||
BasedOnStyle: Google
|
||||
Standard: Cpp11
|
||||
...
|
@ -0,0 +1,59 @@
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
using namespace paddle::framework;
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class CosineOp : public OperatorWithKernel {
|
||||
public:
|
||||
void Run(const OpRunContext *context) const override {
|
||||
printf("%s\n", DebugString().c_str());
|
||||
}
|
||||
};
|
||||
|
||||
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
||||
public:
|
||||
CosineOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("input", "input of cosine op");
|
||||
AddOutput("output", "output of cosine op");
|
||||
AddAttr<float>("scale", "scale of cosine op")
|
||||
.SetDefault(1.0)
|
||||
.LargerThan(0.0);
|
||||
AddType("cos");
|
||||
AddComment("This is cos op");
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)
|
||||
|
||||
class MyTestOp : public OperatorWithKernel {
|
||||
public:
|
||||
void Run(const OpRunContext *context) const override {
|
||||
printf("%s\n", DebugString().c_str());
|
||||
}
|
||||
};
|
||||
|
||||
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
||||
public:
|
||||
MyTestOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("input", "input of cosine op");
|
||||
AddOutput("output", "output of cosine op");
|
||||
auto my_checker = [](int i) {
|
||||
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
|
||||
};
|
||||
AddAttr<int>("test_attr", "a simple test attribute")
|
||||
.AddCustomChecker(my_checker);
|
||||
AddType("my_test_op");
|
||||
AddComment("This is my_test op");
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)
|
||||
|
||||
} // namespace operators
|
||||
} // namespace operators
|
Loading…
Reference in new issue