parent
a838c9bd3d
commit
87bf2a7dcd
@ -0,0 +1,86 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/parallel/ps/ps_context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
namespace ps {
|
||||
std::shared_ptr<PSContext> PSContext::instance() {
|
||||
static std::shared_ptr<PSContext> ps_instance = nullptr;
|
||||
if (ps_instance == nullptr) {
|
||||
ps_instance.reset(new (std::nothrow) PSContext());
|
||||
}
|
||||
return ps_instance;
|
||||
}
|
||||
|
||||
void PSContext::SetPSEnable(bool enabled) {
|
||||
ps_enabled_ = enabled;
|
||||
if (ps_enabled_) {
|
||||
std::string ms_role = common::GetEnv(kEnvRole);
|
||||
MS_LOG(INFO) << "PS mode is enabled. MS_ROLE is " << ms_role;
|
||||
if (ms_role == kEnvRoleOfWorker) {
|
||||
is_worker_ = true;
|
||||
} else if (ms_role == kEnvRoleOfPServer) {
|
||||
is_pserver_ = true;
|
||||
} else if (ms_role == kEnvRoleOfScheduler) {
|
||||
is_sched_ = true;
|
||||
} else {
|
||||
MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid.";
|
||||
}
|
||||
} else {
|
||||
MS_LOG(INFO) << "PS mode is disabled.";
|
||||
is_worker_ = false;
|
||||
is_pserver_ = false;
|
||||
is_sched_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
bool PSContext::is_ps_enabled() const { return ps_enabled_; }
|
||||
|
||||
void PSContext::Reset() {
|
||||
ps_enabled_ = false;
|
||||
is_worker_ = false;
|
||||
is_pserver_ = false;
|
||||
is_sched_ = false;
|
||||
}
|
||||
|
||||
std::string PSContext::ms_role() const {
|
||||
if (is_worker_) {
|
||||
return kEnvRoleOfWorker;
|
||||
} else if (is_pserver_) {
|
||||
return kEnvRoleOfPServer;
|
||||
} else if (is_sched_) {
|
||||
return kEnvRoleOfScheduler;
|
||||
} else {
|
||||
return kEnvRoleOfNotPS;
|
||||
}
|
||||
}
|
||||
|
||||
bool PSContext::is_role_worker() const { return is_worker_; }
|
||||
|
||||
bool PSContext::is_role_pserver() const { return is_pserver_; }
|
||||
|
||||
bool PSContext::is_role_sched() const { return is_sched_; }
|
||||
|
||||
void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; }
|
||||
|
||||
int PSContext::ps_rank_id() const { return rank_id_; }
|
||||
} // namespace ps
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
@ -0,0 +1,61 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
namespace ps {
|
||||
constexpr char kEnvRole[] = "MS_ROLE";
|
||||
constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
|
||||
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
|
||||
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
|
||||
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
|
||||
|
||||
class PSContext {
|
||||
public:
|
||||
~PSContext() = default;
|
||||
PSContext(PSContext const &) = delete;
|
||||
PSContext &operator=(const PSContext &) = delete;
|
||||
static std::shared_ptr<PSContext> instance();
|
||||
|
||||
void SetPSEnable(bool enabled);
|
||||
bool is_ps_enabled() const;
|
||||
void Reset();
|
||||
std::string ms_role() const;
|
||||
bool is_role_worker() const;
|
||||
bool is_role_pserver() const;
|
||||
bool is_role_sched() const;
|
||||
void SetPSRankId(int rank_id);
|
||||
int ps_rank_id() const;
|
||||
|
||||
private:
|
||||
PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {}
|
||||
bool ps_enabled_;
|
||||
bool is_worker_;
|
||||
bool is_pserver_;
|
||||
bool is_sched_;
|
||||
int rank_id_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_CONTEXT_H_
|
@ -0,0 +1,115 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Context for parameter server training mode"""
|
||||
|
||||
from mindspore._c_expression import PSContext
|
||||
|
||||
_ps_context = None
|
||||
|
||||
|
||||
def ps_context():
|
||||
"""
|
||||
Get the global _ps_context, if it is not created, create a new one.
|
||||
|
||||
Returns:
|
||||
_ps_context, the global parameter server training mode context.
|
||||
"""
|
||||
global _ps_context
|
||||
if _ps_context is None:
|
||||
_ps_context = PSContext.get_instance()
|
||||
return _ps_context
|
||||
|
||||
_set_ps_context_func_map = {
|
||||
"enable_ps": ps_context().set_ps_enable
|
||||
}
|
||||
|
||||
_get_ps_context_func_map = {
|
||||
"enable_ps": ps_context().is_ps_enabled
|
||||
}
|
||||
|
||||
def _get_ps_mode_rank():
|
||||
ps_rank = ps_context().ps_rank_id()
|
||||
if ps_rank == -1:
|
||||
raise RuntimeError("The parameter server mode training is not enabled yet.")
|
||||
return ps_rank
|
||||
|
||||
def _set_ps_context(**kwargs):
|
||||
"""
|
||||
Set parameter server training mode context.
|
||||
|
||||
Note:
|
||||
Some other environment variables should also be set for parameter server training mode.
|
||||
These environment variables are listed below:
|
||||
MS_SERVER_NUM # Server number
|
||||
MS_WORKER_NUM # Worker number
|
||||
MS_SCHED_HOST # Scheduler IP address
|
||||
MS_SCHED_PORT # Scheduler port
|
||||
MS_ROLE # The role of this process:
|
||||
MS_SCHED represents the scheduler,
|
||||
MS_WORKER represents the worker,
|
||||
MS_PSERVER represents the Server
|
||||
|
||||
|
||||
Args:
|
||||
enable_ps (bool): Whether to enable parameter server training mode.
|
||||
Only after enable_ps is set True, the environment variables will be effective.
|
||||
Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not the attribute in parameter server training mode context.
|
||||
|
||||
Examples:
|
||||
>>> context.set_ps_context(enable_ps=True)
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if key not in _set_ps_context_func_map:
|
||||
raise ValueError("Set PS context keyword %s is not recognized!" % key)
|
||||
set_func = _set_ps_context_func_map[key]
|
||||
set_func(value)
|
||||
|
||||
def _get_ps_context(attr_key):
|
||||
"""
|
||||
Get parameter server training mode context attribute value according to the key.
|
||||
|
||||
Args:
|
||||
attr_key (str): The key of the attribute.
|
||||
|
||||
Returns:
|
||||
Returns attribute value according to the key.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
"""
|
||||
if key not in _get_ps_context_func_map:
|
||||
raise ValueError("Get PS context keyword %s is not recognized!" % key)
|
||||
get_func = _get_ps_context_func_map[attr_key]
|
||||
get_func(attr_key)
|
||||
|
||||
def _reset_ps_context():
|
||||
"""
|
||||
Reset parameter server training mode context attributes to the default values:
|
||||
|
||||
- enable_ps: False.
|
||||
"""
|
||||
ps_context().reset()
|
||||
|
||||
def _is_role_worker():
|
||||
return ps_context().is_role_worker()
|
||||
|
||||
def _is_role_pserver():
|
||||
return ps_context().is_role_pserver()
|
||||
|
||||
def _is_role_sched():
|
||||
return ps_context().is_role_sched()
|
@ -1,23 +0,0 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils for parameter server training mode"""
|
||||
|
||||
from mindspore._c_expression import get_ps_mode_rank
|
||||
|
||||
def _get_ps_mode_rank():
|
||||
ps_rank = get_ps_mode_rank()
|
||||
if ps_rank == -1:
|
||||
raise RuntimeError("The parameter server mode training is not launched yet.")
|
||||
return ps_rank
|
Loading…
Reference in new issue