diff --git a/build.sh b/build.sh index 0c4d1ff90c..dfed66aadf 100755 --- a/build.sh +++ b/build.sh @@ -49,7 +49,7 @@ usage() echo " -Q Enable dump memory, default off" echo " -D Enable dumping of function graph ir, default on" echo " -z Compile dataset & mindrecord, default on" - echo " -M Enable MPI and NCCL for GPU training, default on" + echo " -M Enable MPI and NCCL for GPU training, gpu default on" echo " -V Specify the minimum required cuda version, default CUDA 9.2" echo " -I Compile predict, default off" echo " -K Compile with AKG, default off" diff --git a/mindspore/ccsrc/device/CMakeLists.txt b/mindspore/ccsrc/device/CMakeLists.txt index 9f64d0e02b..7178a01ce6 100644 --- a/mindspore/ccsrc/device/CMakeLists.txt +++ b/mindspore/ccsrc/device/CMakeLists.txt @@ -14,17 +14,19 @@ endif () if (ENABLE_CPU) file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") - if (ENABLE_MPI) - # _ms_mpi - set_property(SOURCE "gpu/mpi/mpi_initializer.cc" - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) - pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc") - target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi) - else () + if (NOT ENABLE_MPI) list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc") endif () endif () +if (ENABLE_MPI) + # _ms_mpi + set_property(SOURCE "gpu/mpi/mpi_initializer.cc" + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) + pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc") + target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi) +endif () + # gpu if (ENABLE_GPU) file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc" "gpu/*.cu") diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc index 9ceb8676b2..b743ecaf4f 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc @@ -25,6 +25,7 @@ #include "device/ascend/ascend_device_address.h" #include "device/cpu/mpi/mpi_adapter.h" #include "utils/context/ms_context.h" +#include "utils/mpi/mpi_config.h" #include "device/ascend/profiling/profiling_manager.h" #include "hccl/hcom.h" #include "common/trans.h" @@ -510,19 +511,35 @@ bool AscendKernelRuntime::HcclInit() { MS_LOG(ERROR) << "file path " << config_path_str << " does not exist"; return false; } + const char *identify = nullptr; #ifdef ENABLE_MPI - int rank_id = device::cpu::MPIAdapter::Instance().GetRankId(); - const char *offset = std::getenv("RANK_OFFSET"); - if (offset != nullptr) { - int rank_offset = std::stoi(offset); - rank_id += rank_offset; + std::string rank_id_tmp; + auto mpi_config_ptr = MpiConfig::GetInstance(); + MS_EXCEPTION_IF_NULL(mpi_config_ptr); + if (mpi_config_ptr->enable_mpi()) { + int rank_id = device::cpu::MPIAdapter::Instance().GetRankId(); + const char *offset = std::getenv("RANK_OFFSET"); + if (offset != nullptr) { + try { + int rank_offset = std::stoi(offset); + rank_id += rank_offset; + } catch (std::invalid_argument) { + MS_LOG(EXCEPTION) << "stoi invalid argument:" << offset; + } catch (std::out_of_range) { + MS_LOG(EXCEPTION) << "stoi out_of_range:" << offset; + } + } + rank_id_tmp = std::to_string(rank_id); + identify = rank_id_tmp.c_str(); + } else { + identify = std::getenv("RANK_ID"); } - const char *identify = reinterpret_cast(std::to_string(rank_id).c_str()); #else - const char *identify = std::getenv("RANK_ID"); + identify = std::getenv("RANK_ID"); #endif if (identify == nullptr) { MS_LOG(ERROR) << "get hccl rankid failed, please set env RANK_ID"; + free(full_path); return false; } MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << identify; diff --git a/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc b/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc index 283885e9ba..1160ba57b7 100644 --- a/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc +++ b/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc @@ -16,6 +16,7 @@ #include "device/cpu/mpi/mpi_adapter.h" #include +#include "utils/mpi/mpi_config.h" #include "utils/log_adapter.h" namespace mindspore { @@ -35,6 +36,20 @@ MPI_Op GetMpiOp(const std::string &op_type) { MS_LOG(EXCEPTION) << "unsupport op_type:" << op_type; return MPI_SUM; } + +int GetScatterIndex(int rankid, const std::vector &ranks_group) { + int scatter_index = -1; + for (size_t i = 0; i < ranks_group.size(); ++i) { + if (ranks_group[i] == rankid) { + scatter_index = static_cast(i); + break; + } + } + if (scatter_index == -1) { + MS_LOG(EXCEPTION) << "process rankid " << rankid << " does not in the input rank group!"; + } + return scatter_index; +} } // namespace MPIAdapter::MPIAdapter() : rank_id_(0), rank_size_(0), comm_group_world_(MPI_GROUP_NULL) { Init(); } @@ -65,6 +80,11 @@ void MPIAdapter::Init() { if (init) { return; } + auto mpi_config_ptr = MpiConfig::GetInstance(); + MS_EXCEPTION_IF_NULL(mpi_config_ptr); + if (!mpi_config_ptr->enable_mpi()) { + MS_LOG(EXCEPTION) << "MPI is disabled now!Please enable mpi with mpi config first."; + } int init_flag = 0; if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { MS_LOG(EXCEPTION) << "Check mpi initialized fail!"; @@ -123,7 +143,7 @@ MPI_Group MPIAdapter::AddGroup(const std::vector &ranks) { return group; } -bool MPIAdapter::ReduceScatter(float *input, float *output, const std::vector &ranks_group, size_t data_num, +bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector &ranks_group, size_t data_num, const std::string &op_type) { if (ranks_group.empty()) { MS_LOG(ERROR) << "input rank group is empty!"; @@ -159,6 +179,51 @@ bool MPIAdapter::ReduceScatter(float *input, float *output, const std::vector &ranks_group, size_t data_num, + const std::string &op_type, float *output) { + int scatter_index = GetScatterIndex(rank_id_, ranks_group); + auto group = AddGroup(ranks_group); + if (group == MPI_GROUP_NULL) { + MS_LOG(EXCEPTION) << "Get mpi group fail!rankid:" << rank_id_; + } + MPI_Comm comm; + MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); + if (comm == MPI_COMM_NULL) { + MS_LOG(EXCEPTION) << "create mpi comm fail!rankid:" << rank_id_; + } + + MPI_Win window; + auto ret = MPI_Win_create(input, data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window); + if (ret != MPI_SUCCESS) { + MS_LOG(ERROR) << "mpi window create fail! ret = " << ret; + return false; + } + MPI_Win_fence(0, window); + for (size_t i = 0; i < ranks_group.size(); ++i) { + int remote_rank = ranks_group[i]; + if (rank_id_ == remote_rank) { + continue; + } + auto op = GetMpiOp(op_type); + ret = MPI_Accumulate(input + i * data_num, data_num, MPI_FLOAT, remote_rank, i * data_num, data_num, MPI_FLOAT, op, + window); + if (ret != MPI_SUCCESS) { + MS_LOG(EXCEPTION) << "mpi accumulate " << op_type << " fail!ret = " << ret; + } + } + MPI_Win_fence(0, window); + if (output != nullptr) { + auto data_size = data_num * sizeof(float); + auto copy_ret = memcpy_s(output, data_size, input + scatter_index * data_num, data_size); + if (copy_ret != 0) { + MS_LOG(EXCEPTION) << "copy output memory fail!"; + } + } + MPI_Win_free(&window); + MPI_Comm_free(&comm); + return true; +} + bool MPIAdapter::AllGather(float *input, float *output, const std::vector &ranks_group, size_t data_num) { if (ranks_group.empty()) { MS_LOG(ERROR) << "input rank group is empty!"; diff --git a/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h b/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h index c2ed3192a9..398ea1a364 100644 --- a/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h +++ b/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h @@ -32,8 +32,10 @@ class MPIAdapter { ~MPIAdapter(); static MPIAdapter &Instance(); int GetRankId() const; - bool ReduceScatter(float *input, float *output, const std::vector &ranks_group, size_t data_num, + bool ReduceScatter(const float *input, float *output, const std::vector &ranks_group, size_t data_num, const std::string &op_type = kOpTypeSum); + bool ReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t data_num, + const std::string &op_type = kOpTypeSum, float *output = nullptr); bool AllGather(float *input, float *output, const std::vector &ranks_group, size_t data_num); private: diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 0992b7fa66..998c530cf8 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -26,6 +26,7 @@ #include "pipeline/parse/python_adapter.h" #include "utils/summary/event_writer.h" #include "utils/config_manager.h" +#include "utils/mpi/mpi_config.h" #include "parallel/context.h" #include "parallel/device_manager.h" #include "parallel/costmodel_context.h" @@ -147,6 +148,11 @@ PYBIND11_MODULE(_c_expression, m) { .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size."); + (void)py::class_>(m, "MpiConfig") + .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") + .def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.") + .def("set_enable_mpi", &mindspore::MpiConfig::set_enable_mpi, "Set whether to enable mpi."); + (void)py::class_>(m, "AutoParallelContext") .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.") .def("get_device_num", &ParallelContext::device_num, "Get device num.") diff --git a/mindspore/ccsrc/utils/mpi/mpi_config.cc b/mindspore/ccsrc/utils/mpi/mpi_config.cc new file mode 100644 index 0000000000..5831d34f9f --- /dev/null +++ b/mindspore/ccsrc/utils/mpi/mpi_config.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/mpi/mpi_config.h" + +namespace mindspore { + +std::shared_ptr MpiConfig::instance_ = nullptr; + +std::shared_ptr MpiConfig::GetInstance() { + if (instance_ == nullptr) { + MS_LOG(DEBUG) << "Create new mpi config instance."; + instance_.reset(new (std::nothrow) MpiConfig()); + } + return instance_; +} + +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/mpi/mpi_config.h b/mindspore/ccsrc/utils/mpi/mpi_config.h new file mode 100644 index 0000000000..044e767762 --- /dev/null +++ b/mindspore/ccsrc/utils/mpi/mpi_config.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_ +#define MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_ +#include +#include "utils/log_adapter.h" + +namespace mindspore { +class MpiConfig { + public: + ~MpiConfig() = default; + MpiConfig(const MpiConfig &) = delete; + MpiConfig &operator=(const MpiConfig &) = delete; + + static std::shared_ptr GetInstance(); + + void set_enable_mpi(bool flag) { enable_mpi_ = flag; } + bool enable_mpi() const { return enable_mpi_; } + + private: + MpiConfig() : enable_mpi_(false) {} + + static std::shared_ptr instance_; + bool enable_mpi_; +}; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_ diff --git a/mindspore/context.py b/mindspore/context.py index 35f671a1c6..ef2be24636 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -25,6 +25,7 @@ from mindspore._c_expression import MSContext from mindspore._checkparam import args_type_check from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ _reset_auto_parallel_context +from mindspore.parallel.mpi._mpi_config import _set_mpi_config, _get_mpi_config __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context', 'get_auto_parallel_context', 'reset_auto_parallel_context'] @@ -566,3 +567,40 @@ def get_context(attr_key): if not hasattr(_context(), attr_key): raise ValueError("Get context keyword %s is not recognized!" % attr_key) return getattr(_context(), attr_key) + +@args_type_check(enable_mpi=bool) +def set_mpi_config(**kwargs): + """ + Sets mpi config for running environment. + + mpi config should be configured before running your program. If there is no configuration, + mpi moudle will be disabled by default. + + Note: + Attribute name is required for setting attributes. + + Args: + enable_mpi (bool): Whether to enable mpi. Default: False. + + Raises: + ValueError: If input key is not an attribute in mpi config. + + Examples: + >>> mpiconfig.set_mpi_config(enable_mpi=True) + """ + _set_mpi_config(**kwargs) + +def get_mpi_config(attr_key): + """ + Gets mpi config attribute value according to the input key. + + Args: + attr_key (str): The key of the attribute. + + Returns: + Object, The value of given attribute key. + + Raises: + ValueError: If input key is not an attribute in context. + """ + return _get_mpi_config(attr_key) diff --git a/mindspore/parallel/mpi/__init__.py b/mindspore/parallel/mpi/__init__.py new file mode 100644 index 0000000000..e30774307c --- /dev/null +++ b/mindspore/parallel/mpi/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/mindspore/parallel/mpi/_mpi_config.py b/mindspore/parallel/mpi/_mpi_config.py new file mode 100644 index 0000000000..e43305fb76 --- /dev/null +++ b/mindspore/parallel/mpi/_mpi_config.py @@ -0,0 +1,111 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +The MPI config, used to configure the MPI environment. +""" +import threading +from mindspore._c_expression import MpiConfig +from mindspore._checkparam import args_type_check + +class _MpiConfig: + """ + _MpiConfig is the config tool for controlling MPI + + Note: + Create a config through instantiating MpiConfig object is not recommended. + should use MpiConfig() to get the config since MpiConfig is singleton. + """ + _instance = None + _instance_lock = threading.Lock() + + def __init__(self): + self._mpiconfig_handle = MpiConfig.get_instance() + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance_lock.acquire() + cls._instance = object.__new__(cls) + cls._instance_lock.release() + return cls._instance + + def __getattribute__(self, attr): + value = object.__getattribute__(self, attr) + if attr == "_mpiconfig_handle" and value is None: + raise ValueError("mpiconfig handle is none in MpiConfig!!!") + return value + + @property + def enable_mpi(self): + return self._mpiconfig_handle.get_enable_mpi() + + @enable_mpi.setter + def enable_mpi(self, enable_mpi): + self._mpiconfig_handle.set_enable_mpi(enable_mpi) + +_k_mpi_config = None +def _mpi_config(): + """ + Get the global mpi config, if mpi config is not created, create a new one. + + Returns: + _MpiConfig, the global mpi config. + """ + global _k_mpi_config + if _k_mpi_config is None: + _k_mpi_config = _MpiConfig() + return _k_mpi_config + +@args_type_check(enable_mpi=bool) +def _set_mpi_config(**kwargs): + """ + Sets mpi config for running environment. + + mpi config should be configured before running your program. If there is no configuration, + mpi moudle will be disabled by default. + + Note: + Attribute name is required for setting attributes. + + Args: + enable_mpi (bool): Whether to enable mpi. Default: False. + + Raises: + ValueError: If input key is not an attribute in mpi config. + + Examples: + >>> mpiconfig.set_mpi_config(enable_mpi=True) + """ + for key, value in kwargs.items(): + if not hasattr(_mpi_config(), key): + raise ValueError("Set mpi config keyword %s is not recognized!" % key) + setattr(_mpi_config(), key, value) + + +def _get_mpi_config(attr_key): + """ + Gets mpi config attribute value according to the input key. + + Args: + attr_key (str): The key of the attribute. + + Returns: + Object, The value of given attribute key. + + Raises: + ValueError: If input key is not an attribute in context. + """ + if not hasattr(_mpi_config(), attr_key): + raise ValueError("Get context keyword %s is not recognized!" % attr_key) + return getattr(_mpi_config(), attr_key) diff --git a/tests/st/ops/cpu/test_reduce_scatter.py b/tests/st/ops/cpu/test_reduce_scatter.py index 2e1f019af9..6b21efe89c 100644 --- a/tests/st/ops/cpu/test_reduce_scatter.py +++ b/tests/st/ops/cpu/test_reduce_scatter.py @@ -23,9 +23,10 @@ from mindspore.common import dtype as mstype from mindspore.ops import operations as P import mindspore._ms_mpi as mpi # run comand: -# mpirun -np 3 python test_reduce_scatter.py +# mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_reduce_scatter.py context.set_context(mode=context.GRAPH_MODE, device_target='CPU') +context.set_mpi_config(enable_mpi=True) class Net(nn.Cell): def __init__(self): @@ -46,14 +47,19 @@ class AllGatherNet(nn.Cell): return self.hostallgather(x) def test_net_reduce_scatter(): - x = np.ones(12).astype(np.float32) * 0.1 + x = np.arange(12).astype(np.float32) * 0.1 reducescatter = Net() rankid = mpi.get_rank_id() print("self rankid:", rankid) output = reducescatter(Tensor(x, mstype.float32)) print("output:\n", output) - expect_result = np.ones(4).astype(np.float32) * 0.3 + if rankid == 0: + expect_result = np.arange(4).astype(np.float32) * 0.3 + if rankid == 1: + expect_result = np.arange(4, 8).astype(np.float32) * 0.3 + if rankid == 2: + expect_result = np.arange(8, 12).astype(np.float32) * 0.3 diff = abs(output.asnumpy() - expect_result) error = np.ones(shape=expect_result.shape) * 1.0e-6 assert np.all(diff < error) @@ -61,7 +67,7 @@ def test_net_reduce_scatter(): allgather = AllGatherNet() allgather_output = allgather(output) print("allgather result:\n", allgather_output) - expect_allgather_result = np.ones(12).astype(np.float32) * 0.3 + expect_allgather_result = np.arange(12).astype(np.float32) * 0.3 diff = abs(allgather_output.asnumpy() - expect_allgather_result) error = np.ones(shape=expect_allgather_result.shape) * 1.0e-6 assert np.all(diff < error)