!7724 set last operator strategy euqal to label in eval or predict

Merge pull request !7724 from yao_yf/set_last_operator_strategy_euqal_to_label_in_eval
pull/7724/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit fcee224c0e

@ -1512,7 +1512,87 @@ Status ValidStageCheck(const std::vector<int32_t> &stages, int32_t strategy_stag
}
}
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
// find previous parallel care node.
bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids) {
MS_EXCEPTION_IF_NULL(unique_ids);
// if previous node is a parameter, handle it in the outsize.
if (node->isa<Parameter>()) {
return false;
}
if (!node->isa<CNode>()) {
return false;
}
CNodePtr cnode = node->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
if (IsParallelCareNode(cnode) && prim->name() != MAKE_TUPLE && prim->name() != MAKE_LIST) {
unique_ids->push_back(cnode->UniqueId());
return true;
}
bool find = false;
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
if (prim->name() == DEPEND && index != 1) {
continue;
}
if (FindPreNodes(cnode->inputs()[index], unique_ids)) {
find = true;
continue;
}
}
return find;
}
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, vector<std::string> *unique_ids) {
MS_EXCEPTION_IF_NULL(unique_ids);
for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == RETURN) {
if (!FindPreNodes(cnode, unique_ids)) {
MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph";
}
}
}
}
StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim) {
MS_EXCEPTION_IF_NULL(operator_);
MS_EXCEPTION_IF_NULL(prim);
StrategyPtr strategyPtr;
std::shared_ptr<Strategys> strategy_v_ptr = operator_->GenerateBatchStrategies();
MS_EXCEPTION_IF_NULL(strategy_v_ptr);
strategyPtr = NewStrategy(0, *strategy_v_ptr);
std::vector<ValuePtr> elements;
for (size_t i = 0; i < strategy_v_ptr->size(); i++) {
elements.push_back(MakeValue((*strategy_v_ptr)[i]));
}
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
// display the strategy generated by batch parallel
auto attrs = prim->attrs();
attrs[GEN_STRATEGY] = strategy;
(void)prim->SetAttrs(attrs);
MS_LOG(INFO) << "prim " << prim->name() << " batch parallel strategy is " << attrs[GEN_STRATEGY]->ToString();
return strategyPtr;
}
void SetLastNodeStrategy(const StrategyPtr strategyPtr) {
auto strategys = strategyPtr->GetInputDim();
for (size_t i = 0; i < strategys.size(); ++i) {
for (size_t j = 0; j < strategys[i].size(); ++j) {
strategys[i][j] = 1;
}
}
strategyPtr->ResetInputs(strategys);
}
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training) {
// load strategy map from checkpoint
StrategyMap stra_map;
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
@ -1520,7 +1600,11 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
}
}
vector<std::string> last_forward_node_ids;
if (!is_training) {
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
}
// Get global rank after the checkpoint?
int32_t global_rank = ParallelContext::GetInstance()->global_rank();
std::vector<int32_t> stages = ParallelContext::GetInstance()->stage();
@ -1572,30 +1656,22 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
}
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
if (!StrategyFound(attrs) && !load_strategy_from_ckpt) {
bool is_last_nodes = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId()) !=
last_forward_node_ids.end();
bool full_batch = ParallelContext::GetInstance()->full_batch();
if ((is_last_nodes && !full_batch) || (!StrategyFound(attrs) && !load_strategy_from_ckpt)) {
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel";
std::shared_ptr<Strategys> strategy_v_ptr = operator_->GenerateBatchStrategies();
if (strategy_v_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Failure:Generate batch parallel strategy failed";
}
std::vector<ValuePtr> elements;
for (size_t i = 0; i < strategy_v_ptr->size(); i++) {
elements.push_back(MakeValue((*strategy_v_ptr)[i]));
}
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
// display the strategy generated by batch parallel
attrs[GEN_STRATEGY] = strategy;
(void)prim->SetAttrs(attrs);
MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is "
<< attrs[GEN_STRATEGY]->ToString();
strategyPtr = NewStrategy(0, *strategy_v_ptr);
strategyPtr = GenerateBatchParallelStrategy(operator_, prim);
} else if (load_strategy_from_ckpt) {
strategyPtr = stra_map[strategy_key_name];
} else {
strategyPtr = ExtractStrategy(attrs);
}
if (strategyPtr != nullptr) {
if (is_last_nodes && full_batch) {
SetLastNodeStrategy(strategyPtr);
}
(*operator_).set_stage_id(strategyPtr->GetInputStage());
MS_LOG(INFO) << "Extract stage id for op " << prim->name() << " is " << (*operator_).stage_id();
if (ValidStageCheck(stages, (*operator_).stage_id()) == FAILED) {
@ -2854,7 +2930,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
}
// extract shape and strategy, set operator_info
ExtractInformation(all_nodes);
ExtractInformation(all_nodes, root->has_flag(TRAINING));
ReshapeInit(all_nodes);
}

@ -118,7 +118,7 @@ void CoverSliceShape(const FuncGraphPtr &root);
void SetVirtualDatasetStrategy(const CNodePtr &node);
// Creat parallel operator for primitive node(has strategy)
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes);
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training = true);
TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int> &node_pair);

@ -59,6 +59,7 @@ class Grad(nn.Cell):
def compile_net(net, x, y):
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y)

@ -48,6 +48,7 @@ class GradWrap(nn.Cell):
def compile_net(net, x, y, b):
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y, b)
@ -649,6 +650,7 @@ def test_assign_sub():
def compile_sub_net(net, x):
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x)
context.set_auto_parallel_context(device_num=64, global_rank=15)
@ -696,6 +698,7 @@ def test_assign_add():
def compile_sub_net(net, x):
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x)
context.set_auto_parallel_context(device_num=64, global_rank=15)
@ -743,6 +746,7 @@ def test_assign():
def compile_sub_net(net, x):
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x)
context.set_auto_parallel_context(device_num=64, global_rank=15)

@ -73,4 +73,5 @@ def test_auto_parallel_bn_with_prelu():
net = GradWrap(NetWithLoss(Net()))
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x)

@ -43,6 +43,7 @@ def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()

@ -52,6 +52,7 @@ class GradWrap(nn.Cell):
def compile_net(net, x, y, b, phase):
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y, b, phase=phase)

@ -61,6 +61,7 @@ def test_auto_parallel_assign_sub_with_ref_key():
net.set_auto_parallel()
reset_op_id()
net.set_train()
_executor.compile(net, x, phase="train")
strategies = _executor._get_shard_strategy(net)
for (k, v) in strategies.items():

@ -81,6 +81,7 @@ def test_double_star_graph():
net.set_auto_parallel()
reset_op_id()
net.set_train()
_executor.compile(net, x, y, z, w, phase='train')
strategies = _executor._get_shard_strategy(net)
expected_strategies = {'Default/network-Net/Cast-op0': [[8, 1]],

@ -72,4 +72,5 @@ def test_common_parameter():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y, z)

@ -79,6 +79,7 @@ def test_double_source_graph():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y, z, w, a)
@ -114,4 +115,5 @@ def test_double_source_complex_graph():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y, z, w, a)

@ -83,4 +83,5 @@ def test_double_star_graph():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y, z, w, a, b, c)

@ -113,6 +113,7 @@ def test_double_subgraphs():
x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32)
reset_op_id()
net.set_train()
_executor.compile(net, x, phase='train')
strategies = _executor._get_shard_strategy(net)
for (k, v) in strategies.items():

@ -70,4 +70,5 @@ def test_two_matmul():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y, b)

@ -49,6 +49,7 @@ class GradWrap(nn.Cell):
def compile_net(net, x, y, z, w, b):
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y, z, w, b)
# model_parallel test

@ -73,4 +73,5 @@ def test_auto_parallel_l2normalize():
x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
y = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
b = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
net.set_train()
_executor.compile(net, x, y, b, phase='train')

@ -70,4 +70,5 @@ def test_two_matmul_dropout():
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
net.set_train()
_executor.compile(net, x, y, b)

@ -74,6 +74,7 @@ def test_matmul_prelu():
net.set_auto_parallel()
reset_op_id()
net.set_train()
_executor.compile(net, x, y, b, phase='train')
strategies = _executor._get_shard_strategy(net)
for (k, v) in strategies.items():

@ -58,6 +58,7 @@ def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, inputs_, label_)
context.reset_auto_parallel_context()

@ -99,6 +99,7 @@ def test_auto_parallel_arithmetic():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.int32)
net.set_train()
_executor.compile(net, x, y, b)

@ -68,6 +68,7 @@ def test_common_parameter():
net.set_auto_parallel()
reset_op_id()
net.set_train()
_executor.compile(net, x, y, phase='train')
strategies = _executor._get_shard_strategy(net)
for (k, v) in strategies.items():

@ -77,4 +77,5 @@ def test_four_matmul_linear():
net = GradWrap(NetWithLoss(Net(strategy1)))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y, z, w, b)

@ -49,6 +49,7 @@ class GradWrap(nn.Cell):
def compile_net(net, x, y, b):
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y, b)

@ -68,6 +68,7 @@ def test_reshape_matmul():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x)
def test_reshape_reshape():
@ -90,6 +91,7 @@ def test_reshape_reshape():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x)
@ -115,6 +117,7 @@ def test_reshape_auto_1():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x)
@ -143,6 +146,7 @@ def test_reshape_auto_2():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x)
@ -168,6 +172,7 @@ def test_reshape_auto_3():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x)
@ -194,6 +199,7 @@ def test_reshape_auto_4():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x)
@ -244,6 +250,7 @@ def test_reshape_auto_5():
net = GradWrap5(NetWithLoss5(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y)
def test_reshape_auto_6():
@ -291,6 +298,7 @@ def test_reshape_auto_6():
net = GradWrap6(NetWithLoss6(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y)
def test_reshape_auto_7():
@ -313,4 +321,5 @@ def test_reshape_auto_7():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x)

@ -49,6 +49,7 @@ class GradWrap(nn.Cell):
def compile_net(net, x, y, b):
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y, b)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save