diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc index b4f4cb5b22..30173e533c 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc @@ -399,7 +399,12 @@ Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) { ret_ = ret; root_graph_ = ret_->func_graph(); MS_EXCEPTION_IF_NULL(root_graph_); - auto forward_graph = ForwardGraph(root_graph_); + auto graph_set = ForwardGraph(root_graph_); + if (graph_set.size() > 1) { + MS_LOG(WARNING) << "AllReduce fusion don't support multiple subgraphs now."; + return SUCCESS; + } + auto forward_graph = *(graph_set.begin()); MS_EXCEPTION_IF_NULL(forward_graph); forward_ret_ = forward_graph->get_return(); MS_EXCEPTION_IF_NULL(forward_ret_); diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index d1390db899..c24c14abf6 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -1607,72 +1607,79 @@ void ReshapeInit(const std::vector &all_nodes) { } } -// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) -bool IsGradSensNode(const AnfNodePtr &node) { - if (!node->isa()) { - return false; +CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + CNodePtr return_node = func_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->size() < 2) { + MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; } + AnfNodePtr pre_node = return_node->input(1); + MS_EXCEPTION_IF_NULL(pre_node); - // cnode(sens)-->cnode(tuple_getitem) - auto cnode = node->cast(); - AnfNodePtr expect_tuple_getitem = cnode->input(0); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem); - if (!expect_tuple_getitem->isa()) { - return false; - } - auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast(); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem_cnode); - if (!IsValueNode(expect_tuple_getitem_cnode->input(0))) { - return false; + auto pre_cnode = pre_node->cast(); + MS_EXCEPTION_IF_NULL(pre_cnode); + auto current_prim = GetValueNode(pre_cnode->input(0)); + + // return -> cast + if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { + pre_cnode = pre_cnode->input(1)->cast(); + MS_EXCEPTION_IF_NULL(pre_cnode); + current_prim = GetValueNode(pre_cnode->input(0)); } - ValueNodePtr expect_tuple_getitem_value_node = expect_tuple_getitem_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem_value_node); - PrimitivePtr expect_tuple_getitem_prim = expect_tuple_getitem_value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem_prim); - if (expect_tuple_getitem_prim->name() != TUPLE_GETITEM) { - return false; + + // notice: the GetNext op has not input + if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { + MS_LOG(INFO) << "The loss is: " << current_prim->name(); + return pre_cnode; } - // cnode(sens)-->cnode(tuple_getitem)-->cnode - AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); - MS_EXCEPTION_IF_NULL(expect_anonymous); - if (!expect_anonymous->isa()) { - return false; + // size of common cnode is larger than 1 + if (pre_cnode->size() < 2) { + MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2"; } - // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) - auto expect_anonymous_cnode = expect_anonymous->cast(); - MS_EXCEPTION_IF_NULL(expect_anonymous_cnode); - AnfNodePtr expect_j = expect_anonymous_cnode->input(0); - MS_EXCEPTION_IF_NULL(expect_j); - if (!expect_j->isa()) { - return false; + // return -> tuple_getitem -> loss + if (current_prim->name() == TUPLE_GETITEM) { + AnfNodePtr pre_pre_node = pre_cnode->input(1); + MS_EXCEPTION_IF_NULL(pre_pre_node); + + auto pre_pre_cnode = pre_pre_node->cast(); + auto value = pre_pre_cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(value); + PrimitivePtr prim = value->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + MS_LOG(DEBUG) << "The loss name is " << prim->name(); + return pre_pre_cnode; } - auto expect_j_cnode = expect_j->cast(); - MS_EXCEPTION_IF_NULL(expect_j_cnode); - if (!IsValueNode(expect_j_cnode->input(0))) { - return false; + + // return -> make_tuple + if (current_prim->name() == MAKE_TUPLE) { + MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; } - ValueNodePtr expect_j_value_node = expect_j_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(expect_j_value_node); - PrimitivePtr expect_j_prim = expect_j_value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(expect_j_prim); - return (expect_j_prim->name() == J); + + // return -> loss + MS_LOG(DEBUG) << "The loss name is " << current_prim->name(); + return pre_cnode; } -TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { +TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + TensorLayouts ret; + if (!IsValueNode(cnode->input(1))) { + MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; + } + auto func_graph = GetValueNode(cnode->input(1)); + auto loss_cnode = FindLossCNode(func_graph); MS_EXCEPTION_IF_NULL(loss_cnode); AnfNodePtr node = loss_cnode->cast(); MS_EXCEPTION_IF_NULL(node); LossNodeInfo node_info = GetLossNodeInfo(node); - ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast(); MS_EXCEPTION_IF_NULL(prim_anf_node); PrimitivePtr prim = prim_anf_node->value()->cast(); MS_EXCEPTION_IF_NULL(prim); - - TensorLayouts ret; if (INVALID_LOSS_OPS.find(prim->name()) != INVALID_LOSS_OPS.end()) { MS_LOG(WARNING) << "The loss name is: " << prim->name() << ", do nothing for split sens now"; return ret; @@ -1680,7 +1687,6 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { OperatorInfoPtr operator_info = loss_cnode->operator_info(); MS_EXCEPTION_IF_NULL(operator_info); - TensorInfo loss_grad_tensor_info; size_t op_output_size = operator_info->outputs_tensor_info().size(); MS_LOG(INFO) << "The loss name is " << operator_info->name() << ", the has tuple item is " @@ -1805,6 +1811,100 @@ void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePt HandleDropoutNode(distribute_operator, cnode); } +std::set FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { + // J->CNode->Graph + std::set graph_set; + for (auto &node : root_all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if ((cnode->size() < 2) || !IsValueNode(cnode->input(0))) { + continue; + } + auto expect_j_prim = GetValueNode(cnode->input(0)); + if (expect_j_prim->name() != J) { + continue; + } + if (IsValueNode(cnode->input(1))) { + auto graph = GetValueNode(cnode->input(1)); + MS_LOG(DEBUG) << "Find the forward graph success"; + graph_set.insert(graph); + } + } + return graph_set; +} + +// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) +void StepSplitSens(const AnfNodePtr &node) { + if (!node->isa()) { + return; + } + + // cnode(sens)-->cnode(tuple_getitem) + auto cnode = node->cast(); + AnfNodePtr expect_tuple_getitem = cnode->input(0); + MS_EXCEPTION_IF_NULL(expect_tuple_getitem); + if (!expect_tuple_getitem->isa()) { + return; + } + auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast(); + MS_EXCEPTION_IF_NULL(expect_tuple_getitem_cnode); + if (!IsValueNode(expect_tuple_getitem_cnode->input(0))) { + return; + } + auto expect_tuple_getitem_prim = GetValueNode(expect_tuple_getitem_cnode->input(0)); + if (expect_tuple_getitem_prim->name() != TUPLE_GETITEM) { + return; + } + + // cnode(sens)-->cnode(tuple_getitem)-->cnode + AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); + MS_EXCEPTION_IF_NULL(expect_anonymous); + if (!expect_anonymous->isa()) { + return; + } + + // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) + auto expect_anonymous_cnode = expect_anonymous->cast(); + MS_EXCEPTION_IF_NULL(expect_anonymous_cnode); + AnfNodePtr expect_j = expect_anonymous_cnode->input(0); + MS_EXCEPTION_IF_NULL(expect_j); + if (!expect_j->isa()) { + return; + } + auto expect_j_cnode = expect_j->cast(); + MS_EXCEPTION_IF_NULL(expect_j_cnode); + if (!IsValueNode(expect_j_cnode->input(0))) { + return; + } + auto expect_j_prim = GetValueNode(expect_j_cnode->input(0)); + if (expect_j_prim->name() == J) { + auto loss_grad_layout = GetLossNodeGradOutputLayout(expect_j_cnode); + if (!loss_grad_layout.empty()) { + SplitSens(node, loss_grad_layout[0]); + } + } +} + +std::vector FindLossCNodeFromRoot(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + AnfNodePtr root_return_node = root->get_return(); + MS_EXCEPTION_IF_NULL(root_return_node); + std::vector loss_node; + const auto &all_nodes = root->nodes(); + std::set graph_set = FindForwardGraphByRootNodes(all_nodes); + if (graph_set.empty()) { + loss_node.push_back(FindLossCNode(root)); + } + (void)std::transform(graph_set.begin(), graph_set.end(), std::back_inserter(loss_node), + [](const FuncGraphPtr &graph) { return FindLossCNode(graph); }); + return loss_node; +} + void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(root); @@ -1812,18 +1912,15 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector loss_cnode = FindLossCNodeFromRoot(root); + // split sens must before inserting the operators. for (auto &node : all_nodes) { - // find sens node - if ((grad_sens_node == nullptr) && IsGradSensNode(node)) { - grad_sens_node = node; - MS_LOG(INFO) << "Find the sens node success"; - } + // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. + // If the type of sens node is not Tensor, it is unsupported now, do nothing default. + StepSplitSens(node); + } + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { auto cnode = node->cast(); @@ -1837,7 +1934,8 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vectorget_return(); - MS_EXCEPTION_IF_NULL(return_node); - if (return_node->inputs().size() < 2) { - MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; - } - AnfNodePtr pre_node = return_node->input(1); - MS_EXCEPTION_IF_NULL(pre_node); - - auto pre_cnode = pre_node->cast(); - MS_EXCEPTION_IF_NULL(pre_cnode); - auto current_value = pre_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(current_value); - PrimitivePtr current_prim = current_value->value()->cast(); - MS_EXCEPTION_IF_NULL(current_prim); - - // return -> cast - if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { - pre_cnode = pre_cnode->input(1)->cast(); - MS_EXCEPTION_IF_NULL(pre_cnode); - current_prim = GetValueNode(pre_cnode->input(0)); - } - - // notice: the GetNext op has not input - if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { - MS_LOG(INFO) << "The loss is: " << current_prim->name(); - return pre_cnode; - } - - // size of common cnode is larger than 1 - if (pre_cnode->inputs().size() < 2) { - MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2"; - } - - // return -> tuple_getitem -> loss - if (current_prim->name() == TUPLE_GETITEM) { - AnfNodePtr pre_pre_node = pre_cnode->input(1); - MS_EXCEPTION_IF_NULL(pre_pre_node); - - auto pre_pre_cnode = pre_pre_node->cast(); - auto value = pre_pre_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(value); - PrimitivePtr prim = value->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - MS_LOG(INFO) << "The loss name is " << prim->name(); - return pre_pre_cnode; - } else if (current_prim->name() == MAKE_TUPLE) { - MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; - } - - // return -> loss - MS_LOG(INFO) << "The loss name is " << current_prim->name(); - return pre_cnode; +std::set ForwardGraph(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + const auto &all_nodes = root->nodes(); + std::set graph_set = FindForwardGraphByRootNodes(all_nodes); + return graph_set; } -FuncGraphPtr FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { - for (auto &node : root_all_nodes) { +std::vector FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { + MS_EXCEPTION_IF_NULL(graph); + auto loss_cnode = FindLossCNode(graph); + MS_EXCEPTION_IF_NULL(loss_cnode); + auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy(); + std::vector root_forward_nodes; + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; } - auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - if ((cnode->inputs().size() < 2) || !IsValueNode(cnode->input(0))) { - continue; - } - ValueNodePtr expect_j_value_node = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(expect_j_value_node); - PrimitivePtr expect_j_prim = expect_j_value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(expect_j_prim); - if (expect_j_prim->name() != J) { - continue; - } - MS_LOG(DEBUG) << "Find J prim: " << expect_j_value_node->DebugString() << "."; - if (IsValueNode(cnode->input(1))) { - auto graph = GetValueNode(cnode->input(1)); - MS_LOG(INFO) << "Find the forward graph success"; - return graph; + auto root_node_id = node->UniqueIdThroughCopy(); + if (loss_cnode_id == root_node_id) { + root_forward_nodes = DeepLinkedGraphSearch(cnode); + break; } } - return nullptr; -} - -CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - AnfNodePtr root_return_node = root->get_return(); - MS_EXCEPTION_IF_NULL(root_return_node); - const auto &all_nodes = root->nodes(); - FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); - if (func_graph == nullptr) { - return FindLossCNode(root); - } else { - return FindLossCNode(func_graph); - } -} - -FuncGraphPtr ForwardGraph(const FuncGraphPtr &root) { - FuncGraphPtr forward_graph = root; - MS_EXCEPTION_IF_NULL(root); - AnfNodePtr root_return_node = root->get_return(); - MS_EXCEPTION_IF_NULL(root_return_node); - const auto &all_nodes = root->nodes(); - FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); - if (func_graph != nullptr) { - forward_graph = func_graph; - } - return forward_graph; + return root_forward_nodes; } void MarkForwardCNode(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); - AnfNodePtr root_return_node = root->get_return(); - MS_EXCEPTION_IF_NULL(root_return_node); - auto &all_nodes = root->nodes(); - FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); + auto all_nodes = root->nodes(); + std::set graph_set = FindForwardGraphByRootNodes(all_nodes); - if (func_graph == nullptr) { - // Can not find the forward graph, so the ops in root graph are forward. + if (graph_set.empty()) { MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph"; SetForwardFlag(all_nodes); } else { - MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); - AnfNodePtr return_node = func_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - std::vector all_dfs_nodes = DeepLinkedGraphSearch(return_node); - SetForwardFlag(all_dfs_nodes); + for (auto &func_graph : graph_set) { + MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); + auto return_node = func_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + auto all_dfs_nodes = DeepLinkedGraphSearch(return_node); + SetForwardFlag(all_dfs_nodes); + auto root_forward_nodes = FindRootForwardCNode(func_graph, all_nodes); + if (root_forward_nodes.empty()) { + continue; + } + // Mark forward flag for the nodes in root graph. + SetForwardFlag(root_forward_nodes); + } } } diff --git a/mindspore/ccsrc/parallel/step_parallel.h b/mindspore/ccsrc/parallel/step_parallel.h index 184d11d173..b0d128f515 100644 --- a/mindspore/ccsrc/parallel/step_parallel.h +++ b/mindspore/ccsrc/parallel/step_parallel.h @@ -24,6 +24,7 @@ #include #include #include +#include #include "./common.h" #include "optimizer/opt.h" @@ -142,13 +143,13 @@ bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optim int32_t GetTupleGetItemIndex(const CNodePtr &cnode); -CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr &root); +std::vector FindLossCNodeFromRoot(const FuncGraphPtr &root); Status ParallelInit(); std::vector ExtractInputsTensorName(const CNodePtr &node); -FuncGraphPtr ForwardGraph(const FuncGraphPtr &root); +std::set ForwardGraph(const FuncGraphPtr &root); } // namespace parallel } // namespace mindspore diff --git a/tests/ut/python/parallel/test_semi_auto_two_subgraphs.py b/tests/ut/python/parallel/test_semi_auto_two_subgraphs.py new file mode 100644 index 0000000000..b572968a4f --- /dev/null +++ b/tests/ut/python/parallel/test_semi_auto_two_subgraphs.py @@ -0,0 +1,108 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mindspore as ms +from mindspore import Tensor, Parameter, ParameterTuple, context +from mindspore import nn +from mindspore.common.api import _executor +from mindspore.nn.optim import Adam, FTRL +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +import numpy as np + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.mul = P.Mul() + self.relu = P.ReLU() + self.param1 = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="wide") + self.param2 = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="deep") + + def construct(self, x): + out = self.mul(x, self.param1) + out = self.mul(out, self.param2) + out = self.relu(out) + return out + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.sum = P.ReduceSum(keep_dims=False).set_strategy(strategy=((4, 1, 1, 1),)) + self.mean = P.ReduceMean(keep_dims=False).set_strategy(strategy=((8, 1, 1, 1),)) + self.net = network + + def construct(self, x): + net_out = self.net(x) + loss1 = self.sum(net_out, -1) + loss2 = self.mean(net_out, -1) + return loss1, loss2 + + +class IthOutputCell(nn.Cell): + def __init__(self, network, output_index): + super(IthOutputCell, self).__init__() + self.network = network + self.output_index = output_index + + def construct(self, x1): + predict = self.network(x1)[self.output_index] + return predict + + +class TrainStepWrap(nn.Cell): + def __init__(self, network, sens=1000.0): + super(TrainStepWrap, self).__init__() + self.network = network + self.network.set_train() + self.trainable_params = network.trainable_params() + weights_w = [] + weights_d = [] + for params in self.trainable_params: + weights_w.append(params) + weights_d.append(params) + + self.weights_w = ParameterTuple(weights_w) + self.weights_d = ParameterTuple(weights_d) + self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w, + l1=1e-8, l2=1e-8, initial_accum=1.0) + self.optimizer_d = Adam(self.weights_d, learning_rate=3.5e-4, eps=1e-8, + loss_scale=sens) + self.hyper_map = C.HyperMap() + self.grad_w = C.GradOperation('grad_w', get_by_list=True, + sens_param=True) + self.grad_d = C.GradOperation('grad_d', get_by_list=True, + sens_param=True) + self.sens = sens + self.loss_net_w = IthOutputCell(network, output_index=0) + self.loss_net_d = IthOutputCell(network, output_index=1) + + def construct(self, x): + weights_w = self.weights_w + weights_d = self.weights_d + loss_w, loss_d = self.network(x) + sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens) + sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens) + grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w) + grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d) + return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, self.optimizer_d(grads_d)) + + +def test_two_subgraphs(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + net = TrainStepWrap(NetWithLoss(Net())) + input_x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32) + _executor.compile(net, input_x)