|
|
|
@ -49,6 +49,9 @@ namespace mindspore {
|
|
|
|
|
namespace parallel {
|
|
|
|
|
const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
|
|
|
|
|
const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
|
|
|
|
|
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
|
|
|
|
|
// it will be one item in map with key: C, and value: (B, i)
|
|
|
|
|
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap;
|
|
|
|
|
|
|
|
|
|
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
|
|
|
|
|
if (new_node_input.empty()) {
|
|
|
|
@ -1085,11 +1088,19 @@ std::vector<Shapes> ExtractShape(const CNodePtr& node) {
|
|
|
|
|
std::vector<AnfNodePtr> all_inputs = node->inputs();
|
|
|
|
|
std::vector<AnfNodePtr> node_inputs{all_inputs.begin() + 1, all_inputs.end()};
|
|
|
|
|
|
|
|
|
|
for (auto& input : node_inputs) {
|
|
|
|
|
size_t inputs_size = all_inputs.size();
|
|
|
|
|
for (size_t i = 1; i < inputs_size; ++i) {
|
|
|
|
|
Shapes input_shapes;
|
|
|
|
|
AnfNodePtr input = all_inputs[i];
|
|
|
|
|
if (IsValueNode<RefKey>(input)) {
|
|
|
|
|
auto func_graph = node->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
|
|
|
|
|
if (parameters.size() != 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
|
|
|
|
|
}
|
|
|
|
|
std::pair<AnfNodePtr, int> node_pair = std::make_pair(node, SizeToInt(i));
|
|
|
|
|
g_RefMap[parameters[0]] = node_pair;
|
|
|
|
|
input_shapes = GetRefKeyNodeShape(input, func_graph);
|
|
|
|
|
} else if (IsValueNode<Tensor>(input) || input->isa<CNode>() || input->isa<Parameter>()) {
|
|
|
|
|
input_shapes = GetNodeShape(input);
|
|
|
|
@ -1205,14 +1216,20 @@ void CoverSliceShape(const FuncGraphPtr& root) {
|
|
|
|
|
auto parameters = root->parameters();
|
|
|
|
|
for (auto& parameter : parameters) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter->Shape());
|
|
|
|
|
auto iter = g_RefMap.find(parameter);
|
|
|
|
|
if (iter != g_RefMap.end()) {
|
|
|
|
|
SetParallelShape(parameter, g_RefMap[parameter]);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::pair<AnfNodePtr, int> res = FindSubGraph(root, parameter);
|
|
|
|
|
if (res.first == nullptr) {
|
|
|
|
|
MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape";
|
|
|
|
|
} else {
|
|
|
|
|
SetParallelShape(parameter, res);
|
|
|
|
|
MS_LOG(DEBUG) << "parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
|
|
|
|
|
MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
g_RefMap.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ParameterIsCloned(const FuncGraphPtr& root, const AnfNodePtr& parameter_node) {
|
|
|
|
|