parent
b096383386
commit
6034f9c1e2
@ -0,0 +1,31 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "utils/mpi/mpi_config.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
std::shared_ptr<MpiConfig> MpiConfig::instance_ = nullptr;
|
||||
|
||||
std::shared_ptr<MpiConfig> MpiConfig::GetInstance() {
|
||||
if (instance_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "Create new mpi config instance.";
|
||||
instance_.reset(new (std::nothrow) MpiConfig());
|
||||
}
|
||||
return instance_;
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
|
||||
#define MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
|
||||
#include <memory>
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
class MpiConfig {
|
||||
public:
|
||||
~MpiConfig() = default;
|
||||
MpiConfig(const MpiConfig &) = delete;
|
||||
MpiConfig &operator=(const MpiConfig &) = delete;
|
||||
|
||||
static std::shared_ptr<MpiConfig> GetInstance();
|
||||
|
||||
void set_enable_mpi(bool flag) { enable_mpi_ = flag; }
|
||||
bool enable_mpi() const { return enable_mpi_; }
|
||||
|
||||
private:
|
||||
MpiConfig() : enable_mpi_(false) {}
|
||||
|
||||
static std::shared_ptr<MpiConfig> instance_;
|
||||
bool enable_mpi_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
|
@ -0,0 +1,14 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
@ -0,0 +1,111 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
The MPI config, used to configure the MPI environment.
|
||||
"""
|
||||
import threading
|
||||
from mindspore._c_expression import MpiConfig
|
||||
from mindspore._checkparam import args_type_check
|
||||
|
||||
class _MpiConfig:
|
||||
"""
|
||||
_MpiConfig is the config tool for controlling MPI
|
||||
|
||||
Note:
|
||||
Create a config through instantiating MpiConfig object is not recommended.
|
||||
should use MpiConfig() to get the config since MpiConfig is singleton.
|
||||
"""
|
||||
_instance = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self._mpiconfig_handle = MpiConfig.get_instance()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance_lock.acquire()
|
||||
cls._instance = object.__new__(cls)
|
||||
cls._instance_lock.release()
|
||||
return cls._instance
|
||||
|
||||
def __getattribute__(self, attr):
|
||||
value = object.__getattribute__(self, attr)
|
||||
if attr == "_mpiconfig_handle" and value is None:
|
||||
raise ValueError("mpiconfig handle is none in MpiConfig!!!")
|
||||
return value
|
||||
|
||||
@property
|
||||
def enable_mpi(self):
|
||||
return self._mpiconfig_handle.get_enable_mpi()
|
||||
|
||||
@enable_mpi.setter
|
||||
def enable_mpi(self, enable_mpi):
|
||||
self._mpiconfig_handle.set_enable_mpi(enable_mpi)
|
||||
|
||||
_k_mpi_config = None
|
||||
def _mpi_config():
|
||||
"""
|
||||
Get the global mpi config, if mpi config is not created, create a new one.
|
||||
|
||||
Returns:
|
||||
_MpiConfig, the global mpi config.
|
||||
"""
|
||||
global _k_mpi_config
|
||||
if _k_mpi_config is None:
|
||||
_k_mpi_config = _MpiConfig()
|
||||
return _k_mpi_config
|
||||
|
||||
@args_type_check(enable_mpi=bool)
|
||||
def _set_mpi_config(**kwargs):
|
||||
"""
|
||||
Sets mpi config for running environment.
|
||||
|
||||
mpi config should be configured before running your program. If there is no configuration,
|
||||
mpi moudle will be disabled by default.
|
||||
|
||||
Note:
|
||||
Attribute name is required for setting attributes.
|
||||
|
||||
Args:
|
||||
enable_mpi (bool): Whether to enable mpi. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in mpi config.
|
||||
|
||||
Examples:
|
||||
>>> mpiconfig.set_mpi_config(enable_mpi=True)
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if not hasattr(_mpi_config(), key):
|
||||
raise ValueError("Set mpi config keyword %s is not recognized!" % key)
|
||||
setattr(_mpi_config(), key, value)
|
||||
|
||||
|
||||
def _get_mpi_config(attr_key):
|
||||
"""
|
||||
Gets mpi config attribute value according to the input key.
|
||||
|
||||
Args:
|
||||
attr_key (str): The key of the attribute.
|
||||
|
||||
Returns:
|
||||
Object, The value of given attribute key.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
"""
|
||||
if not hasattr(_mpi_config(), attr_key):
|
||||
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
|
||||
return getattr(_mpi_config(), attr_key)
|
Loading…
Reference in new issue