parent
b096383386
commit
6034f9c1e2
@ -0,0 +1,31 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "utils/mpi/mpi_config.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
std::shared_ptr<MpiConfig> MpiConfig::instance_ = nullptr;
|
||||||
|
|
||||||
|
std::shared_ptr<MpiConfig> MpiConfig::GetInstance() {
|
||||||
|
if (instance_ == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "Create new mpi config instance.";
|
||||||
|
instance_.reset(new (std::nothrow) MpiConfig());
|
||||||
|
}
|
||||||
|
return instance_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,42 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
|
||||||
|
#define MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
|
||||||
|
#include <memory>
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
class MpiConfig {
|
||||||
|
public:
|
||||||
|
~MpiConfig() = default;
|
||||||
|
MpiConfig(const MpiConfig &) = delete;
|
||||||
|
MpiConfig &operator=(const MpiConfig &) = delete;
|
||||||
|
|
||||||
|
static std::shared_ptr<MpiConfig> GetInstance();
|
||||||
|
|
||||||
|
void set_enable_mpi(bool flag) { enable_mpi_ = flag; }
|
||||||
|
bool enable_mpi() const { return enable_mpi_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
MpiConfig() : enable_mpi_(false) {}
|
||||||
|
|
||||||
|
static std::shared_ptr<MpiConfig> instance_;
|
||||||
|
bool enable_mpi_;
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_UTILS_MPI_MS_CONTEXT_H_
|
@ -0,0 +1,14 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
@ -0,0 +1,111 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
The MPI config, used to configure the MPI environment.
|
||||||
|
"""
|
||||||
|
import threading
|
||||||
|
from mindspore._c_expression import MpiConfig
|
||||||
|
from mindspore._checkparam import args_type_check
|
||||||
|
|
||||||
|
class _MpiConfig:
|
||||||
|
"""
|
||||||
|
_MpiConfig is the config tool for controlling MPI
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Create a config through instantiating MpiConfig object is not recommended.
|
||||||
|
should use MpiConfig() to get the config since MpiConfig is singleton.
|
||||||
|
"""
|
||||||
|
_instance = None
|
||||||
|
_instance_lock = threading.Lock()
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._mpiconfig_handle = MpiConfig.get_instance()
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance_lock.acquire()
|
||||||
|
cls._instance = object.__new__(cls)
|
||||||
|
cls._instance_lock.release()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __getattribute__(self, attr):
|
||||||
|
value = object.__getattribute__(self, attr)
|
||||||
|
if attr == "_mpiconfig_handle" and value is None:
|
||||||
|
raise ValueError("mpiconfig handle is none in MpiConfig!!!")
|
||||||
|
return value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enable_mpi(self):
|
||||||
|
return self._mpiconfig_handle.get_enable_mpi()
|
||||||
|
|
||||||
|
@enable_mpi.setter
|
||||||
|
def enable_mpi(self, enable_mpi):
|
||||||
|
self._mpiconfig_handle.set_enable_mpi(enable_mpi)
|
||||||
|
|
||||||
|
_k_mpi_config = None
|
||||||
|
def _mpi_config():
|
||||||
|
"""
|
||||||
|
Get the global mpi config, if mpi config is not created, create a new one.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_MpiConfig, the global mpi config.
|
||||||
|
"""
|
||||||
|
global _k_mpi_config
|
||||||
|
if _k_mpi_config is None:
|
||||||
|
_k_mpi_config = _MpiConfig()
|
||||||
|
return _k_mpi_config
|
||||||
|
|
||||||
|
@args_type_check(enable_mpi=bool)
|
||||||
|
def _set_mpi_config(**kwargs):
|
||||||
|
"""
|
||||||
|
Sets mpi config for running environment.
|
||||||
|
|
||||||
|
mpi config should be configured before running your program. If there is no configuration,
|
||||||
|
mpi moudle will be disabled by default.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Attribute name is required for setting attributes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enable_mpi (bool): Whether to enable mpi. Default: False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input key is not an attribute in mpi config.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> mpiconfig.set_mpi_config(enable_mpi=True)
|
||||||
|
"""
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if not hasattr(_mpi_config(), key):
|
||||||
|
raise ValueError("Set mpi config keyword %s is not recognized!" % key)
|
||||||
|
setattr(_mpi_config(), key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mpi_config(attr_key):
|
||||||
|
"""
|
||||||
|
Gets mpi config attribute value according to the input key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attr_key (str): The key of the attribute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Object, The value of given attribute key.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input key is not an attribute in context.
|
||||||
|
"""
|
||||||
|
if not hasattr(_mpi_config(), attr_key):
|
||||||
|
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
|
||||||
|
return getattr(_mpi_config(), attr_key)
|
Loading…
Reference in new issue