diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/cpu_kernel.h index 0836529840..5837f922b5 100644 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/kernel/cpu/cpu_kernel.h @@ -55,7 +55,7 @@ class CPUKernel : public kernel::KernelMod { public: CPUKernel() = default; ~CPUKernel() override = default; - void Init(const CNodePtr &kernel_node); + virtual void Init(const CNodePtr &kernel_node); virtual void InitKernel(const CNodePtr &kernel_node) = 0; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void * /*stream_ptr*/) override { diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h b/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h index 52eda12ba7..aebcc15d6a 100644 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h +++ b/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h @@ -62,10 +62,12 @@ class CPUKernelRegistrar { static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \ []() { return std::make_shared(); }); -#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) \ +#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) MS_REG_CPU_KERNEL_T_(__COUNTER__, OPNAME, ATTR, OPCLASS, T) +#define MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) +#define _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) \ static_assert(std::is_base_of>::value, " must be base of CPUKernel"); \ - static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_reg(#OPNAME, ATTR, \ - []() { return std::make_shared>(); }); + static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_reg( \ + #OPNAME, ATTR, []() { return std::make_shared>(); }); #define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \ static_assert(std::is_base_of>::value, " must be base of CPUKernel"); \ diff --git a/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.cc b/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.cc index 16420b433a..26cc42685f 100644 --- a/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.cc @@ -46,24 +46,10 @@ void SparseApplyFtrlPSKernel::InitKernel( if (grad_shape[0] != indices_size_) { MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; } - /* - lr_ = AnfAlgo::GetNodeAttr(kernel_node, "lr"); - if (lr_ <= 0) { - MS_LOG(EXCEPTION) << "lr should be a positive scalar"; - } - l1_ = AnfAlgo::GetNodeAttr(kernel_node, "l1"); - if (l1_ < 0) { - MS_LOG(EXCEPTION) << "l1 should be a non-negative scalar"; - } - l2_ = AnfAlgo::GetNodeAttr(kernel_node, "l2"); - if (l2_ < 0) { - MS_LOG(EXCEPTION) << "l2 should be a non-negative scalar"; - } - lr_power_ = AnfAlgo::GetNodeAttr(kernel_node, "lr_power"); - if (lr_power_ > 0) { - MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar"; - } - */ + lr_ = 0.01; + l1_ = 1e-8; + l2_ = 1e-8; + lr_power_ = -0.5; workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); } diff --git a/mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.cc b/mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.cc new file mode 100644 index 0000000000..fd342ec43c --- /dev/null +++ b/mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/pass/replace_node_by_proxy.h" +#include +#include +#include "device/kernel_info.h" +#include "session/anf_runtime_algorithm.h" +#include "kernel/kernel_build_info.h" + +namespace mindspore { +namespace opt { +kernel::KernelBuildInfoPtr ReplaceNodeByProxy::GenerateKernelBuildInfo(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector inputs_device_format; + std::vector outputs_device_format; + std::vector inputs_device_type; + std::vector outputs_device_type; + std::vector> outputs_shape; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); + inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); + outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); + } + builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); + builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); + builder.SetKernelType(AnfAlgo::GetKernelType(cnode)); + + builder.SetInputsFormat(inputs_device_format); + builder.SetOutputsFormat(outputs_device_format); + builder.SetInputsDeviceType(inputs_device_type); + builder.SetOutputsDeviceType(outputs_device_type); + return builder.Build(); +} + +bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector node_list = TopoSort(func_graph->get_return()); + for (auto node : node_list) { + if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) { + CNodePtr cnode = node->cast(); + auto prim = std::make_shared(kEmbeddingLookupProxyOpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector proxy_inputs = {NewValueNode(prim)}; + proxy_inputs.insert(proxy_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); + AnfNodePtr proxy_node = func_graph->NewCNode(proxy_inputs); + MS_EXCEPTION_IF_NULL(proxy_node); + + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + proxy_node->set_kernel_info(kernel_info); + + AbstractBasePtrList abstract_list; + AnfAlgo::CopyNodeAttr(kAttrPsKey, cnode, proxy_node); + AnfAlgo::CopyNodeAttr("reduce_scatter_flag", cnode, proxy_node); + AnfAlgo::CopyNodeAttr("offset", cnode, proxy_node); + abstract_list.push_back(cnode->abstract()); + auto abstract_tuple = std::make_shared(abstract_list); + MS_EXCEPTION_IF_NULL(abstract_tuple); + proxy_node->set_abstract(abstract_tuple); + + auto kernel_build_info = GenerateKernelBuildInfo(cnode); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, proxy_node.get()); + + if (!manager->Replace(cnode, proxy_node)) { + MS_LOG(EXCEPTION) << "Replace node by proxy node failed."; + } + } + } + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.h b/mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.h new file mode 100644 index 0000000000..2549501a0a --- /dev/null +++ b/mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ +#include +#include +#include + +#include "pre_activate/common/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "utils/utils.h" +#include "kernel/kernel_build_info.h" + +namespace mindspore { +namespace opt { +class ReplaceNodeByProxy : public Pass { + public: + explicit ReplaceNodeByProxy(const std::string &name) : Pass(name) {} + ~ReplaceNodeByProxy() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CNodePtr &cnode); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ diff --git a/mindspore/communication/_comm_helper.py b/mindspore/communication/_comm_helper.py index 508aa2e7a9..5e1f7d06e7 100644 --- a/mindspore/communication/_comm_helper.py +++ b/mindspore/communication/_comm_helper.py @@ -14,7 +14,7 @@ # ============================================================================ """comm_helper""" - +import os from ._hccl_management import load_lib as hccl_load_lib _HCCL_AVAILABLE = False @@ -44,7 +44,7 @@ else: HCCL_WORLD_COMM_GROUP = "hccl_world_group" NCCL_WORLD_COMM_GROUP = "nccl_world_group" - +MS_ROLE = os.getenv("MS_ROLE") class Backend: """ @@ -152,6 +152,9 @@ def _get_rank_helper(group, backend): Integer. The local rank id of the calling process. """ rank_id = None + if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): + rank_id = 0 + return rank_id if backend == Backend.HCCL: if group == HCCL_WORLD_COMM_GROUP: rank_id = hccl.get_rank_id() @@ -211,6 +214,9 @@ def _get_size_helper(group, backend): Integer. The rank size of specified group. """ size = None + if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): + size = 1 + return size if backend == Backend.HCCL: if group == HCCL_WORLD_COMM_GROUP: size = hccl.get_rank_size() diff --git a/mindspore/communication/management.py b/mindspore/communication/management.py index 1cd60fe2e5..3fb4e7b947 100755 --- a/mindspore/communication/management.py +++ b/mindspore/communication/management.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """Communication management API""" +import os from mindspore.parallel._auto_parallel_context import auto_parallel_context from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ @@ -28,6 +29,7 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP DEFAULT_BACKEND = Backend("hccl") +MS_ROLE = os.getenv("MS_ROLE") def _get_group(group): @@ -58,6 +60,8 @@ def init(backend_name="hccl"): TypeError: If backend name is not a string. RuntimeError: If backend is invalid or distributed init fails. """ + if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): + return if not isinstance(backend_name, str): raise TypeError("Backend name must be a string, but got {}".format(type(backend_name)))