|
|
|
@ -89,8 +89,9 @@ class OperatorBase {
|
|
|
|
|
|
|
|
|
|
std::string DebugString() const { return DebugStringEx(nullptr); }
|
|
|
|
|
|
|
|
|
|
/// Net will call this function to Run an op.
|
|
|
|
|
virtual void Run(const Scope& scope, const platform::Place& place) const = 0;
|
|
|
|
|
/// Net will call this interface function to Run an op.
|
|
|
|
|
// The implementation should be written at RunImpl
|
|
|
|
|
void Run(const Scope& scope, const platform::Place& place);
|
|
|
|
|
|
|
|
|
|
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
|
|
|
|
|
virtual void Stop() {}
|
|
|
|
@ -144,6 +145,8 @@ class OperatorBase {
|
|
|
|
|
private:
|
|
|
|
|
void GenerateTemporaryNames();
|
|
|
|
|
void CheckAllInputOutputSet() const;
|
|
|
|
|
virtual void RunImpl(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Macro for define a clone method.
|
|
|
|
@ -168,10 +171,13 @@ class OperatorBase {
|
|
|
|
|
class NOP : public OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
using OperatorBase::OperatorBase;
|
|
|
|
|
void Run(const Scope& scope, const platform::Place& place) const override {}
|
|
|
|
|
std::unique_ptr<OperatorBase> Clone() const override {
|
|
|
|
|
return std::unique_ptr<OperatorBase>(new NOP(*this));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const Scope& scope,
|
|
|
|
|
const platform::Place& place) const override {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ExecutionContext {
|
|
|
|
@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
const VariableNameMap& outputs, const AttributeMap& attrs)
|
|
|
|
|
: OperatorBase(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void Run(const Scope& scope, const platform::Place& place) const final;
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
|
|
|
|
|
AllOpKernels() {
|
|
|
|
|
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
|
|
|
|
@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
// indicate kernel DataType by input data. Defaultly all input data must be
|
|
|
|
|
// same.
|
|
|
|
|
proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
|
|
|
|
|
void RunImpl(const Scope& scope, const platform::Place& place) const final;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
extern bool OpSupportGPU(const std::string& op_type);
|
|
|
|
|