[auto-monad] Support side-effects by auto-monad

The basic idea is: exploits data dependency to control the execution order
of side-effect operations, and keep the semantics of ANF unchanged.

The ControlDepend primitive is removed and there are two primitives added:

1. UpdateState:
```
  a = Assign(para, value)
```
became:
```
  a = Assign(para, value, u)
  u = UpdateState(u, a)
```

2. Load:
```
  x = Add(para, value)
```
became:
```
  p = Load(para, u)
  x = Add(p, value)
  u = UpdateState(u, p)
```
pull/12241/head
He Wei 4 years ago
parent f0a9cb7c20
commit 7d9a783993

@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.14.1)
cmake_minimum_required(VERSION 3.14.0)
project(MindSpore)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0)
@ -14,18 +14,25 @@ endif()
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(CMAKE_OSX_SYSROOT "")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Winconsistent-missing-override -Wuser-defined-warnings -Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare -Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move -Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Winconsistent-missing-override -Wuser-defined-warnings \
-Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare \
-Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move \
-Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
else()
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined \
-DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
endif()
if(ENABLE_PYTHON)
add_compile_definitions(ENABLE_PYTHON)
endif()
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D_LIBCPP_DISABLE_EXTERN_TEMPLATE=1 -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp")
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer \
-Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D_LIBCPP_DISABLE_EXTERN_TEMPLATE=1 \
-DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 -Werror -Wall -Wno-deprecated-declarations -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 \
-Werror -Wall -Wno-deprecated-declarations -fPIC")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(PYBIND11_CPP_STANDARD -std=c++17)

@ -132,6 +132,16 @@ def Depend(value, expr):
return value
def UpdateState(monad, expr):
"""Implement `UpdateState`."""
return monad
def Load(value, u=None):
"""Implement `Load`."""
return value
# only used in PyNative mode
def make_ref(key, value, ref):
return value

@ -42,14 +42,16 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid) {
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_format.emplace_back(kOpFormat_DEFAULT);
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
}
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_format.emplace_back(kOpFormat_DEFAULT);
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}

@ -139,9 +139,9 @@ bool CheckCache(const std::string &kernel_name) {
std::string kernel_json = bin_map->Search(kernel_name);
bool ret = (!kernel_json.empty());
if (ret) {
MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registed.";
MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered.";
} else {
MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registed.";
MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered.";
}
return ret;
}
@ -730,30 +730,6 @@ bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann:
return false;
}
void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node_list);
auto output = func_graph->output();
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::IsRealKernel(output)) {
// single output.
node_list->push_back(std::make_pair(output, 0));
return;
} else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
auto output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
// multi output.
auto &inputs = output_cnode->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0);
node_list->push_back(in_with_idx);
}
return;
}
MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2)
<< " of graph: " << func_graph->ToString();
}
bool IsWeightBoundary(const AnfNodePtr &node) {
if (node->isa<ValueNode>()) {
return true;
@ -776,7 +752,7 @@ std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(primitive);
auto axis_attr = primitive->GetAttr(kAxis);
if (axis_attr == nullptr) {
MS_LOG(ERROR) << "This node does't have axie attr.";
MS_LOG(ERROR) << "This node doesn't have axie attr.";
return std::vector<int64_t>();
}
std::vector<int64_t> axis_list;

@ -181,7 +181,8 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
std::vector<size_t> out_shape;
out_shape.emplace_back(miss_count);
std::vector<TypeId> dtypes;
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) {
size_t output_num = AnfAlgo::GetOutputTensorNum(node_);
for (size_t i = 0; i < output_num; i++) {
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i));
}
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape},

@ -69,7 +69,8 @@ void SubAndFilterCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
std::vector<size_t> out_shape;
out_shape.emplace_back(count);
std::vector<TypeId> dtypes;
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) {
size_t output_num = AnfAlgo::GetOutputTensorNum(node_);
for (size_t i = 0; i < output_num; i++) {
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i));
}
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {out_shape, out_shape}, node_.get());

@ -29,5 +29,8 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
Assign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
AssignGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
Assign, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
AssignGpuKernel, int64_t)
} // namespace kernel
} // namespace mindspore

@ -63,13 +63,15 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
for (const auto &type : kHcclSupportTypes) {
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index));
inputs_type.push_back(type);
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
if (op_name == kReduceScatter && AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrFusion) > 0) {
outputs_format.emplace_back(GetKernelFormat(kernel_node, 0));
} else {

@ -31,7 +31,8 @@ bool IsPyNativeMode() {
bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_intput_shape_list) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) {
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
for (size_t i = 0; i < input_num; ++i) {
std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i);
hccl_kernel_intput_shape_list->emplace_back(shape_i);
}
@ -42,7 +43,8 @@ bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<siz
bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_output_shape_list) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(hccl_kernel_output_shape_list);
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node); ++i) {
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
for (size_t i = 0; i < output_num; ++i) {
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
hccl_kernel_output_shape_list->emplace_back(shape_i);
}
@ -53,11 +55,12 @@ bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<si
bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(data_type_list);
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) {
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
for (size_t i = 0; i < input_num; ++i) {
auto type_ptr = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, i);
auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(type_ptr);
if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) {
MS_LOG(EXCEPTION) << "HcomDataType cann't support Current Ascend Data Type : " << type_ptr;
MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_ptr;
}
data_type_list->emplace_back(iter->second);
}

@ -37,13 +37,15 @@ void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
inputs_format.emplace_back(kOpFormat_DEFAULT);
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
outputs_format.emplace_back(kOpFormat_DEFAULT);
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}

@ -30,7 +30,7 @@ std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const {
if (output_index >= outputs_format_.size()) {
MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node";
MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
return kInvalidFormat;
}
return outputs_format_[output_index];

@ -86,6 +86,9 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernel
builder.SetProcessor(AICORE);
builder.SetKernelType(RT_KERNEL);
builder.SetFusionType(OPAQUE);
// LabelSwitch always return UMonad.
builder.SetOutputsFormat({kOpFormat_DEFAULT});
builder.SetOutputsDeviceType({TypeId::kObjectTypeUMonad});
label_switch_build_info.emplace_back(builder.Build());
}
return label_switch_build_info;

@ -74,11 +74,10 @@ void GetRtKelInfo(const CNodePtr &kernel_node,
input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i));
}
kernel_build_info_builder->SetInputsDeviceType(input_types);
// set output info
auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_num, kOpFormat_DEFAULT));
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>(output_num, TypeId::kTypeUnknown));
// set ohter info
// Kernel ops in while-list such as 'LabelSet' always return UMonad.
kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
kernel_build_info_builder->SetOutputsDeviceType({TypeId::kObjectTypeUMonad});
// set other info
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);

@ -1052,10 +1052,16 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i
auto node_name = AnfAlgo::GetCNodeName(cnode);
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(node_name, cnode);
MS_EXCEPTION_IF_NULL(cnode);
if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) {
auto node_inputs_size = cnode->inputs().size();
for (auto &input : cnode->inputs()) {
if (HasAbstractMonad(input)) {
node_inputs_size--;
}
}
if (op_info->inputs_ptr().size() < (node_inputs_size - 1)) {
MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope();
}
return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size());
return (op_info->inputs_ptr().size() + 1 - node_inputs_size);
}
std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) {
@ -1103,6 +1109,9 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
bool is_dynamic_input = IsDynamicInput(cnode);
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto input = cnode->input(i);
if (HasAbstractMonad(input)) {
continue;
}
auto kernel_idx = AnfAlgo::VisitKernel(input, 0);
auto real_node = kernel_idx.first;
size_t real_idx = kernel_idx.second;

@ -112,6 +112,10 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(node);
auto input_node = AnfAlgo::GetInputNode(node, index);
if (HasAbstractMonad(input_node)) {
// No transfer for monad inputs.
return input_node;
}
auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
MS_EXCEPTION_IF_NULL(node_with_index.first);
auto real_input = node_with_index.first;
@ -330,8 +334,9 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads.
for (size_t input_index = 0; input_index < in_num; ++input_index) {
// Monad inputs keep unchanged from GetTransInputNodePtr().
AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select);
MS_EXCEPTION_IF_NULL(input_node);
new_inputs.push_back(input_node);
@ -352,12 +357,18 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt
CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads.
for (size_t input_index = 0; input_index < in_num; ++input_index) {
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
if (HasAbstractMonad(cur_input)) {
// No cast for monad inputs.
new_inputs.push_back(cur_input);
continue;
}
auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index);
const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
TypeId origin_type(kTypeUnknown);
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0);
auto real_input_node = kernel_with_index.first;
if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {

@ -244,6 +244,7 @@ void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph,
if (auto in = cnode->input(idx); std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(),
(*buffer_fusion_infos)[fusion_id].inputs_list.end(),
in) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) {
if (!HasAbstractMonad(in)) {
(*buffer_fusion_infos)[fusion_id].inputs_list.push_back(in);
}
}
@ -251,6 +252,7 @@ void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph,
}
}
}
}
bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) {
MS_EXCEPTION_IF_NULL(node1);

@ -40,62 +40,6 @@ bool IsParameterOrValueNode(const AnfNodePtr &node) {
return real_node->isa<ValueNode>();
}
void SetInput(const CNodePtr &control_depend, const int index, const FuncGraphPtr &graph, const CNodePtr &hccl_node,
const std::vector<AnfNodePtr> &memcpy_async_list) {
MS_EXCEPTION_IF_NULL(control_depend);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(hccl_node);
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end());
make_tuple_inputs.emplace_back(hccl_node);
auto make_tuple = graph->NewCNode(make_tuple_inputs);
MS_EXCEPTION_IF_NULL(make_tuple);
control_depend->set_input(IntToSize(index), make_tuple);
}
void DealControlForGetitem(const CNodePtr &tuple_getitem, const FuncGraphPtr &graph, const CNodePtr &hccl_node,
const std::vector<AnfNodePtr> &memcpy_async_list) {
MS_EXCEPTION_IF_NULL(tuple_getitem);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto iter = node_users.find(tuple_getitem);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager"
<< " trace: " << trace::DumpSourceLines(hccl_node);
}
for (const auto &node_index : iter->second) {
AnfNodePtr output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list);
}
}
}
void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list,
const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(hccl_node);
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto iter = node_users.find(hccl_node);
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager"
<< " trace: " << trace::DumpSourceLines(hccl_node);
}
// find hccl_node's output which is a control depend
for (const auto &node_index : iter->second) {
AnfNodePtr output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list);
} else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimTupleGetItem)) {
DealControlForGetitem(output->cast<CNodePtr>(), graph, hccl_node, memcpy_async_list);
}
}
}
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) {
if (node_users.size() == 1) {
@ -155,7 +99,7 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(hccl_node);
std::vector<AnfNodePtr> memcpy_async_list;
bool need_memcpy_async = false;
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
for (size_t i = 1; i < hccl_node->size(); ++i) {
auto input = hccl_node->input(i);
@ -164,17 +108,17 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
if (memcpy_async == nullptr) {
MS_LOG(EXCEPTION) << "Create memcpy_async op failed.";
}
if (AnfAlgo::IsNodeDynamicShape(input)) {
if (input->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(input)) {
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async);
}
new_inputs.push_back(memcpy_async);
memcpy_async_list.push_back(memcpy_async);
need_memcpy_async = true;
} else {
new_inputs.push_back(input);
}
}
if (!memcpy_async_list.empty()) {
if (need_memcpy_async) {
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
new_hccl_node->set_inputs(new_inputs);
auto manager = graph->manager();
@ -182,9 +126,6 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node";
(void)manager->Replace(hccl_node, new_hccl_node);
MS_LOG(DEBUG) << "end replace";
// transer hccl op's control to the memcpy_async
TransferControl(new_hccl_node, memcpy_async_list, graph);
}
}

@ -57,7 +57,7 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph
return nullptr;
}
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
for (size_t input_idx = 0; input_idx < AnfAlgo::GetInputTensorNum(cnode); input_idx++) {
for (size_t input_idx = 0; input_idx < input_num; input_idx++) {
auto cur_input = AnfAlgo::GetInputNode(cnode, input_idx);
auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx);
auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_idx);

@ -40,7 +40,8 @@ const AnfNodePtr ConvertCastFormat::Process(const FuncGraphPtr &func_graph, cons
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
auto input_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, input_index), 0).first;
MS_EXCEPTION_IF_NULL(input_node);
if (!input_node->isa<CNode>()) {
@ -77,7 +78,8 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr
MS_EXCEPTION_IF_NULL(node_info.first);
auto cast_out_node = node_info.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cast_out_node);
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cast_out_node); ++index) {
size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node);
for (size_t index = 0; index < input_num; ++index) {
if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first !=
cast_node) {
continue;

@ -162,7 +162,8 @@ CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_
std::vector<AnfNodePtr> make_tuple_inputs;
AbstractBasePtrList abstract_list;
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
CNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index);
// deal with ref output
if (ref_infos.count(output_index) != 0) {

@ -37,7 +37,7 @@ const AnfNodePtr DynamicRNNGradReformat::Process(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(node);
auto split_v = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(split_v);
auto matmul = CheckAnfNodeIfCNodeAndInputSize(split_v->input(1), 3);
auto matmul = CheckAnfNodeIfCNodeAndInputSize(split_v->input(1), kMatMulInputTensorNum);
MS_EXCEPTION_IF_NULL(matmul);
auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(matmul, 0);
auto input_node = input_node_with_idx.first;

@ -129,9 +129,21 @@ AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr
auto mng = sub_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::vector<AnfNodePtr> todo;
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
kernel::GetValidKernelNodes(sub_graph, &todo);
kernel::GetGraphRealOutput(sub_graph, &graph_rets);
auto outputs = AnfAlgo::GetAllOutput(sub_graph->output(), {prim::kPrimTupleGetItem});
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
for (auto &output : outputs) {
size_t index = 0;
if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
ValuePtr tuple_index_value = GetValueNode(output->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
MS_EXCEPTION_IF_NULL(tuple_index_value);
if (!tuple_index_value->isa<Int64Imm>()) {
MS_LOG(EXCEPTION) << "The index of tuple getitem is not int64";
}
index = tuple_index_value->cast<Int64ImmPtr>()->value();
}
graph_rets.emplace_back(std::pair<AnfNodePtr, size_t>(output, index));
}
for (auto &t : todo) {
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t);
// process input

@ -33,7 +33,8 @@ bool RunOpInsertTransData::Run(const FuncGraphPtr &graph) {
continue;
}
auto cnode = node->cast<CNodePtr>();
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t index = 0; index < input_num; ++index) {
auto prev_input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
auto prev_node_out_infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index);
auto input_format = AnfAlgo::GetInputFormat(cnode, index);

@ -28,8 +28,6 @@
namespace mindspore {
namespace opt {
namespace {
const size_t kCastInputNum = 2;
const size_t kTupleGetitemInputNum = 3;
bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type, const size_t change_idx,
const std::shared_ptr<kernel::KernelBuildInfo> &candidate_kernel_info) {
if (node == nullptr || node->kernel_info() == nullptr || candidate_kernel_info == nullptr) {
@ -126,7 +124,8 @@ void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size
auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0);
std::vector<Shape> shapes;
std::vector<TypeId> types;
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) {
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t index = 0; index < output_num; ++index) {
if (cast_index == index) {
shapes.emplace_back(cast_shape);
types.emplace_back(cast_dtype);
@ -175,7 +174,7 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
<< "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info"
<< (*alternative_kernel_info)->ToString();
AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get());
if (node->inputs().size() < kCastInputNum) {
if (AnfAlgo::GetInputTensorNum(node) < kCastInputTensorNum) {
MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:";
}
return node->input(1);
@ -188,9 +187,7 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu
*prior_op = x_cnode;
// when x_node is tuple_getitem
if (AnfAlgo::GetCNodeName(x_node) == prim::kPrimTupleGetItem->name()) {
if (x_cnode->inputs().size() < kTupleGetitemInputNum) {
MS_LOG(EXCEPTION) << "tuple getitem node has wrong input num" << x_cnode->inputs().size();
}
CheckCNodeInputSize(x_cnode, kTupleGetItemInputTensorNum);
MS_EXCEPTION_IF_NULL(output_idx);
AnfNodePtr input1 = x_cnode->input(1);
MS_EXCEPTION_IF_NULL(input1);
@ -214,9 +211,7 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu
AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_node, const KernelQueryPtr kernel_query) {
MS_EXCEPTION_IF_NULL(cur_node);
MS_EXCEPTION_IF_NULL(kernel_query);
if (cur_node->inputs().size() < kCastInputNum) {
MS_LOG(EXCEPTION) << "op[Cast] has wrong input num:";
}
CheckCNodeInputSize(cur_node, kCastInputTensorNum);
AnfNodePtr x_node = cur_node->input(1);
if (IsUsedByOthers(graph, x_node)) {
return nullptr;

@ -69,7 +69,7 @@ const AnfNodePtr RemoveInternalOutput::Process(const FuncGraphPtr &func_graph, c
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
CheckCNodeInputSize(cnode, kTransOpInputNum);
CheckCNodeInputSize(cnode, kTransOpInputTensorNum);
auto input_node = cnode->input(1);
if (!AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimTupleGetItem)) {
kernel_graph->ReplaceInternalOutput(node, input_node);

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save