enbale ps mode init in server

pull/7598/head
lizhenyu 4 years ago
parent 49f02bfbed
commit fd1d61eaf7

@ -1556,8 +1556,7 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
auto input_node = input_nodes[i];
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
auto pk_node = input_node->cast<ParameterPtr>();
ps::worker.InitPSParamAndOptim(pk_node->fullname_with_scope(), tensor);
ps::worker.InitPSParamAndOptim(input_node, tensor);
}
}
}

@ -52,7 +52,7 @@ class Worker {
void SetOptimInputShapes(size_t key, const ShapeVector &shape);
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const ShapeVector &sizes);
void InitPSParamAndOptim(const std::string &param_name, const tensor::TensorPtr &tensor);
void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor);
void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int cmd);
void Finalize();
@ -321,14 +321,17 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto
}
template <typename T>
void Worker<T>::InitPSParamAndOptim(const std::string &param_name, const tensor::TensorPtr &tensor) {
void Worker<T>::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor) {
MS_EXCEPTION_IF_NULL(tensor);
MS_EXCEPTION_IF_NULL(input_node);
auto pk_node = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(pk_node);
const std::string &param_name = pk_node->fullname_with_scope();
void *param_data = tensor->data_c();
size_t param_size = LongToSize(tensor->data().nbytes());
if (param_size > INT_MAX) {
MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " << param_size;
}
ShapeVector param_shape = tensor->shape_c();
size_t param_key = GetParamKey(param_name);
if (param_key == kInvalidKey) {
@ -336,8 +339,8 @@ void Worker<T>::InitPSParamAndOptim(const std::string &param_name, const tensor:
return;
}
bool init_in_server = false;
ShapeVector shape_init_in_server = {1};
if (param_shape == shape_init_in_server) {
auto param_info_ptr = pk_node->param_info();
if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) {
init_in_server = true;
}
SetParamInitInServer(param_name, init_in_server);

@ -26,6 +26,7 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
.def("clone", &ParamInfo::Clone)
.def_property("name", &ParamInfo::name, &ParamInfo::set_name)
.def_property("requires_grad", &ParamInfo::requires_grad, &ParamInfo::set_requires_grad)
.def_property("init_in_server", &ParamInfo::init_in_server, &ParamInfo::set_init_in_server)
.def_property("layerwise_parallel", &ParamInfo::layerwise_parallel,
&ParamInfo::set_layerwise_parallel)
.def(py::pickle(

@ -144,7 +144,7 @@ class Parameter(MetaTensor_):
if isinstance(data, bool):
raise ValueError('Parameter data can not be `bool`')
if isinstance(data, MetaTensor):
if _is_in_parallel_mode():
if _is_in_parallel_mode() or _is_role_worker():
# do not init data while in auto parallel.
return (MetaTensor_, data.dtype, data.shape)
data = data.to_tensor()
@ -174,8 +174,12 @@ class Parameter(MetaTensor_):
def set_param_ps(self, init_in_server=False):
if _is_role_worker() or _is_role_pserver() or _is_role_sched():
if init_in_server and (not self.name.endswith("embedding_table")):
raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of \
sparse operator support initialization in server.".format(self.name))
self.is_param_ps = True
self.init_in_server = init_in_server
self._param_info.init_in_server = init_in_server
else:
raise RuntimeError("Must complete following two steps before calling set_param_ps: \
1. set_ps_context(enable_ps=True) \
@ -270,6 +274,8 @@ class Parameter(MetaTensor_):
x._param_info = self._param_info.clone()
x._param_info.name = prefix + '.' + self._param_info.name
x.is_init = False
x.is_param_ps = self.is_param_ps
x.init_in_server = self.init_in_server
if init != 'same':
shape = self.shape
dtype = self.dtype
@ -403,12 +409,18 @@ class Parameter(MetaTensor_):
raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout))
slice_index = int(_get_slice_index(layout[0], layout[1]))
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
data = self.init_mode.to_tensor(0, [1])
if _is_role_worker():
data = self.init_mode.to_tensor(0, [1])
else:
data = self.init_mode.to_tensor(slice_index, layout[2])
else:
data = self.init_mode.to_tensor(slice_index, layout[2])
else:
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
data = self.init_mode.to_tensor(0, [1])
if _is_role_worker():
data = self.init_mode.to_tensor(0, [1])
else:
data = self.init_mode.to_tensor()
else:
data = self.init_mode.to_tensor()

@ -42,6 +42,9 @@ class ParamInfo {
bool requires_grad() const { return requires_grad_; }
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
bool init_in_server() const { return init_in_server_; }
void set_init_in_server(bool init_in_server) { init_in_server_ = init_in_server; }
bool layerwise_parallel() const { return layerwise_parallel_; }
void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; }
@ -68,12 +71,14 @@ class ParamInfo {
clone->cloned_index_ = index;
this->be_cloned_ = true;
this->be_cloned_index_.push_back(index);
clone->init_in_server_ = this->init_in_server_;
return clone;
}
private:
std::string name_{"Parameter"};
bool requires_grad_{true};
bool init_in_server_{false};
bool layerwise_parallel_{false};
bool be_cloned_{false};
bool cloned_{false};

@ -26,7 +26,6 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam
from mindspore.common import set_seed
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_worker
parser = argparse.ArgumentParser(description="test_sparse_embedding")
@ -39,18 +38,6 @@ context.set_context(
context.set_ps_context(enable_ps=True)
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
def __init__(self, num_class=10):
super(LeNet5, self).__init__()
@ -58,7 +45,7 @@ class LeNet5(nn.Cell):
self.flatten = nn.Flatten()
self.embedding = nn.EmbeddingLookup(16, 4)
self.relu = nn.ReLU()
self.fc = fc_with_initialize(12, num_class)
self.fc = nn.Dense(12, num_class)
def construct(self, x):
x = self.cast(x, mstype.int32)

Loading…
Cancel
Save