complete the selfOp func for matrix slice parse like mat[1,:] += 1

pull/7652/head
yepei6 4 years ago committed by Payne
parent 4df56b6c1e
commit c33e7c6092

@ -1016,36 +1016,36 @@ AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &no
return block->func_graph()->NewCNode({make_dict_op, keys_tuple, values_tuple});
}
// process a augment assign such as a += b;
// process a augment assign such as a += b or mat[stride_slice] += b.
FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast AugAssign";
py::object op = python_adapter::GetPyObjAttr(node, "op");
MS_EXCEPTION_IF_NULL(block);
// resolve the op
AnfNodePtr op_node = block->MakeResolveAstOp(op);
py::object target_node = python_adapter::GetPyObjAttr(node, "target");
MS_EXCEPTION_IF_NULL(ast_);
auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, target_node)));
AnfNodePtr read_node = nullptr;
py::object target_obj = python_adapter::GetPyObjAttr(node, "target");
py::object op_obj = python_adapter::GetPyObjAttr(node, "op");
py::object value_obj = python_adapter::GetPyObjAttr(node, "value");
AnfNodePtr target_node = nullptr;
AnfNodePtr op_node = block->MakeResolveAstOp(op_obj);
AnfNodePtr value_node = ParseExprNode(block, value_obj);
auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, target_obj)));
if (ast_type == AST_SUB_TYPE_NAME) {
read_node = ParseName(block, target_node);
} else if (ast_->IsClassMember(target_node)) {
read_node = ParseAttribute(block, target_node);
target_node = ParseName(block, target_obj);
} else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
target_node = ParseSubscript(block, target_obj);
} else if (ast_->IsClassMember(target_obj)) {
target_node = ParseAttribute(block, target_obj);
} else {
MS_LOG(EXCEPTION) << "Not supported augassign";
}
if (read_node == nullptr) {
if (target_node == nullptr) {
MS_LOG(EXCEPTION) << "Can not get target node ";
}
py::object value = python_adapter::GetPyObjAttr(node, "value");
AnfNodePtr value_node = ParseExprNode(block, value);
CNodePtr augassign_app = block->func_graph()->NewCNode({op_node, read_node, value_node});
WriteAssignVars(block, target_node, augassign_app);
CNodePtr augassign_app = block->func_graph()->NewCNode({op_node, target_node, value_node});
WriteAssignVars(block, target_obj, augassign_app);
return block;
}
// process global declaration such as 'global x';
FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Global";

Loading…
Cancel
Save