From 7765d44b76f741e58e045f60690156670a7aba6b Mon Sep 17 00:00:00 2001 From: kingfo Date: Fri, 28 Aug 2020 17:36:22 +0800 Subject: [PATCH] support parameter tuple input in pynative mode --- .../pipeline/pynative/pynative_execute.cc | 50 ++++++++++-- .../pipeline/pynative/pynative_execute.h | 15 +++- .../pynative_mode/test_tuple_parameter.py | 77 +++++++++++++++++++ 3 files changed, 134 insertions(+), 8 deletions(-) create mode 100644 tests/ut/python/pynative_mode/test_tuple_parameter.py diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index a4a4870927..ee6d2742cb 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -793,6 +793,20 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { return node; } +AnfNodePtr PynativeExecutor::GetParamNode(const py::object &obj) { + auto id = GetId(obj); + auto ¶m = graph_info_map_[curr_g_].param_map[id]; + if (param.second.size() == 1 && param.second[0] == -1) { + return param.first; + } + auto para_node = param.first; + for (auto &idx : param.second) { + std::vector tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), para_node, NewValueNode(idx)}; + para_node = curr_g_->NewCNode(tuple_get_item_inputs); + } + return para_node; +} + std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &args) { auto cell_id = GetId(cell); for (size_t i = 0; i < args.size(); i++) { @@ -995,9 +1009,18 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg graph_info_map_[g] = GraphInfo(); } for (size_t i = 0; i < args.size(); i++) { + auto param = args[i]; auto new_param = g->add_parameter(); - std::string param_obj = GetId(args[i]); - graph_info_map_[g].param_map[param_obj] = new_param; + std::string param_obj = GetId(param); + if (py::isinstance(param)) { + auto tuple = param.cast(); + auto tuple_size = static_cast(tuple.size()); + for (int j = 0; j < tuple_size; j++) { + set_param_map(curr_g_, GetId(tuple[j]), new_param, j); + SetTupleParam(tuple[j], new_param, std::vector{j}); + } + } + set_param_map(curr_g_, param_obj, new_param); } } @@ -1028,16 +1051,16 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { auto value = py::cast(obj); free_param->set_default_param(value); MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; - graph_info_map_[df_builder_].param_map[obj_id] = free_param; + set_param_map(df_builder_, obj_id, free_param); return free_param; } - return graph_info_map_[df_builder_].param_map[obj_id]; + return graph_info_map_[df_builder_].param_map[obj_id].first; } // if input is graph output if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) { // op(x, y) - node = graph_info_map_[curr_g_].param_map[obj_id]; + node = GetParamNode(obj); } else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) { // out = op(op1(x, y)) // out = op(cell1(x, y)) @@ -1085,6 +1108,19 @@ void PynativeExecutor::SetTupleOutput(const py::object &obj, const AnfNodePtr &c } } +// for param ((a, (b, c)), d) need multi getitem +void PynativeExecutor::SetTupleParam(const py::object &obj, const AnfNodePtr ¶_node, std::vector idx) { + if (py::isinstance(obj)) { + auto tuple = obj.cast(); + for (int i = 0; i < static_cast(tuple.size()); i++) { + std::vector tmp = idx; + tmp.push_back(i); + set_param_map(curr_g_, GetId(tuple[i]), para_node, tmp); + SetTupleParam(tuple[i], para_node, tmp); + } + } +} + void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); } void PynativeExecutor::Popp() { @@ -1132,7 +1168,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje const py::args &args) { AnfNodePtr output_node; if (graph_info_map_[curr_g_].param_map.count(out_id)) { - output_node = graph_info_map_[curr_g_].param_map[out_id]; + output_node = GetParamNode(out); } else { output_node = GetObjNode(out); } @@ -1186,7 +1222,7 @@ std::vector PynativeExecutor::GetWeightsArgs(const py::object &weigh auto param_id = GetId(param); AnfNodePtr para_node = nullptr; if (graph_info_map_[df_builder_].param_map.count(param_id)) { - para_node = graph_info_map_[df_builder_].param_map[param_id]; + para_node = graph_info_map_[df_builder_].param_map[param_id].first; } else { auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(param, "name"); if (py::isinstance(name_attr)) { diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index df895df15b..ee29353c21 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -59,7 +59,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tupl void ClearPyNativeSession(); struct GraphInfo { - std::unordered_map param_map; + std::unordered_map>> param_map; std::unordered_map>> obj_node_map; AnfNodePtr output; std::vector objects; @@ -92,6 +92,7 @@ class PynativeExecutor : public std::enable_shared_from_this { void set_grad_flag(bool flag) { grad_flag_ = flag; } AnfNodePtr GetInput(const py::object &obj, bool op_mask); AnfNodePtr GetObjNode(const py::object &obj); + AnfNodePtr GetParamNode(const py::object &obj); std::string GetCellId(const py::object &obj, const py::args &args); FuncGraphPtr curr_g() { return curr_g_; } void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); } @@ -104,6 +105,17 @@ class PynativeExecutor : public std::enable_shared_from_this { void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector index) { graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index); } + + void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) { + graph_info_map_[g].param_map[obj] = std::make_pair(node, std::vector{-1}); + } + void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) { + graph_info_map_[g].param_map[obj] = std::make_pair(node, std::vector{index}); + } + void set_param_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector index) { + graph_info_map_[g].param_map[obj] = std::make_pair(node, index); + } + AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, abstract::AbstractBasePtrList *args_spec_list); void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode); @@ -119,6 +131,7 @@ class PynativeExecutor : public std::enable_shared_from_this { FuncGraphPtr GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector &weights, size_t arg_size); void SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector idx); + void SetTupleParam(const py::object &obj, const AnfNodePtr ¶_node, std::vector idx); AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); py::tuple RunOpInner(const py::args &args); py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info); diff --git a/tests/ut/python/pynative_mode/test_tuple_parameter.py b/tests/ut/python/pynative_mode/test_tuple_parameter.py new file mode 100644 index 0000000000..a14e8470da --- /dev/null +++ b/tests/ut/python/pynative_mode/test_tuple_parameter.py @@ -0,0 +1,77 @@ +import numpy as np + +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.ops import operations as P +from mindspore.ops import composite as C + + + +def setup_module(module): + context.set_context(mode=context.PYNATIVE_MODE) + + +class Block1(nn.Cell): + """ Define Cell with tuple input as paramter.""" + + def __init__(self): + super(Block1, self).__init__() + self.mul = P.Mul() + + def construct(self, tuple_xy): + x, y = tuple_xy + z = self.mul(x, y) + return z + +class Block2(nn.Cell): + """ definition with tuple in tuple output in Cell.""" + + def __init__(self): + super(Block2, self).__init__() + self.mul = P.Mul() + self.add = P.TensorAdd() + + def construct(self, x, y): + z1 = self.mul(x, y) + z2 = self.add(z1, x) + z3 = self.add(z1, y) + return (z1, (z2, z3)) + +class Net1(nn.Cell): + def __init__(self): + super(Net1, self).__init__() + self.block = Block1() + + def construct(self, x, y): + res = self.block((x, y)) + return res + + +class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.add = P.TensorAdd() + self.block = Block2() + + def construct(self, x, y): + z1, (z2, z3) = self.block(x, y) + res = self.add(z1, z2) + res = self.add(res, z3) + return res + +def test_net(): + context.set_context(save_graphs=True) + x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32) * 2) + y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32) * 3) + net1 = Net1() + grad_op = C.GradOperation(get_all=True) + output = grad_op(net1)(x, y) + assert np.all(output[0].asnumpy() == y.asnumpy()) + assert np.all(output[1].asnumpy() == x.asnumpy()) + + net2 = Net2() + output = grad_op(net2)(x, y) + expect_x = np.ones([1, 1, 3, 3]).astype(np.float32) * 10 + expect_y = np.ones([1, 1, 3, 3]).astype(np.float32) * 7 + assert np.all(output[0].asnumpy() == expect_x) + assert np.all(output[1].asnumpy() == expect_y)