|
|
|
@ -54,8 +54,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
|
|
|
|
|
|
|
|
|
|
bool UseGPU() const;
|
|
|
|
|
|
|
|
|
|
bool NeedCollectiveForGrad(const std::string &grad_name,
|
|
|
|
|
std::vector<ir::Node *> ops) const;
|
|
|
|
|
virtual bool NeedCollectiveForGrad(const std::string &grad_name,
|
|
|
|
|
std::vector<ir::Node *> ops) const;
|
|
|
|
|
|
|
|
|
|
bool IsScaleLossOp(ir::Node *node) const;
|
|
|
|
|
|
|
|
|
@ -117,7 +117,10 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
|
|
|
|
|
void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
|
|
|
|
|
const std::string &g_name) const override {}
|
|
|
|
|
|
|
|
|
|
bool NeedCollectiveOps() const override { return false; }
|
|
|
|
|
bool NeedCollectiveForGrad(const std::string &grad_name,
|
|
|
|
|
std::vector<ir::Node *> ops) const {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override {
|
|
|
|
|
if (node->Op()->Type() == "recv") {
|
|
|
|
|