diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc index 22574b98c4..22c05f03cb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc @@ -34,8 +34,8 @@ void Cast(const S *in, T *out, size_t size) { template void CastCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - source_dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, 0); - target_dtype = AnfAlgo::GetOutputInferDataType(kernel_node, 0); + source_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + target_dtype = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0); } template @@ -45,7 +45,6 @@ bool CastCPUKernel::Launch(const std::vector &inputs, S *input = reinterpret_cast(inputs[0]->addr); T *output = reinterpret_cast(outputs[0]->addr); MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name(); - size_t lens = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(T)) : 1; Cast(input, output, lens); return true; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/maximum_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/maximum_grad_cpu_kernel.cc index 1dd67b378e..287c68d1ab 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/maximum_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/maximum_grad_cpu_kernel.cc @@ -27,7 +27,7 @@ void MaximumGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { dout_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); dx_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); dy_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1); - dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); if (!x_shape_.size() || !y_shape_.size() || !dout_shape.size()) { MS_LOG(EXCEPTION) << "Input NULL"; } diff --git a/mindspore/ccsrc/backend/optimizer/CMakeLists.txt b/mindspore/ccsrc/backend/optimizer/CMakeLists.txt index 194ae07ee0..e1cfce0605 100644 --- a/mindspore/ccsrc/backend/optimizer/CMakeLists.txt +++ b/mindspore/ccsrc/backend/optimizer/CMakeLists.txt @@ -36,6 +36,11 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") -Wno-overloaded-virtual -Wno-unused-const-variable -Wno-pessimizing-move") endif() +if(ENABLE_CPU) + file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") + list(APPEND _PREACTIVATE_SRC_LIST ${_CPU_SRC_LIST}) +endif() + set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT) add_library(_mindspore_backend_optimizer_obj OBJECT ${_PREACTIVATE_SRC_LIST}) diff --git a/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc b/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc new file mode 100644 index 0000000000..ff7b42e5c4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc @@ -0,0 +1,174 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/cpu/insert_cast_cpu.h" + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "utils/utils.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, + const TypeId &input_type, const TypeId &output_type, + const std::vector &origin_shape, const TypeId &origin_type) { + MS_EXCEPTION_IF_NULL(func_graph); + std::string input_format = format; + std::string output_format = format; + CNodePtr cast = func_graph->NewCNode({NewValueNode(std::make_shared(prim::kPrimCast->name())), input}); + MS_EXCEPTION_IF_NULL(cast); + // set kernel build info + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetInputsFormat({input_format}); + builder.SetOutputsFormat({output_format}); + builder.SetInputsDeviceType({input_type}); + builder.SetOutputsDeviceType({output_type}); + + // if kernel info is null , it remarks this function is running ut + if (cast->kernel_info() == nullptr) { + auto kernel_info = std::make_shared(); + cast->set_kernel_info(kernel_info); + } + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); + AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); + return cast; +} + +AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::vector &need_insert_cast) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + auto kernel_graph = func_graph->cast(); + size_t out_num = AnfAlgo::GetOutputTensorNum(cnode); + for (size_t output_idx = 0; output_idx < out_num; ++output_idx) { + AnfNodePtr replace_node = nullptr; + const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); + const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); + auto idx = NewValueNode(SizeToLong(output_idx)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(output_idx); + idx->set_abstract(std::make_shared(imm)); + auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); + AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get()); + if (need_insert_cast[output_idx]) { + const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx); + const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx); + if (infer_type != device_type) { + replace_node = + AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, infer_type, origin_shape, infer_type); + MS_EXCEPTION_IF_NULL(replace_node); + replace_node->set_scope(cnode->scope()); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) { + kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0); + } + } + } + } + return cnode; +} + +void InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + size_t in_num = AnfAlgo::GetInputTensorNum(cnode); + auto kernel_graph = func_graph->cast(); + auto mng = kernel_graph->manager(); + for (size_t input_index = 0; input_index < in_num; ++input_index) { + auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index); + const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second); + auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); + + const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); + const std::vector origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second); + + if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); infer_type != device_type) { + auto cast = + AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, infer_type, device_type, origin_shape, device_type); + MS_EXCEPTION_IF_NULL(cast); + cast->set_scope(cnode->scope()); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast); + mng->Replace(cur_input, cast); + } + } +} + +AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::vector &need_insert_cast) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { + return cnode; + } + MS_EXCEPTION_IF_NULL(cnode->Type()); + auto kernel_graph = func_graph->cast(); + // Single output + if (!cnode->Type()->isa()) { + if (!need_insert_cast[0]) { + return cnode; + } + const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0); + const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0); + + const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0); + AnfNodePtr replace_node = cnode; + if (infer_type != device_type) { + replace_node = + AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, infer_type, origin_shape, infer_type); + MS_EXCEPTION_IF_NULL(replace_node); + replace_node->set_scope(cnode->scope()); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, 0)) { + kernel_graph->ReplaceInternalOutput(cnode, replace_node); + } + } + return replace_node; + } + // Multiple output + return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast); +} +} // namespace + +const BaseRef InsertCastCPU::DefinePattern() const { + VarPtr V = std::make_shared(UnVisited); + VarPtr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +const AnfNodePtr InsertCastCPU::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) { + return nullptr; + } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + // process input + CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + InsertCastForInput(func_graph, cnode); + // process output + return InsertCastForOutput(func_graph, cnode, std::vector(AnfAlgo::GetOutputTensorNum(cnode), true)); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.h b/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.h new file mode 100644 index 0000000000..fcecd63393 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H + +#include +#include "backend/optimizer/common/optimizer.h" +#include "ir/anf.h" + +namespace mindspore { +namespace opt { +class InsertCastCPU : public PatternProcessPass { + public: + explicit InsertCastCPU(bool multigraph = true) : PatternProcessPass("insert_cast_cpu", multigraph) {} + ~InsertCastCPU() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_CAST_CPU_H diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index d7a44dbdaf..fb32f8030b 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -27,7 +27,9 @@ #include "runtime/device/cpu/kernel_select_cpu.h" #include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/cpu/insert_cast_cpu.h" #include "backend/optimizer/pass/replace_node_by_proxy.h" +#include "backend/optimizer/pass/erase_visit_attr.h" #include "debug/anf_ir_dump.h" #include "debug/dump_proto.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) @@ -61,9 +63,21 @@ void CPUSession::Reorder(std::vector *node_list) { AnfAlgo::ReorderPos void CPUSession::Optimize(const std::shared_ptr &kernel_graph) { auto optimizer = std::make_shared(); auto pm = std::make_shared(); - std::string pass_name = "replace_node_by_proxy"; - pass_name.append(std::to_string(graph_sum_)); - pm->AddPass(std::make_shared(pass_name)); +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) { + AssignParamKey(kernel_graph); + if (ps::PSContext::instance()->is_worker()) { + std::string pass_name = "replace_node_by_proxy"; + pass_name.append(std::to_string(graph_sum_)); + pm->AddPass(std::make_shared(pass_name)); + } + } +#endif + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + MS_LOG(INFO) << "insert cast pass"; optimizer->AddPassManager(pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); @@ -77,14 +91,8 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr graph->UpdateGraphDynamicAttr(); MS_LOG(INFO) << "Set kernel info"; SetKernelInfo(graph.get()); -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if (ps::PSContext::instance()->is_ps_mode()) { - AssignParamKey(graph); - if (ps::PSContext::instance()->is_worker()) { - Optimize(graph); - } - } -#endif + MS_LOG(INFO) << "Set kernel info end"; + Optimize(graph); MS_LOG(INFO) << "Build kernel"; BuildKernel(graph.get()); @@ -168,6 +176,7 @@ void CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); MS_EXCEPTION_IF_NULL(kernel_graph); SetKernelInfo(kernel_graph.get()); + Optimize(kernel_graph); BuildKernel(kernel_graph.get()); run_op_graphs_[graph_info] = kernel_graph; } diff --git a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc index 88bf20e2d0..8eec841d39 100644 --- a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc +++ b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc @@ -35,21 +35,6 @@ bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) { return false; } -void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vector &input_not_cnode_indexes, - const CNodePtr kernel_node) { - for (auto &input_index : input_not_cnode_indexes) { - auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first; - MS_EXCEPTION_IF_NULL(input_node); - std::vector output_types; - output_types.emplace_back(kernel_attr.GetInputAttr(input_index).first); - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - builder->SetOutputsFormat({kOpFormat_DEFAULT}); - builder->SetOutputsDeviceType(output_types); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_node.get()); - } -} - void GetOutputInferFormatsAndDtypes(const CNodePtr &kernel_node, std::vector *output_formats, std::vector *output_types) { size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); @@ -142,35 +127,11 @@ std::pair GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr, int format_matched_num = 0; auto input_num = input_types.size(); for (size_t i = 0; i < input_num; ++i) { - bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(), - [i](size_t index) { return index == i; }); - bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size()); - if (have_cnode_input && is_not_cnode_idx) { - data_type_matched_num++; - format_matched_num++; - continue; - } - if (is_not_cnode_idx) { - if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) { - MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first - << ", actual input dtype:" << input_types[i]; - } else { - data_type_matched_num++; - } - format_matched_num++; - continue; - } - if (kernel_attr.GetInputAttr(i).first != input_types[i]) { + if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) { MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first << ", actual input dtype:" << input_types[i]; } else { data_type_matched_num++; - } - - if (kernel_attr.GetInputAttr(i).second != input_formats[i]) { - MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second - << ", actual input format:" << input_formats[i]; - } else { format_matched_num++; } } @@ -320,9 +281,8 @@ void SetKernelInfo(const CNodePtr &kernel_node) { (matched.first || input_types.size() == input_not_cnode_indexes.size())) { MS_LOG(INFO) << "Input format and dtype is matched"; GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &output_formats, &output_types); - UpdatePrevNotCNodeFormatDtype(selected_kernel_attr, input_not_cnode_indexes, kernel_node); - for (auto &input_index : input_not_cnode_indexes) { - input_types[input_index] = selected_kernel_attr.GetInputAttr(input_index).first; + for (size_t i = 0; i < selected_kernel_attr.GetInputSize(); ++i) { + input_types[SizeToInt(i)] = selected_kernel_attr.GetInputAttr(i).first; } } SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get());