|
|
|
@ -67,10 +67,6 @@ class OperatorBase {
|
|
|
|
|
OperatorBase(const std::string& type, const VarNameMap& inputs,
|
|
|
|
|
const VarNameMap& outputs, const AttributeMap& attrs);
|
|
|
|
|
|
|
|
|
|
OperatorBase(const OperatorBase& o) = delete;
|
|
|
|
|
OperatorBase& operator=(const OperatorBase& o) = delete;
|
|
|
|
|
OperatorBase(OperatorBase&& o) = delete;
|
|
|
|
|
|
|
|
|
|
virtual ~OperatorBase() {}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -116,10 +112,14 @@ class OperatorBase {
|
|
|
|
|
void SetType(const std::string& type) { type_ = type; }
|
|
|
|
|
const AttributeMap& Attrs() const { return attrs_; }
|
|
|
|
|
|
|
|
|
|
// Return a new operator instance, which is as same as this.
|
|
|
|
|
// Use unique_ptr to prevent caller forget to delete this pointer.
|
|
|
|
|
virtual std::unique_ptr<OperatorBase> Clone() const = 0;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::string type_;
|
|
|
|
|
// NOTE: in case of OpGrad, inputs_ contains:
|
|
|
|
|
// I (Inputs)
|
|
|
|
|
// I (Inputs)opear
|
|
|
|
|
// O (Outputs)
|
|
|
|
|
// OG (Output Gradients)
|
|
|
|
|
VarNameMap inputs_;
|
|
|
|
@ -130,12 +130,32 @@ class OperatorBase {
|
|
|
|
|
AttributeMap attrs_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Macro for define a clone method.
|
|
|
|
|
// If you are writing an kernel operator, `Clone` will be defined when you
|
|
|
|
|
// register it. i.e. `Clone` method is not needed to define by yourself.
|
|
|
|
|
#define DEFINE_OP_CLONE_METHOD(CLS) \
|
|
|
|
|
std::unique_ptr<OperatorBase> Clone() const final { \
|
|
|
|
|
return std::unique_ptr<OperatorBase>(new CLS(*this)); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Macro for define a default constructor for Operator.
|
|
|
|
|
// You can also use
|
|
|
|
|
// using PARENT_CLASS::PARENT_CLASS;
|
|
|
|
|
// to use parent's constructor.
|
|
|
|
|
#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \
|
|
|
|
|
CLS(const std::string& type, const VarNameMap& inputs, \
|
|
|
|
|
const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \
|
|
|
|
|
: PARENT_CLS(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
class NOP : public OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
using OperatorBase::OperatorBase;
|
|
|
|
|
void InferShape(const Scope& scope) const override {}
|
|
|
|
|
void Run(const Scope& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const override {}
|
|
|
|
|
std::unique_ptr<OperatorBase> Clone() const override {
|
|
|
|
|
return std::unique_ptr<OperatorBase>(new NOP(*this));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// this class not only make proto but also init attribute checkers.
|
|
|
|
|