|
|
|
@ -406,83 +406,115 @@ ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract
|
|
|
|
|
return new_parameter;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> KernelGraph::SplitTupleParameterToNodeList(const ParameterPtr ¶meter) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
|
std::vector<AnfNodePtr> convert_nodes_list;
|
|
|
|
|
auto abstract = parameter->abstract();
|
|
|
|
|
ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node);
|
|
|
|
|
auto new_value_node = MakeValueNode(value_node)->cast<ValueNodePtr>();
|
|
|
|
|
AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
|
|
|
|
|
return new_value_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValueNodePtr KernelGraph::NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value);
|
|
|
|
|
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value);
|
|
|
|
|
new_value_node->set_abstract(abstract);
|
|
|
|
|
SetKernelInfoForNode(new_value_node);
|
|
|
|
|
AnfAlgo::SetGraphId(graph_id(), new_value_node.get());
|
|
|
|
|
return new_value_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr KernelGraph::TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value);
|
|
|
|
|
if (!abstract->isa<abstract::AbstractTuple>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Multiple output Parameter's output must be a tuple abstract but got " << abstract->ToString();
|
|
|
|
|
auto new_value_node = NewValueNode(abstract, value);
|
|
|
|
|
AddValueNodeToGraph(new_value_node);
|
|
|
|
|
return new_value_node;
|
|
|
|
|
}
|
|
|
|
|
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
|
|
|
|
auto value_tuple = value->cast<ValueTuplePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
|
|
|
|
for (size_t index = 0; index < tuple_abstract->size(); ++index) {
|
|
|
|
|
auto new_parameter = this->NewParameter((*tuple_abstract)[index]);
|
|
|
|
|
SetKernelInfoForNode(new_parameter);
|
|
|
|
|
convert_nodes_list.emplace_back(new_parameter);
|
|
|
|
|
}
|
|
|
|
|
auto new_inputs = std::make_shared<std::vector<AnfNodePtr>>();
|
|
|
|
|
auto old_inputs = inputs();
|
|
|
|
|
for (const auto &input_node : old_inputs) {
|
|
|
|
|
if (input_node != parameter) {
|
|
|
|
|
new_inputs->emplace_back(input_node);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::copy(convert_nodes_list.begin(), convert_nodes_list.end(), std::back_inserter(*new_inputs));
|
|
|
|
|
}
|
|
|
|
|
inputs_ = new_inputs;
|
|
|
|
|
return convert_nodes_list;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> KernelGraph::SplitTupleOutputNodeToNodeList(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << node->DebugString();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_tuple);
|
|
|
|
|
if (tuple_abstract->size() != value_tuple->size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Abstract size:" << tuple_abstract->size()
|
|
|
|
|
<< " is not equal to value size:" << value_tuple->size();
|
|
|
|
|
}
|
|
|
|
|
if (node->isa<Parameter>()) {
|
|
|
|
|
return SplitTupleParameterToNodeList(node->cast<ParameterPtr>());
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {
|
|
|
|
|
mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
|
|
|
|
for (size_t index = 0; index < tuple_abstract->size(); ++index) {
|
|
|
|
|
make_tuple_inputs.push_back(TransValueNodeTuple((*tuple_abstract)[index], (*value_tuple)[index]));
|
|
|
|
|
}
|
|
|
|
|
return SplitTupleValueNodeToNodeList(node->cast<ValueNodePtr>());
|
|
|
|
|
auto make_tuple = NewCNode(make_tuple_inputs);
|
|
|
|
|
make_tuple->set_abstract(tuple_abstract);
|
|
|
|
|
return make_tuple;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node);
|
|
|
|
|
auto node_value = value_node->value();
|
|
|
|
|
std::vector<AnfNodePtr> convert_inputs;
|
|
|
|
|
if (!node_value->isa<ValueTuple>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
|
|
|
|
|
}
|
|
|
|
|
auto value_tuple = node_value->cast<ValueTuplePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_tuple);
|
|
|
|
|
auto abstract = value_node->abstract();
|
|
|
|
|
AnfNodePtr KernelGraph::TransParameterTuple(const AbstractBasePtr &abstract) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract);
|
|
|
|
|
if (!abstract->isa<abstract::AbstractTuple>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Spilted node's output abstract is not type tuple";
|
|
|
|
|
return NewParameter(abstract);
|
|
|
|
|
}
|
|
|
|
|
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
|
|
|
|
if (tuple_abstract->size() != value_tuple->size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The node output index [" << value_tuple->size() << "]is outof range "
|
|
|
|
|
<< tuple_abstract->size();
|
|
|
|
|
}
|
|
|
|
|
for (size_t index = 0; index < value_tuple->value().size(); ++index) {
|
|
|
|
|
auto new_value_node = std::make_shared<ValueNode>(value_tuple->value()[index]);
|
|
|
|
|
new_value_node->set_abstract((*tuple_abstract)[index]);
|
|
|
|
|
AddValueNodeToGraph(new_value_node);
|
|
|
|
|
SetKernelInfoForNode(new_value_node);
|
|
|
|
|
AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
|
|
|
|
|
convert_inputs.emplace_back(new_value_node);
|
|
|
|
|
}
|
|
|
|
|
if (!RemoveValueNodeFromGraph(value_node)) {
|
|
|
|
|
MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString();
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {
|
|
|
|
|
mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
|
|
|
|
|
for (size_t index = 0; index < tuple_abstract->size(); ++index) {
|
|
|
|
|
make_tuple_inputs.push_back(TransParameterTuple((*tuple_abstract)[index]));
|
|
|
|
|
}
|
|
|
|
|
return convert_inputs;
|
|
|
|
|
auto make_tuple = NewCNode(make_tuple_inputs);
|
|
|
|
|
make_tuple->set_abstract(tuple_abstract);
|
|
|
|
|
return make_tuple;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node);
|
|
|
|
|
auto new_value_node = MakeValueNode(value_node)->cast<ValueNodePtr>();
|
|
|
|
|
AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
|
|
|
|
|
return new_value_node;
|
|
|
|
|
AnfNodePtr KernelGraph::CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx) {
|
|
|
|
|
auto idx = mindspore::NewValueNode(SizeToInt(output_idx));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(idx);
|
|
|
|
|
auto imm = std::make_shared<Int32Imm>(SizeToInt(output_idx));
|
|
|
|
|
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
|
|
|
|
|
idx->set_abstract(abstract_scalar);
|
|
|
|
|
AnfNodePtr tuple_getitem = NewCNode({mindspore::NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
|
|
|
|
tuple_getitem->set_scope(node->scope());
|
|
|
|
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
|
|
|
|
|
TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx);
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
|
|
|
|
|
return tuple_getitem;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr KernelGraph::TransCNodeTuple(const CNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
std::vector<TypeId> types;
|
|
|
|
|
std::vector<std::vector<size_t>> shapes;
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs_list = {mindspore::NewValueNode(prim::kPrimMakeTuple)};
|
|
|
|
|
for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(node); ++tuple_out_index) {
|
|
|
|
|
make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(node, tuple_out_index));
|
|
|
|
|
types.push_back(AnfAlgo::GetOutputInferDataType(node, tuple_out_index));
|
|
|
|
|
shapes.emplace_back(AnfAlgo::GetOutputInferShape(node, tuple_out_index));
|
|
|
|
|
}
|
|
|
|
|
auto make_tuple = NewCNode(make_tuple_inputs_list);
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get());
|
|
|
|
|
return make_tuple;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr KernelGraph::TransTupleToMakeTuple(const AnfNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
if (!AnfAlgo::IsTupleOutput(node)) {
|
|
|
|
|
return node;
|
|
|
|
|
}
|
|
|
|
|
if (node->isa<Parameter>()) {
|
|
|
|
|
return TransParameterTuple(node->abstract());
|
|
|
|
|
} else if (node->isa<ValueNode>()) {
|
|
|
|
|
auto value_node = node->cast<ValueNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(value_node);
|
|
|
|
|
auto make_tuple = TransValueNodeTuple(value_node->abstract(), value_node->value());
|
|
|
|
|
if (RemoveValueNodeFromGraph(value_node)) {
|
|
|
|
|
MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
return make_tuple;
|
|
|
|
|
} else if (node->isa<CNode>()) {
|
|
|
|
|
return TransCNodeTuple(node->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unexpected node:" << node->DebugString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<AnfNodePtr> &KernelGraph::inputs() const {
|
|
|
|
@ -782,6 +814,23 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter) {
|
|
|
|
|
// update graph inputs
|
|
|
|
|
MS_EXCEPTION_IF_NULL(old_parameter);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_parameter);
|
|
|
|
|
if (old_parameter == new_parameter) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < inputs_->size(); i++) {
|
|
|
|
|
if ((*inputs_)[i] == old_parameter) {
|
|
|
|
|
MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_parameter->DebugString()
|
|
|
|
|
<< ",new graph input:" << new_parameter->DebugString();
|
|
|
|
|
(*inputs_)[i] = new_parameter;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodePtr> new_anf_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inputs_);
|
|
|
|
|
{
|
|
|
|
@ -805,15 +854,7 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP
|
|
|
|
|
output_cnode->set_input(i, new_anf_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// update graph inputs
|
|
|
|
|
for (size_t i = 0; i < inputs_->size(); i++) {
|
|
|
|
|
if ((*inputs_)[i] == old_anf_node.get()) {
|
|
|
|
|
MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString()
|
|
|
|
|
<< ",new graph input:" << new_anf_node->DebugString();
|
|
|
|
|
(*inputs_)[i] = new_anf_node.get();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ReplaceGraphInput(old_anf_node, new_anf_node);
|
|
|
|
|
}
|
|
|
|
|
// update front to backend map
|
|
|
|
|
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
|
|
|
|
|