|
|
@ -63,6 +63,17 @@ class ExecutionContext;
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
class OperatorBase {
|
|
|
|
class OperatorBase {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
|
|
|
|
OperatorBase() {} // TODO(yi): This constructor is to be removed.
|
|
|
|
|
|
|
|
OperatorBase(const std::string& type, const std::vector<std::string>& inputs,
|
|
|
|
|
|
|
|
const std::vector<std::string>& outputs,
|
|
|
|
|
|
|
|
const AttributeMap& attrs,
|
|
|
|
|
|
|
|
std::unordered_map<std::string, int>* in_out_idxs)
|
|
|
|
|
|
|
|
: type_(type),
|
|
|
|
|
|
|
|
inputs_(inputs),
|
|
|
|
|
|
|
|
outputs_(outputs),
|
|
|
|
|
|
|
|
attrs_(attrs),
|
|
|
|
|
|
|
|
in_out_idxs_(in_out_idxs) {}
|
|
|
|
|
|
|
|
|
|
|
|
virtual ~OperatorBase() {}
|
|
|
|
virtual ~OperatorBase() {}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
@ -109,6 +120,9 @@ class OperatorBase {
|
|
|
|
const std::vector<std::string> Inputs() const { return inputs_; }
|
|
|
|
const std::vector<std::string> Inputs() const { return inputs_; }
|
|
|
|
const std::vector<std::string> Outputs() const { return outputs_; }
|
|
|
|
const std::vector<std::string> Outputs() const { return outputs_; }
|
|
|
|
const AttributeMap& Attrs() const { return attrs_; }
|
|
|
|
const AttributeMap& Attrs() const { return attrs_; }
|
|
|
|
|
|
|
|
const std::unordered_map<std::string, int>* InOutIdx() const {
|
|
|
|
|
|
|
|
return in_out_idxs_.get();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
std::string type_;
|
|
|
|
std::string type_;
|
|
|
@ -286,6 +300,14 @@ class OpKernel {
|
|
|
|
|
|
|
|
|
|
|
|
class OperatorWithKernel : public OperatorBase {
|
|
|
|
class OperatorWithKernel : public OperatorBase {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
|
|
|
|
OperatorWithKernel() {} // TODO(yi): This constructor is to be removed.
|
|
|
|
|
|
|
|
OperatorWithKernel(const std::string& type,
|
|
|
|
|
|
|
|
const std::vector<std::string>& inputs,
|
|
|
|
|
|
|
|
const std::vector<std::string>& outputs,
|
|
|
|
|
|
|
|
const AttributeMap& attrs,
|
|
|
|
|
|
|
|
std::unordered_map<std::string, int>* in_out_idxs)
|
|
|
|
|
|
|
|
: OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
|
|
|
|
|
|
|
|
|
|
|
|
struct OpKernelKey {
|
|
|
|
struct OpKernelKey {
|
|
|
|
platform::Place place_;
|
|
|
|
platform::Place place_;
|
|
|
|
|
|
|
|
|
|
|
@ -335,5 +357,15 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
virtual void InferShape(const InferShapeContext& ctx) const = 0;
|
|
|
|
virtual void InferShape(const InferShapeContext& ctx) const = 0;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
|
|
|
|
|
|
|
|
public: \
|
|
|
|
|
|
|
|
Class() { /* TODO(yi): This constructor is to be removed. */ \
|
|
|
|
|
|
|
|
} \
|
|
|
|
|
|
|
|
Class(const std::string& type, const std::vector<std::string>& inputs, \
|
|
|
|
|
|
|
|
const std::vector<std::string>& outputs, \
|
|
|
|
|
|
|
|
const ::paddle::framework::AttributeMap& attrs, \
|
|
|
|
|
|
|
|
std::unordered_map<std::string, int>* in_out_idxs) \
|
|
|
|
|
|
|
|
: ParentClass(type, inputs, outputs, attrs, in_out_idxs) {}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|