|
|
|
@ -1061,13 +1061,13 @@ FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) {
|
|
|
|
|
FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
|
|
|
|
|
MS_LOG(DEBUG) << "Process ast For, create an if else statement";
|
|
|
|
|
MS_EXCEPTION_IF_NULL(block);
|
|
|
|
|
// create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT'
|
|
|
|
|
// create statement 'len(xs) < MAX_FOR_LOOP_COUNT'
|
|
|
|
|
AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
|
|
|
|
|
py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
|
|
|
|
|
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
|
|
|
|
|
CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
|
|
|
|
|
CNodePtr bool_node = block->func_graph()->NewCNode(
|
|
|
|
|
{NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(prim::MAX_FOR_LOOP_COUNT)});
|
|
|
|
|
CNodePtr bool_node =
|
|
|
|
|
block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(MAX_FOR_LOOP_COUNT)});
|
|
|
|
|
|
|
|
|
|
// create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
|
|
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
|
|
|
|
@ -1191,7 +1191,12 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
|
|
|
|
|
py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
|
|
|
|
|
AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(iter_node);
|
|
|
|
|
CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node});
|
|
|
|
|
// Generate node for loop count and convert it to tensor, to make the loop not unroll
|
|
|
|
|
CNodePtr scalar_len = block->func_graph()->NewCNode({op_len, iter_node});
|
|
|
|
|
auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations");
|
|
|
|
|
auto scalar_to_tensor_node = block->func_graph()->NewCNode({NewValueNode(scalar_to_tensor)});
|
|
|
|
|
|
|
|
|
|
CNodePtr len_iter = block->func_graph()->NewCNode({scalar_to_tensor_node, scalar_len});
|
|
|
|
|
|
|
|
|
|
FunctionBlockPtr header_block =
|
|
|
|
|
GenerateBlockInFor(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
|
|
|
|
@ -1199,7 +1204,9 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o
|
|
|
|
|
// create loop variable 'i'
|
|
|
|
|
ParameterPtr loop_var = header_block->func_graph()->add_parameter();
|
|
|
|
|
// create loop condition 'i < len(xs)'
|
|
|
|
|
CNodePtr cond_node = header_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), loop_var, len_iter});
|
|
|
|
|
auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations");
|
|
|
|
|
auto less_node = header_block->func_graph()->NewCNode({NewValueNode(prim_less)});
|
|
|
|
|
CNodePtr cond_node = header_block->func_graph()->NewCNode({less_node, loop_var, len_iter});
|
|
|
|
|
|
|
|
|
|
// generate the body of the for statement
|
|
|
|
|
FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
|
|
|
|
|