|
|
@ -370,15 +370,12 @@ bool IsParallelCareNode(const CNodePtr& cnode) {
|
|
|
|
if (prim == nullptr) {
|
|
|
|
if (prim == nullptr) {
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto attrs = prim->attrs();
|
|
|
|
|
|
|
|
if (IsInBlackList(prim)) {
|
|
|
|
if (IsInBlackList(prim)) {
|
|
|
|
MS_LOG(INFO) << "Parallel don't care node: " << prim->name();
|
|
|
|
MS_LOG(INFO) << "Parallel don't care node: " << prim->name();
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if ((prim->name() == CAST)) {
|
|
|
|
if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) {
|
|
|
|
if ((!attrs.count(STRATEGY)) && (cnode->operator_info() == nullptr)) {
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return cnode->in_forward_flag();
|
|
|
|
return cnode->in_forward_flag();
|
|
|
@ -648,6 +645,13 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr& loss_node) {
|
|
|
|
MS_EXCEPTION_IF_NULL(pre_node);
|
|
|
|
MS_EXCEPTION_IF_NULL(pre_node);
|
|
|
|
|
|
|
|
|
|
|
|
LossNodeInfo node_info;
|
|
|
|
LossNodeInfo node_info;
|
|
|
|
|
|
|
|
// return -> cast
|
|
|
|
|
|
|
|
auto pre_cnode = pre_node->cast<CNodePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(pre_cnode);
|
|
|
|
|
|
|
|
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
|
|
|
|
|
|
if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
|
|
|
|
|
|
|
|
pre_node = pre_cnode->input(1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// return -> loss
|
|
|
|
// return -> loss
|
|
|
|
if (pre_node == loss_node) {
|
|
|
|
if (pre_node == loss_node) {
|
|
|
@ -1943,6 +1947,13 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) {
|
|
|
|
MS_EXCEPTION_IF_NULL(current_value);
|
|
|
|
MS_EXCEPTION_IF_NULL(current_value);
|
|
|
|
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
|
|
|
|
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(current_prim);
|
|
|
|
MS_EXCEPTION_IF_NULL(current_prim);
|
|
|
|
|
|
|
|
// return -> cast
|
|
|
|
|
|
|
|
if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
|
|
|
|
|
|
|
|
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(pre_cnode);
|
|
|
|
|
|
|
|
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// notice: the GetNext op has not input
|
|
|
|
// notice: the GetNext op has not input
|
|
|
|
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
|
|
|
|
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
|
|
|
|
MS_LOG(INFO) << "The loss is: " << current_prim->name();
|
|
|
|
MS_LOG(INFO) << "The loss is: " << current_prim->name();
|
|
|
|