support const input in graph_ir convertor, add value inference in Concat

pull/8789/head
chenhaozhe 4 years ago
parent b2f52c881b
commit b3add83bf0

@ -593,9 +593,17 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) {
AnfNodePtr anf_out = node;
AnfNodePtr pre_node = nullptr;
// trace Parameter node
// Trace value node
if (node->isa<ValueNode>()) {
auto op = Convert(anf_out);
graph_outputs_.emplace_back(std::make_pair(*op, ""));
AddGraphConstInput(op);
return;
}
// Trace Parameter node
TraceOutputFromParameter(anf_out);
// then trace cnode
// Then trace cnode
if (!node->isa<CNode>()) {
return;
}
@ -869,7 +877,12 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
}
}
MS_LOG(DEBUG) << "trace output";
graph_outputs_.clear();
TraceOutput(anf_graph_->get_return()->input(1));
// Add const nodes as graph input for some operator work with constant
MS_LOG(INFO) << "graph const input size: " << graph_const_inputs_.size();
std::transform(graph_const_inputs_.begin(), graph_const_inputs_.end(), std::back_inserter(inputs),
[](OperatorPtr x) { return *x; });
@ -879,8 +892,6 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
// set graph output
// set the value of finale return apply node as the output of dataflow graph
MS_LOG(DEBUG) << "set output";
graph_outputs_.clear();
TraceOutput(anf_graph_->get_return()->input(1));
MS_LOG(INFO) << "set graph output num: " << graph_outputs_.size();
(void)df_graph_->SetOutputs(graph_outputs_);
@ -1036,7 +1047,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
}
void DfGraphConvertor::AddGraphConstInput(const OperatorPtr &op) {
if (op->GetOpType() == "Constant") {
if (op->GetOpType() == "Constant" || op->GetOpType() == "Const") {
graph_const_inputs_.push_back(op);
}
}

@ -1965,10 +1965,13 @@ class Concat(PrimitiveWithInfer):
self.add_prim_attr('T', x_type[0].element_type())
self.add_prim_attr('inputNums', len(x_shp))
ret_shp = x_shp[0].copy()
value = None
if input_x['value'] is not None:
value = Tensor(np.concatenate([x.asnumpy() for x in input_x['value']], axis=axis))
ret_shp[axis] = all_shp
out = {'shape': ret_shp,
'dtype': x_type[0],
'value': None}
'value': value}
return out

@ -42,13 +42,13 @@ def test_convert_tuple_input_to_dynamic_input(tag):
@fns
def before(x):
res = concat((t1, t2))
res = concat((x, x))
res = add(x, res)
return res
@fns
def after(x):
res = concat(t1, t2)
res = concat(x, x)
res = add(x, res)
res = make_tuple(res)
return res

Loading…
Cancel
Save