|
|
@ -54,6 +54,7 @@ static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
|
|
|
|
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
|
|
|
|
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
|
|
|
|
// it will be one item in map with key: C, and value: (B, i)
|
|
|
|
// it will be one item in map with key: C, and value: (B, i)
|
|
|
|
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap;
|
|
|
|
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap;
|
|
|
|
|
|
|
|
static void HandleNoUsedParameter(const FuncGraphPtr &root);
|
|
|
|
|
|
|
|
|
|
|
|
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
|
|
|
|
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
|
|
|
|
if (new_node_input.empty()) {
|
|
|
|
if (new_node_input.empty()) {
|
|
|
@ -3032,6 +3033,68 @@ void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
|
|
|
|
auto manager = graph->manager();
|
|
|
|
|
|
|
|
auto node_users = manager->node_users()[parameter];
|
|
|
|
|
|
|
|
if (node_users.empty()) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (auto node_user : node_users) {
|
|
|
|
|
|
|
|
auto use_node = node_user.first->cast<CNodePtr>();
|
|
|
|
|
|
|
|
if (IsValueNode<FuncGraph>(use_node->input(0))) {
|
|
|
|
|
|
|
|
auto graph_sub = GetValueNode<FuncGraphPtr>(use_node->input(0));
|
|
|
|
|
|
|
|
auto parameters = graph_sub->parameters();
|
|
|
|
|
|
|
|
auto parameter_sub = parameters[node_user.second - 1];
|
|
|
|
|
|
|
|
return IsUsedParameter(graph_sub, parameter_sub);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (use_node->input(0)->isa<CNode>()) {
|
|
|
|
|
|
|
|
auto cnode = use_node->input(0)->cast<CNodePtr>();
|
|
|
|
|
|
|
|
if (!IsSomePrimitive(cnode, J) || !IsValueNode<FuncGraph>(cnode->input(1))) {
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto graph_sub = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
|
|
|
|
|
|
|
auto parameters = graph_sub->parameters();
|
|
|
|
|
|
|
|
auto parameter_sub = parameters[node_user.second - 1];
|
|
|
|
|
|
|
|
return IsUsedParameter(graph_sub, parameter_sub);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static void HandleNoUsedParameter(const FuncGraphPtr &root) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(root);
|
|
|
|
|
|
|
|
bool full_batch = ParallelContext::GetInstance()->full_batch();
|
|
|
|
|
|
|
|
if (full_batch) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto dev_num = g_device_manager->GetDeviceListByStageId(0).size();
|
|
|
|
|
|
|
|
auto parameters = root->parameters();
|
|
|
|
|
|
|
|
for (auto ¶meter : parameters) {
|
|
|
|
|
|
|
|
if (IsUsedParameter(root, parameter)) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto parameter_shape = GetNodeShape(parameter);
|
|
|
|
|
|
|
|
if (parameter_shape.empty()) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
Shape slice_shape = parameter_shape[0];
|
|
|
|
|
|
|
|
if (slice_shape.empty()) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
slice_shape[0] = slice_shape[0] / dev_num;
|
|
|
|
|
|
|
|
auto slice_shape_ptr = std::make_shared<abstract::Shape>(slice_shape);
|
|
|
|
|
|
|
|
auto abstract = parameter->abstract();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract);
|
|
|
|
|
|
|
|
auto abstract_cloned = abstract->Clone();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract_cloned);
|
|
|
|
|
|
|
|
abstract_cloned->set_shape(slice_shape_ptr);
|
|
|
|
|
|
|
|
parameter->set_abstract(abstract_cloned);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
|
|
|
|
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
|
|
|
|
MS_EXCEPTION_IF_NULL(root);
|
|
|
|
MS_EXCEPTION_IF_NULL(root);
|
|
|
|
MS_EXCEPTION_IF_NULL(optimizer);
|
|
|
|
MS_EXCEPTION_IF_NULL(optimizer);
|
|
|
@ -3103,6 +3166,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|
|
|
// cover Parallel shape
|
|
|
|
// cover Parallel shape
|
|
|
|
CoverSliceShape(root);
|
|
|
|
CoverSliceShape(root);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// handle input is not used
|
|
|
|
|
|
|
|
HandleNoUsedParameter(root);
|
|
|
|
|
|
|
|
|
|
|
|
// set the shape for optimizer's clone tensor
|
|
|
|
// set the shape for optimizer's clone tensor
|
|
|
|
SetClonedTensorShapeForOptimizer(root);
|
|
|
|
SetClonedTensorShapeForOptimizer(root);
|
|
|
|
|
|
|
|
|
|
|
|