add operator base (#2725)
Add OperatorBase. issue: https://github.com/PaddlePaddle/Paddle/issues/2790 Paddle design the Operator with Kernel. OperatorBase has no type and device information when create, One operator can have multiple kernels, Operator will choose a kernel to run according to context. The kernel should be bind to Operator before or during Operator running.gangliao-patch-1
parent
267f9a2cdf
commit
a2e5f652d3
@ -0,0 +1,51 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/framework/operator.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
std::string OperatorBase::DebugString() const {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "=================\n";
|
||||||
|
ss << "type = " << desc_.type() << "\n";
|
||||||
|
ss << "inputs = [";
|
||||||
|
for (auto& ipt : inputs_) {
|
||||||
|
ss << ipt << ", ";
|
||||||
|
}
|
||||||
|
ss << "]\n";
|
||||||
|
ss << "outputs = [";
|
||||||
|
for (auto& opt : outputs_) {
|
||||||
|
ss << opt << ", ";
|
||||||
|
}
|
||||||
|
ss << "]\n";
|
||||||
|
ss << "attr_keys = [";
|
||||||
|
for (auto& attr : attrs_) {
|
||||||
|
ss << attr.first << ", ";
|
||||||
|
}
|
||||||
|
ss << "]\n";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Variable* OpRunContext::Input(int index) const {
|
||||||
|
return scope_->GetVariable(op_->inputs_[index]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Variable* OpRunContext::Output(int index) const {
|
||||||
|
return scope_->GetVariable(op_->outputs_[index]);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,107 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <boost/variant.hpp>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/framework/attr_checker.h"
|
||||||
|
#include "paddle/framework/op_desc.pb.h"
|
||||||
|
#include "paddle/framework/scope.h"
|
||||||
|
#include "paddle/utils/Error.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
class OperatorBase;
|
||||||
|
|
||||||
|
class DeviceContext {};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OpRunContext is the only parameter of Operator's Run function.
|
||||||
|
* Run will get input/output variables, state such as momentum and
|
||||||
|
* device resource such as CUDA stream, cublas handle, etc. from
|
||||||
|
* OpRunContext. User should construct it before run the Operator.
|
||||||
|
*/
|
||||||
|
class OpRunContext {
|
||||||
|
public:
|
||||||
|
OpRunContext(const OperatorBase* op, const std::shared_ptr<Scope> scope,
|
||||||
|
const DeviceContext* device_context)
|
||||||
|
: op_(op), scope_(scope), device_context_(device_context) {}
|
||||||
|
|
||||||
|
const Variable* Input(int index) const;
|
||||||
|
Variable* Output(int index) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
const OperatorBase* op_;
|
||||||
|
const std::shared_ptr<Scope> scope_;
|
||||||
|
const DeviceContext* device_context_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OperatorBase has the basic element that Net will call to do computation.
|
||||||
|
* Only CreateOperator from OpRegistry will new Operator directly. User
|
||||||
|
* should always construct a proto message OpDesc and call
|
||||||
|
* OpRegistry::CreateOp(op_desc) to get an Operator instance.
|
||||||
|
*/
|
||||||
|
class OperatorBase {
|
||||||
|
public:
|
||||||
|
virtual ~OperatorBase() {}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline const T& GetAttr(const std::string& name) const {
|
||||||
|
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
|
||||||
|
name);
|
||||||
|
return boost::get<T>(attrs_.at(name));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string DebugString() const;
|
||||||
|
|
||||||
|
/// InferShape infer the size of Variables used by this Operator with
|
||||||
|
/// information inside scope
|
||||||
|
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;
|
||||||
|
|
||||||
|
/// Net will call this function to Run an op.
|
||||||
|
virtual void Run(const std::shared_ptr<Scope>& scope,
|
||||||
|
const DeviceContext* dev_ctx) const = 0;
|
||||||
|
|
||||||
|
public:
|
||||||
|
OpDesc desc_;
|
||||||
|
std::vector<std::string> inputs_;
|
||||||
|
std::vector<std::string> outputs_;
|
||||||
|
AttributeMap attrs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OperatorWithKernel : public OperatorBase {
|
||||||
|
public:
|
||||||
|
virtual ~OperatorWithKernel() {}
|
||||||
|
|
||||||
|
virtual void InferShape(const std::shared_ptr<Scope>& scope) const {}
|
||||||
|
|
||||||
|
void Run(const std::shared_ptr<Scope>& scope,
|
||||||
|
const DeviceContext* dev_ctx) const {
|
||||||
|
OpRunContext op_ctx(this, scope, dev_ctx);
|
||||||
|
Run(&op_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// when implement an Op, your should implement this function.
|
||||||
|
/// this function should be moved to OpKernel later
|
||||||
|
virtual void Run(const OpRunContext* context) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,80 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/framework/operator.h"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
class OperatorTest : public OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
void Run(const OpRunContext* ctx) const override {
|
||||||
|
float scale = ctx->op_->GetAttr<float>("scale");
|
||||||
|
PADDLE_ENFORCE(ctx->Input(0) == nullptr, "Input(0) should not initialized");
|
||||||
|
PADDLE_ENFORCE(ctx->Output(0) == nullptr,
|
||||||
|
"Output(1) should not initialized");
|
||||||
|
auto output1 = ctx->scope_->CreateVariable("output1");
|
||||||
|
PADDLE_ENFORCE(output1 != nullptr, "should create output1 from scope");
|
||||||
|
printf("get attr %s = %f\n", "scale", scale);
|
||||||
|
printf("%s\n", DebugString().c_str());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
OperatorTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddInput("input", "input of test op");
|
||||||
|
AddOutput("output", "output of test op");
|
||||||
|
AddAttr<float>("scale", "scale of cosine op")
|
||||||
|
.SetDefault(1.0)
|
||||||
|
.LargerThan(0.0);
|
||||||
|
AddType("test_operator");
|
||||||
|
AddComment("This is test op");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_OP(OperatorTest, OperatorTestProtoAndCheckerMaker, test_operator)
|
||||||
|
|
||||||
|
TEST(OperatorBase, DebugString) {
|
||||||
|
OpDesc op_desc;
|
||||||
|
op_desc.set_type("test_operator");
|
||||||
|
std::vector<std::string> inputs = {"IN1", "IN2"};
|
||||||
|
for (auto& input : inputs) {
|
||||||
|
op_desc.add_inputs(input);
|
||||||
|
}
|
||||||
|
std::vector<std::string> outputs = {"OUT1", "OUT2"};
|
||||||
|
for (auto& output : outputs) {
|
||||||
|
op_desc.add_outputs(output);
|
||||||
|
}
|
||||||
|
auto attr = op_desc.mutable_attrs()->Add();
|
||||||
|
attr->set_name("scale");
|
||||||
|
attr->set_type(paddle::framework::AttrType::FLOAT);
|
||||||
|
float scale = 3.14;
|
||||||
|
attr->set_f(scale);
|
||||||
|
|
||||||
|
DeviceContext device_context;
|
||||||
|
auto scope = std::make_shared<Scope>();
|
||||||
|
|
||||||
|
OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc);
|
||||||
|
ASSERT_EQ(op->inputs_, inputs);
|
||||||
|
ASSERT_EQ(op->outputs_, outputs);
|
||||||
|
ASSERT_EQ(op->GetAttr<float>("scale"), scale);
|
||||||
|
op->Run(scope, &device_context);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
Language: Cpp
|
||||||
|
BasedOnStyle: Google
|
||||||
|
Standard: Cpp11
|
||||||
|
...
|
@ -0,0 +1,59 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
|
||||||
|
using namespace paddle::framework;
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class CosineOp : public OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
void Run(const OpRunContext *context) const override {
|
||||||
|
printf("%s\n", DebugString().c_str());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
CosineOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddInput("input", "input of cosine op");
|
||||||
|
AddOutput("output", "output of cosine op");
|
||||||
|
AddAttr<float>("scale", "scale of cosine op")
|
||||||
|
.SetDefault(1.0)
|
||||||
|
.LargerThan(0.0);
|
||||||
|
AddType("cos");
|
||||||
|
AddComment("This is cos op");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)
|
||||||
|
|
||||||
|
class MyTestOp : public OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
void Run(const OpRunContext *context) const override {
|
||||||
|
printf("%s\n", DebugString().c_str());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
MyTestOpProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddInput("input", "input of cosine op");
|
||||||
|
AddOutput("output", "output of cosine op");
|
||||||
|
auto my_checker = [](int i) {
|
||||||
|
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
|
||||||
|
};
|
||||||
|
AddAttr<int>("test_attr", "a simple test attribute")
|
||||||
|
.AddCustomChecker(my_checker);
|
||||||
|
AddType("my_test_op");
|
||||||
|
AddComment("This is my_test op");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace operators
|
Loading…
Reference in new issue