|
|
|
@ -40,12 +40,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
|
|
|
|
|
size_t device_id) const;
|
|
|
|
|
void Init() const;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
mutable std::string loss_var_name_;
|
|
|
|
|
mutable std::vector<platform::Place> places_;
|
|
|
|
|
mutable std::vector<Scope *> local_scopes_;
|
|
|
|
|
mutable std::unordered_set<std::string> grad_names_;
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
mutable platform::NCCLContextMap *nccl_ctxs_;
|
|
|
|
|
#endif
|
|
|
|
@ -95,13 +89,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
|
|
|
|
|
size_t GetAppropriateDeviceID(
|
|
|
|
|
const std::vector<std::string> &var_names) const;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void SetCommunicationContext(OpHandleBase *op_handle,
|
|
|
|
|
const platform::Place &p) const;
|
|
|
|
|
|
|
|
|
|
mutable std::string loss_var_name_;
|
|
|
|
|
mutable std::vector<platform::Place> places_;
|
|
|
|
|
mutable std::vector<Scope *> local_scopes_;
|
|
|
|
|
mutable std::unordered_set<std::string> grad_names_;
|
|
|
|
|
|
|
|
|
|
mutable BuildStrategy strategy_;
|
|
|
|
|
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
|
|
|
|
|
mutable std::vector<int64_t> balance_vars_;
|
|
|
|
|
|
|
|
|
|
void SetCommunicationContext(OpHandleBase *op_handle,
|
|
|
|
|
const platform::Place &p) const;
|
|
|
|
|
};
|
|
|
|
|
} // namespace details
|
|
|
|
|
} // namespace framework
|
|
|
|
|