supplement bug fix of parameter server (#26217)

* fix fluid.embedding
revert-26856-strategy_example2
Chengmo 5 years ago committed by GitHub
parent ad6e3dd69c
commit d0962abd20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,25 +25,32 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("Ids"),
"Input(Ids) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs("Outputs"),
"Output(Outs) of LookupTableOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInputs("Ids"), true,
platform::errors::InvalidArgument(
"Input(Ids) of LookupTableOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
platform::errors::InvalidArgument(
"Input(W) of LookupTableOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutputs("Outputs"), true,
platform::errors::InvalidArgument(
"Output(Outs) of LookupTableOp should not be null."));
auto ids_dims = ctx->GetInputsDim("Ids");
auto table_dims = ctx->GetInputDim("W");
PADDLE_ENFORCE_EQ(table_dims.size(), 2,
"Only 2 dimensions of the 'Embedding' is supported.");
PADDLE_ENFORCE_EQ(
table_dims.size(), 2,
platform::errors::InvalidArgument(
"Only 2 dimensions of the 'Embedding' is supported."));
for (auto &ids_dim : ids_dims) {
PADDLE_ENFORCE_EQ(ids_dim.size(), 2,
"The dimension of the 'Ids' tensor must be 2.");
platform::errors::InvalidArgument(
"The dimension of the 'Ids' tensor must be 2."));
}
auto endpoints = ctx->Attrs().Get<std::vector<std::string>>("endpoints");
// for fluid.embedding
auto lookup_table_version =
ctx->Attrs().Get<std::string>("lookup_table_version");

@ -35,9 +35,30 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
auto is_distributed = context.Attr<bool>("is_distributed");
auto lookup_table_version =
context.Attr<std::string>("lookup_table_version");
operators::distributed::prefetchs(id_names, out_names, embedding_name,
is_distributed, lookup_tables, endpoints,
context, context.scope());
if (lookup_table_version == "lookup_table_v2") {
auto &scope = context.scope();
auto emb_dim =
scope.FindVar(embedding_name)->Get<framework::LoDTensor>().dims()[1];
for (size_t i = 0; i < id_names.size(); ++i) {
auto *id_var = scope.FindVar(id_names[i]);
auto *out_var = scope.FindVar(out_names[i]);
auto *id_tensor = id_var->GetMutable<framework::LoDTensor>();
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
auto id_dims = id_tensor->dims();
out_tensor->Resize(framework::make_ddim(
{static_cast<int64_t>(id_dims[0]), static_cast<int64_t>(id_dims[1]),
static_cast<int64_t>(emb_dim)}));
}
}
}
};

@ -154,15 +154,16 @@ class ParameterServerRuntime(RuntimeBase):
kwargs["sparse_attrs"] = get_sparse_attrs()
return kwargs
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_lr_ops
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_lr_ops, _has_global_step
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import \
SyncStrategy, GeoStrategy
trainer_config = self.async_strategy.get_trainer_runtime_config()
lrs = _get_lr_ops(self.origin_main_program)
if len(lrs) > 0:
lrs = _has_global_step(_get_lr_ops(self.origin_main_program))
if lrs:
kwargs = {"need_global_step": "1"}
else:
kwargs = {"need_global_step": "0"}

@ -42,6 +42,9 @@ op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
SPARSE_OP_LIST = ["lookup_table", "lookup_table_v2"]
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
def _get_lr_ops(program):
lr_ops = []
@ -66,7 +69,7 @@ def _has_global_step(lr_ops):
def is_sparse_op(op):
if op.type == "lookup_table" and op.attr('is_sparse') is True and op.attr(
if op.type in SPARSE_OP_LIST and op.attr('is_sparse') is True and op.attr(
'is_distributed') is False:
return True
@ -78,7 +81,7 @@ def is_sparse_op(op):
def is_distributed_sparse_op(op):
if op.type == "lookup_table" and op.attr('is_distributed') is True:
if op.type in SPARSE_OP_LIST and op.attr('is_distributed') is True:
return True
if op.type == "distributed_lookup_table" and op.attr(
@ -802,11 +805,10 @@ class CompileTimeStrategy(object):
def _get_sparse_varnames():
varnames = []
op_types = {"lookup_table": "W"}
for op in origin_program.global_block().ops:
if op.type in op_types.keys() \
if op.type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True:
param_name = op.input(op_types[op.type])[0]
param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
varnames.append(param_name)
return list(set(varnames))

@ -40,6 +40,8 @@ LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
DEVICE_LIST = ["cpu", "gpu", "xpu"]
COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"]
DEFAULT_DEVICE = 'cpu'
@ -81,11 +83,10 @@ def distributed_ops_pass(program, config):
def _get_pull_sparse_ops(_program):
pull_sparse_ops = {}
op_types = {"lookup_table": "W"}
for op in _program.global_block().ops:
if op.type in op_types.keys() \
if op.type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True:
param_name = op.input(op_types[op.type])[0]
param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
ops = pull_sparse_ops.get(param_name, [])
ops.append(op)
pull_sparse_ops[param_name] = ops
@ -101,6 +102,7 @@ def distributed_ops_pass(program, config):
w = program.global_block().vars[ops[0].input("W")[0]]
padding_idx = ops[0].attr("padding_idx")
is_distributed = ops[0].attr("is_distributed")
op_type = ops[0].type
outputs = [
program.global_block().vars[op.output("Out")[0]] for op in ops
@ -149,7 +151,8 @@ def distributed_ops_pass(program, config):
"is_distributed": is_distributed,
"pserver_num": len(pserver_endpoints),
"padding_idx": padding_idx,
"trainer_id": trainer_id
"trainer_id": trainer_id,
"lookup_table_version": op_type
})
else:
raise ValueError(

@ -432,8 +432,6 @@ if(WITH_DISTRIBUTE)
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_mnist_lars")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_mnist_train")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_save_load")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_simnet_bow")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_ctr")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_text_classification")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_train")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_word2vec")

@ -196,8 +196,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
fleet.stop_worker()
def do_dataset_training(self, fleet):
dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data(
)
train_file_list = ctr_dataset_reader.prepare_fake_data()
exe = fluid.Executor(fluid.CPUPlace())
@ -206,9 +205,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
thread_num = 2
batch_size = 128
filelist = []
for _ in range(thread_num):
filelist.append(train_file_path)
filelist = train_file_list
# config dataset
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset()

@ -0,0 +1,33 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import logging
import tarfile
import random
import paddle
import paddle.fluid.incubate.data_generator as data_generator
logging.basicConfig()
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
class DatasetSimnetReader(data_generator.MultiSlotDataGenerator):
def generate_sample(self, line):
pass

@ -156,40 +156,5 @@ class TestDistCtrHalfAsync2x2(TestFleetBase):
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistCtrPsGpuPyreaderAsync2x2(TestFleetBase):
def _setup_config(self):
self._mode = "async"
self._reader = "pyreader"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "30000", # 5sec to fail fast
"http_proxy": "",
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"CPU_NUM": "2",
"SAVE_MODEL": "1"
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self):
self.check_with_place(
"dist_fleet_ctr_ps_gpu.py", delta=1e-5, check_error_log=True)
if __name__ == "__main__":
unittest.main()

@ -21,7 +21,7 @@ import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
from test_dist_fleet_base import TestFleetBase
from dist_simnet_bow import train_network
from dist_fleet_simnet_bow import train_network
class TestDistGeoCtr_2x2(TestFleetBase):
@ -72,7 +72,7 @@ class TestGeoSgdTranspiler(unittest.TestCase):
strategy = StrategyFactory.create_geo_strategy(5)
avg_cost, _, _ = train_network(batch_size, is_distribute, is_sparse)
avg_cost, _, _, _ = train_network(batch_size, is_distribute, is_sparse)
optimizer = fluid.optimizer.SGD(0.1)
optimizer = fleet.distributed_optimizer(optimizer, strategy)

@ -21,7 +21,7 @@ import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
from test_dist_fleet_base import TestFleetBase
from dist_simnet_bow import train_network
from dist_fleet_simnet_bow import train_network
@unittest.skip(reason="Skip unstable ut, add it after PR 22957 merged")
@ -44,7 +44,7 @@ class TestDistGeoClipByGlobalNormTranspiler(unittest.TestCase):
strategy.geo_sgd_mode = True
strategy.geo_sgd_need_push_nums = 5
avg_cost, _, _ = train_network(batch_size, is_distribute, is_sparse)
avg_cost, _, _, _ = train_network(batch_size, is_distribute, is_sparse)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(2.0))

@ -0,0 +1,56 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import tempfile
from test_dist_fleet_base import TestFleetBase
class TestDistSimnetASync2x2(TestFleetBase):
def _setup_config(self):
self._mode = "async"
self._reader = "pyreader"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "",
"CPU_NUM": "2"
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self):
self.check_with_place(
"dist_fleet_simnet_bow.py", delta=1e-5, check_error_log=True)
if __name__ == "__main__":
unittest.main()

@ -1,161 +0,0 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
from test_dist_base import TestDistBase
import os
flag_name = os.path.splitext(__file__)[0]
class TestDistSimnetBowDense2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def test_simnet_bow(self):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '0',
'IS_SELF_CONTAINED_LR': '1'
}
self.check_with_place(
"dist_simnet_bow.py",
delta=1e-5,
check_error_log=True,
need_envs=need_envs,
log_name=flag_name)
class TestDistSimnetBow2x2DenseAsync(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._enforce_place = "CPU"
# FIXME(typhoonzero): fix async tests later
def notest_simnet_bow(self):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '0',
'IS_SELF_CONTAINED_LR': '1',
}
self.check_with_place(
"dist_simnet_bow.py",
delta=100,
check_error_log=True,
need_envs=need_envs,
log_name=flag_name)
class TestDistSimnetBowSparse2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def test_simnet_bow(self):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '1',
'IS_SELF_CONTAINED_LR': '1'
}
self.check_with_place(
"dist_simnet_bow.py",
delta=1e-5,
check_error_log=True,
need_envs=need_envs,
log_name=flag_name)
class TestDistSimnetBow2x2SparseAsync(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._enforce_place = "CPU"
def test_simnet_bow(self):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '1',
'IS_SELF_CONTAINED_LR': '1'
}
self.check_with_place(
"dist_simnet_bow.py",
delta=100,
check_error_log=True,
need_envs=need_envs,
log_name=flag_name)
# FIXME(tangwei): Learningrate variable is not created on pserver.
class TestDistSimnetBow2x2LookupTableSync(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def test_simnet_bow(self):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '1',
'IS_SELF_CONTAINED_LR': '1'
}
self.check_with_place(
"dist_simnet_bow.py",
delta=1e-5,
check_error_log=True,
need_envs=need_envs,
log_name=flag_name)
class TestDistSimnetBow2x2LookupTableAsync(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._enforce_place = "CPU"
def test_simnet_bow(self):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '1',
'IS_SELF_CONTAINED_LR': '1'
}
self.check_with_place(
"dist_simnet_bow.py",
delta=100,
check_error_log=True,
need_envs=need_envs,
log_name=flag_name)
class TestDistSimnetBow2x2LookupTableNotContainLRSync(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def test_simnet_bow(self):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '1',
'IS_SELF_CONTAINED_LR': '0'
}
self.check_with_place(
"dist_simnet_bow.py",
delta=1e-5,
check_error_log=True,
need_envs=need_envs,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
Loading…
Cancel
Save