fix all reduce and loss overlap

pull/329/head
baker 4 years ago
parent b676e97ffa
commit 4decdfc4f3

@ -363,13 +363,10 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr
} }
} }
// Update stream id for nodes belong to skipped engine subgraph
GE_CHK_STATUS_RET(UpdateForSkippedEngine(graph, subgraphs));
return SUCCESS; return SUCCESS;
} }
int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { int64_t UpdateForSkippedEnginePass::GetSingleInoutStream(const NodePtr &node) const {
set<int64_t> stream_ids; set<int64_t> stream_ids;
for (const auto &in_node : node->GetInAllNodes()) { for (const auto &in_node : node->GetInAllNodes()) {
@ -398,8 +395,7 @@ int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const {
return kInvalidStream; return kInvalidStream;
} }
Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph, Status UpdateForSkippedEnginePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) {
const vector<SubgraphPtr> &subgraphs) {
set<OpDescPtr> ops_without_label; set<OpDescPtr> ops_without_label;
// Check if subgraph is engine skipped and without stream label or not // Check if subgraph is engine skipped and without stream label or not
@ -441,7 +437,7 @@ Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph
return SUCCESS; return SUCCESS;
} }
bool NodeStreamUpdatePass::AreAllPredStreamsInvalid(const NodePtr &node) const { bool UpdateForSkippedEnginePass::AreAllPredStreamsInvalid(const NodePtr &node) const {
for (const auto &pre_node : node->GetInAllNodes()) { for (const auto &pre_node : node->GetInAllNodes()) {
auto pre_node_desc = pre_node->GetOpDesc(); auto pre_node_desc = pre_node->GetOpDesc();
if (pre_node_desc != nullptr) { if (pre_node_desc != nullptr) {
@ -653,12 +649,14 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec
if (context_.enable_single_stream) { if (context_.enable_single_stream) {
passes.emplace_back(MakeShared<SingleStreamPass>()); passes.emplace_back(MakeShared<SingleStreamPass>());
passes.emplace_back(MakeShared<NodeStreamUpdatePass>()); passes.emplace_back(MakeShared<NodeStreamUpdatePass>());
passes.emplace_back(MakeShared<UpdateForSkippedEnginePass>());
} else { } else {
passes.emplace_back(MakeShared<AssignByLabelPass>()); passes.emplace_back(MakeShared<AssignByLabelPass>());
passes.emplace_back(MakeShared<IndependentStreamPass>()); passes.emplace_back(MakeShared<IndependentStreamPass>());
passes.emplace_back(MakeShared<AssignByDependencyPass>()); passes.emplace_back(MakeShared<AssignByDependencyPass>());
passes.emplace_back(MakeShared<NodeStreamUpdatePass>()); passes.emplace_back(MakeShared<NodeStreamUpdatePass>());
passes.emplace_back(MakeShared<AllReduceParallelPass>()); passes.emplace_back(MakeShared<AllReduceParallelPass>());
passes.emplace_back(MakeShared<UpdateForSkippedEnginePass>());
} }
for (auto &pass : passes) { for (auto &pass : passes) {

@ -147,15 +147,20 @@ class NodeStreamUpdatePass : public LogicalStreamPass {
public: public:
STREAM_PASS_DEFAULT_FUNC(NodeStreamUpdatePass); STREAM_PASS_DEFAULT_FUNC(NodeStreamUpdatePass);
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
};
private: // Update the stream of subgraphs to nodes.
class UpdateForSkippedEnginePass : public LogicalStreamPass {
public:
STREAM_PASS_DEFAULT_FUNC(UpdateForSkippedEnginePass);
/// Optimize for case like: /// Optimize for case like:
/// NodeA(stream1) -> Const(stream2) -> NodeB(stream1) /// NodeA(stream1) -> Const(stream2) -> NodeB(stream1)
/// To case: /// To case:
/// NodeA(stream1) -> Const(stream1) -> NodeB(stream1) /// NodeA(stream1) -> Const(stream1) -> NodeB(stream1)
/// Which could reduce event number (Const could be other type which belong to skipped engine subgraph) /// Which could reduce event number (Const could be other type which belong to skipped engine subgraph)
Status UpdateForSkippedEngine(const ComputeGraphPtr &graph, const std::vector<SubgraphPtr> &subgraphs); Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
private:
int64_t GetSingleInoutStream(const NodePtr &node) const; int64_t GetSingleInoutStream(const NodePtr &node) const;
// Judge if all predecessors' streams of node are kInvalidStream // Judge if all predecessors' streams of node are kInvalidStream
bool AreAllPredStreamsInvalid(const NodePtr &node) const; bool AreAllPredStreamsInvalid(const NodePtr &node) const;

Loading…
Cancel
Save