|
|
|
@ -64,6 +64,17 @@ class ExecutionContext;
|
|
|
|
|
*/
|
|
|
|
|
class OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
using VarNameMap = std::map<std::string, std::vector<std::string>>;
|
|
|
|
|
|
|
|
|
|
OperatorBase() = default;
|
|
|
|
|
OperatorBase(const std::string& type, const VarNameMap& inputs,
|
|
|
|
|
const VarNameMap& outputs, const AttributeMap& attrs)
|
|
|
|
|
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
|
|
|
|
|
|
|
|
|
|
OperatorBase(const OperatorBase& o) = delete;
|
|
|
|
|
OperatorBase& operator=(const OperatorBase& o) = delete;
|
|
|
|
|
OperatorBase(OperatorBase&& o) = delete;
|
|
|
|
|
|
|
|
|
|
virtual ~OperatorBase() {}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -151,6 +162,15 @@ class OperatorBase {
|
|
|
|
|
AttributeMap attrs_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
|
|
|
|
|
public: \
|
|
|
|
|
Class() : ParentClass() { /* TODO(yi): This constructor is to be removed. */ \
|
|
|
|
|
} \
|
|
|
|
|
Class(const std::string& type, const VarNameMap& inputs, \
|
|
|
|
|
const VarNameMap& outputs, \
|
|
|
|
|
const paddle::framework::AttributeMap& attrs) \
|
|
|
|
|
: ParentClass(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
class InferShapeContext {
|
|
|
|
|
public:
|
|
|
|
|
InferShapeContext(const OperatorBase& op, const Scope& scope)
|
|
|
|
@ -290,6 +310,8 @@ class OpKernel {
|
|
|
|
|
|
|
|
|
|
class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
DEFINE_OPERATOR_CTOR(OperatorWithKernel, OperatorBase)
|
|
|
|
|
|
|
|
|
|
struct OpKernelKey {
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
|
|
|
|
|