!2410 Support insert memcpy between two hccl op if the part outputs of prior hccl op linking to next hccl op

Merge pull request !2410 from huanghui/insert-memcpy-async-pass
pull/2410/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 8f4bab4e75

@ -87,6 +87,7 @@
#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h"
#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" #include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h"
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h"
#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h" #include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h"
#include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h" #include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h"
#include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h" #include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h"
@ -340,6 +341,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
other_pm->AddPass(std::make_shared<AllGatherFusion>()); other_pm->AddPass(std::make_shared<AllGatherFusion>());
other_pm->AddPass(std::make_shared<ReduceScatterFusion>()); other_pm->AddPass(std::make_shared<ReduceScatterFusion>());
other_pm->AddPass(std::make_shared<BroadcastFusion>()); other_pm->AddPass(std::make_shared<BroadcastFusion>());
other_pm->AddPass(std::make_shared<InsertMemcpyAsyncForCascade>());
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>()); other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
other_pm->AddPass(std::make_shared<RefreshParameterFormat>()); other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
optimizer->AddPassManager(other_pm); optimizer->AddPassManager(other_pm);

@ -0,0 +1,114 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h"
#include <vector>
#include <set>
#include <string>
#include "utils/utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "frontend/optimizer/opt.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
namespace {
bool IsPartOutputsOfHcclOp(const AnfNodePtr &node, const CNodePtr &cur_hccl, const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(cur_hccl);
MS_EXCEPTION_IF_NULL(graph);
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto prev_node = cnode->input(kRealInputNodeIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(prev_node);
if (!AnfAlgo::IsCommunicationOp(prev_node)) {
return false;
}
auto prev_hccl_op = prev_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(prev_hccl_op);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto iter = node_users.find(prev_hccl_op);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager";
}
for (const auto &node_index : iter->second) {
AnfNodePtr output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
bool is_contain = false;
for (size_t i = 1; i < cur_hccl->size(); ++i) {
if (cur_hccl->input(i) == output) {
is_contain = true;
break;
}
}
if (!is_contain) {
return true;
}
}
}
return false;
}
} // namespace
AnfNodePtr InsertMemcpyAsyncForCascade::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(hccl_node);
std::vector<AnfNodePtr> memcpy_async_list;
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
for (size_t i = 1; i < hccl_node->size(); ++i) {
auto input = hccl_node->input(i);
MS_EXCEPTION_IF_NULL(input);
// when input is also a hccl op and just part outputs of it linking with cur_hccl_op
if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) {
auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
auto kernel_info = std::make_shared<device::KernelInfo>();
memcpy_async->set_kernel_info(kernel_info);
MS_EXCEPTION_IF_NULL(kernel_select_);
kernel_select_->SelectKernel(memcpy_async->cast<CNodePtr>());
new_inputs.push_back(memcpy_async);
memcpy_async_list.push_back(memcpy_async);
} else {
new_inputs.push_back(input);
}
}
if (!memcpy_async_list.empty()) {
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
new_hccl_node->set_inputs(new_inputs);
return new_hccl_node;
}
return nullptr;
}
const AnfNodePtr InsertMemcpyAsyncForCascade::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
if (!AnfAlgo::IsCommunicationOp(node)) {
return nullptr;
}
return InsertMemcpyAsync(func_graph, cnode);
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
class InsertMemcpyAsyncForCascade : public PatternProcessPass {
public:
explicit InsertMemcpyAsyncForCascade(bool multigraph = true)
: PatternProcessPass("insert_memcpy_async_for_cascade", multigraph),
kernel_select_(std::make_shared<KernelSelect>()) {}
~InsertMemcpyAsyncForCascade() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const;
KernelSelectPtr kernel_select_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_OP_CASCADE_H_

@ -32,12 +32,17 @@ const std::set<std::string> kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNe
bool IsParameterOrValueNode(const AnfNodePtr &node) { bool IsParameterOrValueNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
return kernel_with_index.first->isa<Parameter>() || kernel_with_index.first->isa<ValueNode>(); auto real_node = kernel_with_index.first;
MS_EXCEPTION_IF_NULL(real_node);
if (real_node->isa<Parameter>()) {
return true;
}
return real_node->isa<ValueNode>();
} }
void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) { void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list,
const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(hccl_node); MS_EXCEPTION_IF_NULL(hccl_node);
MS_EXCEPTION_IF_NULL(memcpy_async);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager(); auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
@ -48,49 +53,62 @@ void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async,
} }
// find hccl_node's output which is a control depend // find hccl_node's output which is a control depend
for (const auto &node_index : iter->second) { for (const auto &node_index : iter->second) {
AnfNodePtr output = node_index.first; if (!AnfAlgo::CheckPrimitiveType(node_index.first, prim::kPrimControlDepend)) {
int output_index = node_index.second; continue;
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { }
CNodePtr control_depend = output->cast<CNodePtr>(); CNodePtr control_depend = node_index.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(control_depend); MS_EXCEPTION_IF_NULL(control_depend);
std::vector<AnfNodePtr> new_inputs; std::vector<AnfNodePtr> new_inputs;
for (size_t i = 0; i < control_depend->size(); ++i) { for (size_t i = 0; i < control_depend->size(); ++i) {
if (i == IntToSize(output_index)) { if (i == IntToSize(node_index.second)) {
new_inputs.push_back(memcpy_async); std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
} else { make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end());
new_inputs.push_back(control_depend->input(i)); make_tuple_inputs.emplace_back(hccl_node);
} auto make_tuple = graph->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple);
new_inputs.push_back(make_tuple);
} else {
new_inputs.push_back(control_depend->input(i));
} }
control_depend->set_inputs(new_inputs);
} }
control_depend->set_inputs(new_inputs);
} }
} }
} // namespace } // namespace
bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const { bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input,
const CNodePtr &cur_node) const {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
MS_EXCEPTION_IF_NULL(cur_node);
// when input is a parameter or is a value node // when input is a parameter or is a value node
if (IsParameterOrValueNode(input)) { if (IsParameterOrValueNode(input)) {
return true; return true;
} }
// when input is a Ref or some special cnodes if (input->isa<CNode>()) {
if (kernel_query_->IsTbeRef(input) || auto manager = graph->manager();
kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { MS_EXCEPTION_IF_NULL(manager);
return true; auto &node_users = manager->node_users();
}
auto manager = graph->manager(); // when input is a Ref cnode
MS_EXCEPTION_IF_NULL(manager); if (kernel_query_->IsTbeRef(input)) {
auto &node_users = manager->node_users(); return true;
auto iter = node_users.find(input); }
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager"; // when input is some special cnodes
} if (kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) {
// when input is used by others return true;
if (iter->second.size() > 1) { }
return true;
// when input is used by others
auto iter = node_users.find(input);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager";
}
if (iter->second.size() > 1) {
return true;
}
} }
return false; return false;
} }
@ -98,21 +116,20 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(hccl_node); MS_EXCEPTION_IF_NULL(hccl_node);
bool has_insert_memcpy = false; std::vector<AnfNodePtr> memcpy_async_list;
AnfNodePtr memcpy_async = nullptr;
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)}; std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
for (size_t i = 1; i < hccl_node->size(); ++i) { for (size_t i = 1; i < hccl_node->size(); ++i) {
auto input = hccl_node->input(i); auto input = hccl_node->input(i);
if (NeedInsertMemcpy(graph, input)) { if (NeedInsertMemcpy(graph, input, hccl_node)) {
memcpy_async = CreateMemcpyAsyncOp(graph, input); auto memcpy_async = CreateMemcpyAsyncOp(graph, input);
has_insert_memcpy = true;
new_inputs.push_back(memcpy_async); new_inputs.push_back(memcpy_async);
memcpy_async_list.push_back(memcpy_async);
} else { } else {
new_inputs.push_back(input); new_inputs.push_back(input);
} }
} }
if (has_insert_memcpy) { if (!memcpy_async_list.empty()) {
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node); CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
new_hccl_node->set_inputs(new_inputs); new_hccl_node->set_inputs(new_inputs);
auto manager = graph->manager(); auto manager = graph->manager();
@ -122,9 +139,7 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
MS_LOG(DEBUG) << "end replace"; MS_LOG(DEBUG) << "end replace";
// transer hccl op's control to the memcpy_async // transer hccl op's control to the memcpy_async
if (hccl_node->size() == 2) { TransferControl(new_hccl_node, memcpy_async_list, graph);
TransferControl(new_hccl_node, memcpy_async, graph);
}
} }
} }

@ -32,7 +32,7 @@ class InsertMemcpyAsyncForHcclOp : public PatternProcessPass {
private: private:
void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const;
bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const; bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &cur_node) const;
KernelQueryPtr kernel_query_; KernelQueryPtr kernel_query_;
}; };
} // namespace opt } // namespace opt

@ -22,6 +22,7 @@
#include "utils/utils.h" #include "utils/utils.h"
#include "backend/kernel_compiler/kernel_build_info.h" #include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/optimizer.h"
#include "ir/param_value.h"
#define private public #define private public
#define protected public #define protected public
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
@ -44,12 +45,10 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery {
~MockInsertMemcpyForHcclKernelQuery() override = default; ~MockInsertMemcpyForHcclKernelQuery() override = default;
bool IsTbeRef(const AnfNodePtr &node) override { bool IsTbeRef(const AnfNodePtr &node) override {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>(); if (!node->isa<CNode>()) {
if (cnode == nullptr) {
return false; return false;
} }
auto name = AnfAlgo::GetCNodeName(cnode); return AnfAlgo::GetCNodeName(node->cast<CNodePtr>()) == "ApplyMomentum";
return name == "ApplyMomentum";
} }
}; };
@ -105,6 +104,11 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond2) {
AbstractBasePtrList args_spec_list{x_abstract}; AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list); auto kg = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kg, nullptr); EXPECT_NE(kg, nullptr);
for (auto p : kg->parameters()) {
auto param = p->cast<ParameterPtr>();
EXPECT_NE(param, nullptr);
param->set_default_param(std::make_shared<ParamValue>());
}
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>(); auto pm = std::make_shared<opt::PassManager>();
@ -146,10 +150,16 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) {
ASSERT_TRUE(g != nullptr); ASSERT_TRUE(g != nullptr);
std::vector<int> shp_x{1, 64, 112, 112}; std::vector<int> shp_x{1, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; AbstractBasePtrList args_spec_list{x_abstract, x_abstract};
auto kg = GetKernelGraph(g, args_spec_list); auto kg = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kg, nullptr); EXPECT_NE(kg, nullptr);
for (auto p : kg->parameters()) {
auto param = p->cast<ParameterPtr>();
EXPECT_NE(param, nullptr);
param->set_default_param(std::make_shared<ParamValue>());
}
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>(); auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>();
@ -161,5 +171,34 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "after"); FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
} }
TEST_F(TestHWInsertMemcpyForHccl, test_cond5) {
get_py_fun_.SetDoResolve(true);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "before");
ASSERT_TRUE(g != nullptr);
std::vector<int> shp_x{1, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kg, nullptr);
for (auto p : kg->parameters()) {
auto param = p->cast<ParameterPtr>();
EXPECT_NE(param, nullptr);
param->set_default_param(std::make_shared<ParamValue>());
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>();
pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>();
pm->AddPass(pass);
optimizer->AddPassManager(pm);
auto new_graph = optimizer->Optimize(kg);
kg->SetExecOrderByDefault();
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

@ -17,6 +17,7 @@ from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
all_reduce = P.AllReduce() all_reduce = P.AllReduce()
broadcast = P.Broadcast(1)
memcpy_async = Primitive('memcpy_async') memcpy_async = Primitive('memcpy_async')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive('tuple_getitem')
@ -101,20 +102,40 @@ def test_insert_memcpy_async_for_hccl_op_cond4(tag):
fns = FnDict() fns = FnDict()
@fns @fns
def before(a, b, c, d, e): def before(a, b):
res1 = apply_momentun(a, b, c, d, e) x = relu(a)
res2 = all_reduce(a) y = all_reduce(b)
res = control_depend(res1, res2) res = control_depend(x, y)
res = make_tuple(res, res2)
return res return res
@fns @fns
def after(a, b, c, d, e): def after(a, b):
res1 = apply_momentun(a, b, c, d, e) x = relu(a)
res2 = memcpy_async(a) y1 = memcpy_async(b)
res3 = all_reduce(res2) y2 = all_reduce(y1)
res = control_depend(res1, res2) res = control_depend(x, make_tuple(y1, y2))
res = make_tuple(res, res3) return make_tuple(res)
return fns[tag]
def test_insert_memcpy_async_for_hccl_op_cond5(tag):
fns = FnDict()
@fns
def before(a, b, c):
x = relu(a)
y = broadcast((b, c))
res = control_depend(x, y)
return res
@fns
def after(a, b, c):
x = relu(a)
m1 = memcpy_async(b)
m2 = memcpy_async(c)
y = broadcast(m1, m2)
res = control_depend(x, make_tuple(m1, m2, y))
return make_tuple(res) return make_tuple(res)
return fns[tag] return fns[tag]

Loading…
Cancel
Save