fix bug if input not used

pull/8223/head
lichenever 4 years ago
parent a3f9be98c0
commit 7c7006f347

@ -54,6 +54,7 @@ static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
// g_RefMap, for CNode B input i is a RefKey[Parameter C], // g_RefMap, for CNode B input i is a RefKey[Parameter C],
// it will be one item in map with key: C, and value: (B, i) // it will be one item in map with key: C, and value: (B, i)
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap; static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap;
static void HandleNoUsedParameter(const FuncGraphPtr &root);
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
if (new_node_input.empty()) { if (new_node_input.empty()) {
@ -3032,6 +3033,68 @@ void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
} }
} }
bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr &parameter) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(parameter);
auto manager = graph->manager();
auto node_users = manager->node_users()[parameter];
if (node_users.empty()) {
return false;
}
for (auto node_user : node_users) {
auto use_node = node_user.first->cast<CNodePtr>();
if (IsValueNode<FuncGraph>(use_node->input(0))) {
auto graph_sub = GetValueNode<FuncGraphPtr>(use_node->input(0));
auto parameters = graph_sub->parameters();
auto parameter_sub = parameters[node_user.second - 1];
return IsUsedParameter(graph_sub, parameter_sub);
}
if (use_node->input(0)->isa<CNode>()) {
auto cnode = use_node->input(0)->cast<CNodePtr>();
if (!IsSomePrimitive(cnode, J) || !IsValueNode<FuncGraph>(cnode->input(1))) {
return true;
}
auto graph_sub = GetValueNode<FuncGraphPtr>(cnode->input(1));
auto parameters = graph_sub->parameters();
auto parameter_sub = parameters[node_user.second - 1];
return IsUsedParameter(graph_sub, parameter_sub);
}
return true;
}
return true;
}
static void HandleNoUsedParameter(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
bool full_batch = ParallelContext::GetInstance()->full_batch();
if (full_batch) {
return;
}
auto dev_num = g_device_manager->GetDeviceListByStageId(0).size();
auto parameters = root->parameters();
for (auto &parameter : parameters) {
if (IsUsedParameter(root, parameter)) {
continue;
}
auto parameter_shape = GetNodeShape(parameter);
if (parameter_shape.empty()) {
continue;
}
Shape slice_shape = parameter_shape[0];
if (slice_shape.empty()) {
continue;
}
slice_shape[0] = slice_shape[0] / dev_num;
auto slice_shape_ptr = std::make_shared<abstract::Shape>(slice_shape);
auto abstract = parameter->abstract();
MS_EXCEPTION_IF_NULL(abstract);
auto abstract_cloned = abstract->Clone();
MS_EXCEPTION_IF_NULL(abstract_cloned);
abstract_cloned->set_shape(slice_shape_ptr);
parameter->set_abstract(abstract_cloned);
}
}
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(root);
MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(optimizer);
@ -3103,6 +3166,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// cover Parallel shape // cover Parallel shape
CoverSliceShape(root); CoverSliceShape(root);
// handle input is not used
HandleNoUsedParameter(root);
// set the shape for optimizer's clone tensor // set the shape for optimizer's clone tensor
SetClonedTensorShapeForOptimizer(root); SetClonedTensorShapeForOptimizer(root);

@ -161,6 +161,8 @@ std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node);
ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)); ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &));
bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr &parameter);
void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator, void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr &distribute_operator,
const CNodePtr &cnode, const AnfNodePtr &parameter, size_t index); const CNodePtr &cnode, const AnfNodePtr &parameter, size_t index);
} // namespace parallel } // namespace parallel

Loading…
Cancel
Save