From b3add83bf04d6e71724132b7bf93c4c9a3aea4c5 Mon Sep 17 00:00:00 2001 From: chenhaozhe Date: Thu, 19 Nov 2020 17:23:43 +0800 Subject: [PATCH] support const input in graph_ir convertor, add value inference in Concat --- mindspore/ccsrc/transform/graph_ir/convert.cc | 21 ++++++++++++++----- mindspore/ops/operations/array_ops.py | 5 ++++- ...nvert_tuple_input_to_dynamic_input_test.py | 4 ++-- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index da0dd28cc9..47e8b23a9a 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -593,9 +593,17 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) { AnfNodePtr anf_out = node; AnfNodePtr pre_node = nullptr; - // trace Parameter node + // Trace value node + if (node->isa()) { + auto op = Convert(anf_out); + graph_outputs_.emplace_back(std::make_pair(*op, "")); + AddGraphConstInput(op); + return; + } + + // Trace Parameter node TraceOutputFromParameter(anf_out); - // then trace cnode + // Then trace cnode if (!node->isa()) { return; } @@ -869,7 +877,12 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { } } + MS_LOG(DEBUG) << "trace output"; + graph_outputs_.clear(); + TraceOutput(anf_graph_->get_return()->input(1)); + // Add const nodes as graph input for some operator work with constant + MS_LOG(INFO) << "graph const input size: " << graph_const_inputs_.size(); std::transform(graph_const_inputs_.begin(), graph_const_inputs_.end(), std::back_inserter(inputs), [](OperatorPtr x) { return *x; }); @@ -879,8 +892,6 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { // set graph output // set the value of finale return apply node as the output of dataflow graph MS_LOG(DEBUG) << "set output"; - graph_outputs_.clear(); - TraceOutput(anf_graph_->get_return()->input(1)); MS_LOG(INFO) << "set graph output num: " << graph_outputs_.size(); (void)df_graph_->SetOutputs(graph_outputs_); @@ -1036,7 +1047,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node } void DfGraphConvertor::AddGraphConstInput(const OperatorPtr &op) { - if (op->GetOpType() == "Constant") { + if (op->GetOpType() == "Constant" || op->GetOpType() == "Const") { graph_const_inputs_.push_back(op); } } diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 4a871f6fc7..8a2a202df6 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1965,10 +1965,13 @@ class Concat(PrimitiveWithInfer): self.add_prim_attr('T', x_type[0].element_type()) self.add_prim_attr('inputNums', len(x_shp)) ret_shp = x_shp[0].copy() + value = None + if input_x['value'] is not None: + value = Tensor(np.concatenate([x.asnumpy() for x in input_x['value']], axis=axis)) ret_shp[axis] = all_shp out = {'shape': ret_shp, 'dtype': x_type[0], - 'value': None} + 'value': value} return out diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/convert_tuple_input_to_dynamic_input_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/convert_tuple_input_to_dynamic_input_test.py index c698143f21..d06ac1b8bb 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/convert_tuple_input_to_dynamic_input_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/convert_tuple_input_to_dynamic_input_test.py @@ -42,13 +42,13 @@ def test_convert_tuple_input_to_dynamic_input(tag): @fns def before(x): - res = concat((t1, t2)) + res = concat((x, x)) res = add(x, res) return res @fns def after(x): - res = concat(t1, t2) + res = concat(x, x) res = add(x, res) res = make_tuple(res) return res