Control flow not split graph

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
pull/2931/head
zhoufeng 5 years ago
parent 3e691e54f5
commit 439d6d618f

@ -40,6 +40,9 @@ using kernel::KernelBuildInfoPtr;
using kernel::KernelMod;
using kernel::KernelModPtr;
namespace {
constexpr size_t kNopNodeInputSize = 2;
constexpr size_t kNopNodeRealInputIndex = 1;
std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape);
std::vector<size_t> shape_size_t;
@ -48,6 +51,26 @@ std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
}
} // namespace
AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
MS_EXCEPTION_IF_NULL(tuple_get_item);
if (tuple_get_item->size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem);
}
size_t AnfRuntimeAlgorithm::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
MS_EXCEPTION_IF_NULL(tuple_get_item);
if (tuple_get_item->size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(output_index_value_node);
auto value_node = output_index_value_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
return IntToSize(GetValue<int>(value_node->value()));
}
KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
if (anf_node->isa<ValueNode>()) {
@ -83,49 +106,47 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz
}
}
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, int index,
bool visit_nop_node,
const std::vector<PrimitivePtr> &return_types) {
MS_EXCEPTION_IF_NULL(anf_node);
for (const auto &prim_type : return_types) {
if (CheckPrimitiveType(anf_node, prim_type)) {
return std::make_pair(anf_node, index);
}
if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool {
return CheckPrimitiveType(anf_node, prim_type);
})) {
return KernelWithIndex(anf_node, index);
}
if (anf_node->isa<ValueNode>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<Parameter>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<CNode>()) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input0 = cnode->input(0);
MS_EXCEPTION_IF_NULL(input0);
if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(input2);
auto value_node = input2->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx),
visit_nop_node, return_types);
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types);
} else if (opt::IsNopNode(cnode) && visit_nop_node) {
if (cnode->inputs().size() == 2) {
return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types);
} else {
MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
if (!anf_node->isa<CNode>()) {
return KernelWithIndex(anf_node, 0);
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
auto item_with_index_tmp = VisitKernelWithReturnType(GetTupleGetItemRealInput(cnode),
GetTupleGetItemOutIndex(cnode), visit_nop_node, return_types);
if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) {
MS_EXCEPTION_IF_NULL(item_with_index_tmp.first);
auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
const std::vector<AnfNodePtr> &make_tuple_inputs = make_tuple->inputs();
size_t make_tuple_input_index = item_with_index_tmp.second + 1;
if (make_tuple_input_index >= make_tuple_inputs.size()) {
MS_LOG(EXCEPTION) << "Index[" << make_tuple_input_index << "] out of range[" << make_tuple_inputs.size()
<< "].";
}
} else {
return std::make_pair(anf_node, index);
return VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], 0, visit_nop_node, return_types);
}
} else {
MS_LOG(EXCEPTION) << "The input is invalid";
return item_with_index_tmp;
}
if (CheckPrimitiveType(cnode, prim::kPrimDepend) || CheckPrimitiveType(cnode, prim::kPrimControlDepend)) {
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, visit_nop_node, return_types);
}
if (opt::IsNopNode(cnode) && visit_nop_node) {
if (cnode->size() != kNopNodeInputSize) {
MS_LOG(EXCEPTION) << "Invalid nop node " << cnode->DebugString();
}
return VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, visit_nop_node, return_types);
}
return KernelWithIndex(anf_node, index);
}
std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node,
@ -591,7 +612,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
if (opt::IsNopNode(node) && visit_nop_node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() == 2) {
if (cnode->size() == kNopNodeInputSize) {
return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
} else {
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
@ -613,7 +634,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
if (opt::IsNopNode(node) && visit_nop_node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() == 2) {
if (cnode->inputs().size() == kNopNodeInputSize) {
return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
} else {
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
@ -806,7 +827,7 @@ bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
IsPrimitive(input, prim::kPrimReturn);
IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
return !is_virtual_node;
}
@ -1117,5 +1138,14 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, s
}
return GetCNodeOutputPrecision(kernel_with_index.first);
}
bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode.";
}
auto input = node->input(kAnfPrimitiveIndex);
return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch);
}
} // namespace session
} // namespace mindspore

@ -42,9 +42,12 @@ using DeviceAddress = device::DeviceAddress;
using DeviceAddressPtr = device::DeviceAddressPtr;
class AnfRuntimeAlgorithm {
public:
// get real input node of tuple_get_item
static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item);
static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
// get input_anf_node's real kernel by recurse
static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index,
static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, int output_index,
bool visit_nop_node = false,
const std::vector<PrimitivePtr> &return_types = {
prim::kPrimMakeTuple});
@ -205,6 +208,7 @@ class AnfRuntimeAlgorithm {
static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node);
// get fix output precision from prev node, input_idx is the input index of current node related to prev node.
static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
static bool IsCondControlKernel(const CNodePtr &node);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;

File diff suppressed because it is too large Load Diff

@ -20,6 +20,8 @@
#include <map>
#include <vector>
#include <tuple>
#include <utility>
#include <functional>
#include "backend/session/kernel_graph.h"
#include "utils/base_ref.h"
#include "utils/contract.h"
@ -29,16 +31,23 @@ namespace mindspore {
namespace session {
class AscendControlParser {
public:
static void ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map);
static void LinkGraph(NotNull<KernelGraphPtr> kg);
static void InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node);
static void InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
NotNull<AnfNodePtr> second_node);
static void ExecutorValidate(NotNull<KernelGraphPtr> root_graph);
static void UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, const AnfNodePtr &jump_node,
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
private:
class ReferenceCounter;
static void EraseParameter(NotNull<KernelGraphPtr> root_graph, const std::set<KernelGraphPtr> &graph_list);
static void EraseLabel(NotNull<KernelGraphPtr> root_graph);
static void ChildGraphDataAssign(NotNull<KernelGraphPtr> kg,
const NotNull<std::vector<std::pair<AnfNodePtr, AnfNodePtr>> *> link_list,
const NotNull<std::set<KernelGraphPtr> *> memo);
static NotNull<CNodePtr> GetStartLabel(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
const CNodePtr &last_label);
static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
@ -53,11 +62,10 @@ class AscendControlParser {
static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
const CNodePtr &last_label);
static KernelGraphPtr ParsePartial(NotNull<AnfNodePtr> node);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static std::vector<std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>> ParseCallNode(NotNull<CNodePtr> call_node);
static std::tuple<KernelGraphPtr, std::vector<AnfNodePtr>> ParsePartial(NotNull<AnfNodePtr> node);
// root graph order
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
@ -65,6 +73,19 @@ class AscendControlParser {
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo);
};
class AscendControlParser::ReferenceCounter {
public:
explicit ReferenceCounter(std::function<bool(int32_t, int32_t)> func) : predicate_(func), count_() {}
void AddReadCount(const AnfNodePtr &key, int32_t num);
void AddWriteCount(const AnfNodePtr &key, int32_t num);
void EraseElem(const AnfNodePtr &key);
bool HasValidElem() const;
std::tuple<AnfNodePtr, int32_t, int32_t> GetOneValidElem() const;
private:
std::function<bool(int32_t, int32_t)> predicate_;
std::map<AnfNodePtr, std::pair<int32_t, int32_t>> count_;
};
} // namespace session
} // namespace mindspore

File diff suppressed because it is too large Load Diff

@ -151,6 +151,15 @@ class AscendSession : public SessionBasic {
// sync intial tensors' data to device
void SyncInitialTenosrToDevice();
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
// create parameter to receive data from multiple branch output
void CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
void SelectKernel(NotNull<KernelGraphPtr> root_graph);
void RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> const memo,
size_t *const raise_precision_count, size_t *const reduce_precision_count) const;
void IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
void HardwareOptimize(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void AssignStaticMemory(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
void UpdateRefOutputMap(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) const;
// member variables
// key is final_graph_id,value is child graph execute order of final graph

@ -616,8 +616,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
depend_mode = AnfAlgo::GetNodeAttr<int>(cnode, kControlDependMode);
}
MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString()
<< "], depend_mode :" << depend_mode << ".";
MS_LOG(DEBUG) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString()
<< "], depend_mode :" << depend_mode << ".";
if (prior_node->isa<Parameter>() && depend_mode == 1) {
prior_nodes = GetOutputNodes(prior_node);
}
@ -647,7 +647,8 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
}
MS_EXCEPTION_IF_NULL(first_node);
MS_EXCEPTION_IF_NULL(second_node);
MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString();
MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString()
<< ",second node:" << second_node->DebugString();
AddDependEdge(second_node, first_node, 1);
}
}
@ -991,6 +992,30 @@ bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const {
return false;
}
void KernelGraph::UpdateChildGraphOrder() {
MS_LOG(INFO) << "Update " << ToString() << " child graph order.";
SetExecOrderByDefault();
auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
std::vector<KernelGraphPtr> child_graph_order;
for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node);
auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
for (const auto &child_graph : call_child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph);
if (child_graph != parent_graph_) {
auto shared_this = std::dynamic_pointer_cast<KernelGraph>(shared_from_this());
MS_EXCEPTION_IF_NULL(shared_this);
child_graph->set_parent_graph(shared_this);
}
child_graph_order.push_back(child_graph);
}
}
for (size_t i = 0; i < child_graph_order.size(); ++i) {
MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]";
}
child_graph_order_ = child_graph_order;
}
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); }

@ -156,6 +156,12 @@ class KernelGraph : public FuncGraph {
bool IsFinalOutputKernel(const AnfNodePtr &node) const;
uint32_t current_epoch() const { return current_epoch_; }
void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }
void UpdateChildGraphOrder();
const std::vector<AnfNodePtr> &child_graph_result() const { return child_graph_result_; }
void AddChildGraphResult(const AnfNodePtr &parameter) { child_graph_result_.push_back(parameter); }
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
child_graph_result_ = child_graph_result;
}
private:
// remove value node form graph
@ -173,6 +179,7 @@ class KernelGraph : public FuncGraph {
void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends);
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
std::vector<AnfNodePtr> child_graph_result_;
std::vector<CNodePtr> execution_order_;
uint32_t graph_id_;
uint32_t stream_distinction_label_;

@ -74,7 +74,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
return input_tensors[input_idx];
}
}
MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr";
MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr";
}
}
// if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
@ -107,8 +107,8 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
return tensor;
}
BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(anf);
MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
@ -120,7 +120,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
MS_EXCEPTION_IF_NULL(cnode);
VectorRef ret;
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors);
auto out = CreateTensorForOutput(cnode->input(i), graph, input_tensors);
ret.push_back(out);
}
return ret;
@ -133,25 +133,6 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
}
BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(anf);
if (!AnfAlgo::IsRealKernel(anf)) {
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel";
}
if (anf->isa<ValueNode>()) {
return CreateOneTensor(anf, 0, graph, input_tensors);
}
VectorRef ret;
if (anf->isa<CNode>() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) {
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) {
auto out = CreateOneTensor(anf, i, graph, input_tensors);
ret.emplace_back(out);
}
}
return ret;
}
ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
MS_EXCEPTION_IF_NULL(graph);
@ -880,20 +861,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
const std::vector<tensor::TensorPtr> &input_tensors) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs);
if (!kernel_graph->child_graph_order().empty()) {
// use the last child graph output as the root graph output
UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors);
return;
}
auto anf_outputs = kernel_graph->outputs();
for (auto &item : anf_outputs) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
continue;
}
outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors));
outputs->emplace_back(CreateTensorForOutput(item, *kernel_graph, input_tensors));
}
}

@ -294,6 +294,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(mem_manager_);
auto graph_inputs = graph->inputs();
auto graph_valid_input = graph->valid_inputs();
graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
std::vector<AnfNodePtr> need_alloc_nodes;
for (size_t i = 0; i < graph_inputs.size(); ++i) {
auto item = graph_inputs[i];

@ -240,6 +240,7 @@ constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
constexpr auto kAttrOffset = "offset";
constexpr auto kAttrPsKey = "ps_key";
constexpr auto kAttrOptimizerType = "optim_type";
constexpr auto kAttrChildGraph = "child_graph";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";

Loading…
Cancel
Save