|
|
|
@ -27,28 +27,15 @@ namespace details {
|
|
|
|
|
constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
|
|
|
|
|
|
|
|
|
|
class OpHandleBase {
|
|
|
|
|
private:
|
|
|
|
|
DISABLE_COPY_AND_ASSIGN(OpHandleBase);
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
std::vector<VarHandleBase *> inputs_;
|
|
|
|
|
std::vector<VarHandleBase *> outputs_;
|
|
|
|
|
std::unordered_map<platform::Place, platform::DeviceContext *,
|
|
|
|
|
platform::PlaceHash>
|
|
|
|
|
dev_ctxes_;
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
std::unordered_map<int, cudaEvent_t> events_;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
OpHandleBase() {}
|
|
|
|
|
|
|
|
|
|
virtual ~OpHandleBase();
|
|
|
|
|
|
|
|
|
|
std::string DebugString() const;
|
|
|
|
|
|
|
|
|
|
virtual std::string Name() const = 0;
|
|
|
|
|
|
|
|
|
|
virtual ~OpHandleBase();
|
|
|
|
|
|
|
|
|
|
void Run(bool use_event);
|
|
|
|
|
|
|
|
|
|
virtual void Wait(platform::DeviceContext *waited_dev);
|
|
|
|
@ -61,6 +48,18 @@ class OpHandleBase {
|
|
|
|
|
// will likely block other computations.
|
|
|
|
|
virtual bool IsMultiDeviceTransfer() { return false; }
|
|
|
|
|
|
|
|
|
|
const platform::DeviceContext *DeviceContext(platform::Place place) {
|
|
|
|
|
return dev_ctxes_[place];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) {
|
|
|
|
|
dev_ctxes_[place] = ctx_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<VarHandleBase *> &Inputs() const { return inputs_; }
|
|
|
|
|
|
|
|
|
|
const std::vector<VarHandleBase *> &Outputs() const { return outputs_; }
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void RunAndRecordEvent(const std::function<void()> &callback);
|
|
|
|
|
|
|
|
|
@ -68,6 +67,18 @@ class OpHandleBase {
|
|
|
|
|
const std::function<void()> &callback);
|
|
|
|
|
|
|
|
|
|
virtual void RunImpl() = 0;
|
|
|
|
|
|
|
|
|
|
std::vector<VarHandleBase *> inputs_;
|
|
|
|
|
std::vector<VarHandleBase *> outputs_;
|
|
|
|
|
std::unordered_map<platform::Place, platform::DeviceContext *,
|
|
|
|
|
platform::PlaceHash>
|
|
|
|
|
dev_ctxes_;
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
std::unordered_map<int, cudaEvent_t> events_;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
DISABLE_COPY_AND_ASSIGN(OpHandleBase);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace details
|
|
|
|
|