parent
38da103034
commit
58f7695ab2
@ -0,0 +1,47 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/pybind/communicator_py.h"
|
||||||
|
|
||||||
|
#include <Python.h>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/program_desc.h"
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/distributed/communicator.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
using paddle::framework::ProgramDesc;
|
||||||
|
using paddle::operators::distributed::Communicator;
|
||||||
|
using paddle::framework::Scope;
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace pybind {
|
||||||
|
|
||||||
|
void BindCommunicator(py::module* m) {
|
||||||
|
// Communicator is already used by nccl, change to DistCommunicator
|
||||||
|
py::class_<Communicator, std::shared_ptr<Communicator>>(*m,
|
||||||
|
"DistCommunicator")
|
||||||
|
.def(py::init([](const ProgramDesc& program, Scope* param_scope) {
|
||||||
|
Communicator::Init(program, param_scope);
|
||||||
|
return Communicator::GetInstantcePtr();
|
||||||
|
}))
|
||||||
|
.def("stop", &Communicator::Stop)
|
||||||
|
.def("start", &Communicator::Start);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace pybind
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,27 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <Python.h>
|
||||||
|
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace pybind {
|
||||||
|
|
||||||
|
void BindCommunicator(pybind11::module* m);
|
||||||
|
|
||||||
|
} // namespace pybind
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,88 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .executor import global_scope
|
||||||
|
from . import core
|
||||||
|
from .framework import Program
|
||||||
|
|
||||||
|
__all__ = ['Communicator']
|
||||||
|
|
||||||
|
|
||||||
|
class Communicator(object):
|
||||||
|
def __init__(self, program):
|
||||||
|
"""
|
||||||
|
Communicator is used for async distribute training in distribute_transpiler mode.
|
||||||
|
It's a wrapper of a cpp class Communicator and should be used inside fleet API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
program(Program): the trainers program after transpile of distribute_transpiler.
|
||||||
|
It's used by communicator to extract the information to do communication.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
|
||||||
|
prog = fluid.Program()
|
||||||
|
comm = fluid.communicator.Communicator(prog)
|
||||||
|
comm.start()
|
||||||
|
comm.stop()
|
||||||
|
"""
|
||||||
|
# set all recv op to not_run mode
|
||||||
|
assert isinstance(program, Program)
|
||||||
|
for op in program.block(0).ops:
|
||||||
|
if op.type == "recv":
|
||||||
|
op._set_attr('do_not_run', True)
|
||||||
|
self.communicator_ = core.DistCommunicator(program.desc, global_scope())
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""
|
||||||
|
Start communicator. Should call before training process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
|
||||||
|
prog = fluid.Program()
|
||||||
|
comm = fluid.communicator.Communicator(prog)
|
||||||
|
comm.start()
|
||||||
|
comm.stop()
|
||||||
|
"""
|
||||||
|
self.communicator_.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""
|
||||||
|
Stop communicator. Should call after training process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
|
||||||
|
prog = fluid.Program()
|
||||||
|
comm = fluid.communicator.Communicator(prog)
|
||||||
|
comm.start()
|
||||||
|
comm.stop()
|
||||||
|
"""
|
||||||
|
self.communicator_.stop()
|
@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# start pserver0
|
||||||
|
python fleet_deep_ctr.py \
|
||||||
|
--role pserver \
|
||||||
|
--endpoints 127.0.0.1:7000,127.0.0.1:7001 \
|
||||||
|
--current_endpoint 127.0.0.1:7000 \
|
||||||
|
--trainers 2 \
|
||||||
|
> pserver0.log 2>&1 &
|
||||||
|
|
||||||
|
# start pserver1
|
||||||
|
python fleet_deep_ctr.py \
|
||||||
|
--role pserver \
|
||||||
|
--endpoints 127.0.0.1:7000,127.0.0.1:7001 \
|
||||||
|
--current_endpoint 127.0.0.1:7001 \
|
||||||
|
--trainers 2 \
|
||||||
|
> pserver1.log 2>&1 &
|
||||||
|
|
||||||
|
# start trainer0
|
||||||
|
python fleet_deep_ctr.py \
|
||||||
|
--role trainer \
|
||||||
|
--endpoints 127.0.0.1:7000,127.0.0.1:7001 \
|
||||||
|
--trainers 2 \
|
||||||
|
--trainer_id 0 \
|
||||||
|
> trainer0.log 2>&1 &
|
||||||
|
|
||||||
|
# start trainer1
|
||||||
|
python fleet_deep_ctr.py \
|
||||||
|
--role trainer \
|
||||||
|
--endpoints 127.0.0.1:7000,127.0.0.1:7001 \
|
||||||
|
--trainers 2 \
|
||||||
|
--trainer_id 1 \
|
||||||
|
> trainer1.log 2>&1 &
|
@ -0,0 +1,100 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tarfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid.incubate.data_generator as data_generator
|
||||||
|
|
||||||
|
logging.basicConfig()
|
||||||
|
logger = logging.getLogger("paddle")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
DATA_URL = "http://paddle-ctr-data.bj.bcebos.com/avazu_ctr_data.tgz"
|
||||||
|
DATA_MD5 = "c11df99fbd14e53cd4bfa6567344b26e"
|
||||||
|
"""
|
||||||
|
avazu_ctr_data/train.txt
|
||||||
|
avazu_ctr_data/infer.txt
|
||||||
|
avazu_ctr_data/test.txt
|
||||||
|
avazu_ctr_data/data.meta.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def download_file():
|
||||||
|
file_name = "avazu_ctr_data"
|
||||||
|
path = paddle.dataset.common.download(DATA_URL, file_name, DATA_MD5)
|
||||||
|
|
||||||
|
dir_name = os.path.dirname(path)
|
||||||
|
text_file_dir_name = os.path.join(dir_name, file_name)
|
||||||
|
|
||||||
|
if not os.path.exists(text_file_dir_name):
|
||||||
|
tar = tarfile.open(path, "r:gz")
|
||||||
|
tar.extractall(dir_name)
|
||||||
|
return text_file_dir_name
|
||||||
|
|
||||||
|
|
||||||
|
def load_dnn_input_record(sent):
|
||||||
|
return list(map(int, sent.split()))
|
||||||
|
|
||||||
|
|
||||||
|
def load_lr_input_record(sent):
|
||||||
|
res = []
|
||||||
|
for _ in [x.split(':') for x in sent.split()]:
|
||||||
|
res.append(int(_[0]))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetCtrReader(data_generator.MultiSlotDataGenerator):
|
||||||
|
def generate_sample(self, line):
|
||||||
|
def iter():
|
||||||
|
fs = line.strip().split('\t')
|
||||||
|
dnn_input = load_dnn_input_record(fs[0])
|
||||||
|
lr_input = load_lr_input_record(fs[1])
|
||||||
|
click = [int(fs[2])]
|
||||||
|
yield ("dnn_data", dnn_input), \
|
||||||
|
("lr_data", lr_input), \
|
||||||
|
("click", click)
|
||||||
|
|
||||||
|
return iter
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data():
|
||||||
|
"""
|
||||||
|
load data meta info from path, return (dnn_input_dim, lr_input_dim)
|
||||||
|
"""
|
||||||
|
file_dir_name = download_file()
|
||||||
|
meta_file_path = os.path.join(file_dir_name, 'data.meta.txt')
|
||||||
|
train_file_path = os.path.join(file_dir_name, 'train.txt')
|
||||||
|
with open(meta_file_path, "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
err_info = "wrong meta format"
|
||||||
|
assert len(lines) == 2, err_info
|
||||||
|
assert 'dnn_input_dim:' in lines[0] and 'lr_input_dim:' in lines[
|
||||||
|
1], err_info
|
||||||
|
res = map(int, [_.split(':')[1] for _ in lines])
|
||||||
|
res = list(res)
|
||||||
|
dnn_input_dim = res[0]
|
||||||
|
lr_input_dim = res[1]
|
||||||
|
logger.info('dnn input dim: %d' % dnn_input_dim)
|
||||||
|
logger.info('lr input dim: %d' % lr_input_dim)
|
||||||
|
return dnn_input_dim, lr_input_dim, train_file_path
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pairwise_reader = DatasetCtrReader()
|
||||||
|
pairwise_reader.run_from_stdin()
|
@ -0,0 +1,204 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
||||||
|
from paddle.fluid.incubate.fleet.parameter_server.distributed_transpiler import fleet
|
||||||
|
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
|
||||||
|
|
||||||
|
import ctr_dataset_reader
|
||||||
|
|
||||||
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger("fluid")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="PaddlePaddle Fleet ctr")
|
||||||
|
|
||||||
|
# the following arguments is used for distributed train, if is_local == false, then you should set them
|
||||||
|
parser.add_argument(
|
||||||
|
'--role',
|
||||||
|
type=str,
|
||||||
|
default='pserver', # trainer or pserver
|
||||||
|
help='The path for model to store (default: models)')
|
||||||
|
parser.add_argument(
|
||||||
|
'--endpoints',
|
||||||
|
type=str,
|
||||||
|
default='127.0.0.1:6000',
|
||||||
|
help='The pserver endpoints, like: 127.0.0.1:6000,127.0.0.1:6001')
|
||||||
|
parser.add_argument(
|
||||||
|
'--current_endpoint',
|
||||||
|
type=str,
|
||||||
|
default='127.0.0.1:6000',
|
||||||
|
help='The path for model to store (default: 127.0.0.1:6000)')
|
||||||
|
parser.add_argument(
|
||||||
|
'--trainer_id',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='The path for model to store (default: models)')
|
||||||
|
parser.add_argument(
|
||||||
|
'--trainers',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='The num of trainers, (default: 1)')
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def model():
|
||||||
|
dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data(
|
||||||
|
)
|
||||||
|
""" network definition """
|
||||||
|
dnn_data = fluid.layers.data(
|
||||||
|
name="dnn_data",
|
||||||
|
shape=[-1, 1],
|
||||||
|
dtype="int64",
|
||||||
|
lod_level=1,
|
||||||
|
append_batch_size=False)
|
||||||
|
lr_data = fluid.layers.data(
|
||||||
|
name="lr_data",
|
||||||
|
shape=[-1, 1],
|
||||||
|
dtype="int64",
|
||||||
|
lod_level=1,
|
||||||
|
append_batch_size=False)
|
||||||
|
label = fluid.layers.data(
|
||||||
|
name="click",
|
||||||
|
shape=[-1, 1],
|
||||||
|
dtype="int64",
|
||||||
|
lod_level=0,
|
||||||
|
append_batch_size=False)
|
||||||
|
|
||||||
|
datas = [dnn_data, lr_data, label]
|
||||||
|
|
||||||
|
# build dnn model
|
||||||
|
dnn_layer_dims = [128, 64, 32, 1]
|
||||||
|
dnn_embedding = fluid.layers.embedding(
|
||||||
|
is_distributed=False,
|
||||||
|
input=dnn_data,
|
||||||
|
size=[dnn_input_dim, dnn_layer_dims[0]],
|
||||||
|
param_attr=fluid.ParamAttr(
|
||||||
|
name="deep_embedding",
|
||||||
|
initializer=fluid.initializer.Constant(value=0.01)),
|
||||||
|
is_sparse=True)
|
||||||
|
dnn_pool = fluid.layers.sequence_pool(input=dnn_embedding, pool_type="sum")
|
||||||
|
dnn_out = dnn_pool
|
||||||
|
for i, dim in enumerate(dnn_layer_dims[1:]):
|
||||||
|
fc = fluid.layers.fc(
|
||||||
|
input=dnn_out,
|
||||||
|
size=dim,
|
||||||
|
act="relu",
|
||||||
|
param_attr=fluid.ParamAttr(
|
||||||
|
initializer=fluid.initializer.Constant(value=0.01)),
|
||||||
|
name='dnn-fc-%d' % i)
|
||||||
|
dnn_out = fc
|
||||||
|
|
||||||
|
# build lr model
|
||||||
|
lr_embbding = fluid.layers.embedding(
|
||||||
|
is_distributed=False,
|
||||||
|
input=lr_data,
|
||||||
|
size=[lr_input_dim, 1],
|
||||||
|
param_attr=fluid.ParamAttr(
|
||||||
|
name="wide_embedding",
|
||||||
|
initializer=fluid.initializer.Constant(value=0.01)),
|
||||||
|
is_sparse=True)
|
||||||
|
lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum")
|
||||||
|
|
||||||
|
merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1)
|
||||||
|
|
||||||
|
predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax')
|
||||||
|
acc = fluid.layers.accuracy(input=predict, label=label)
|
||||||
|
auc_var, batch_auc_var, auc_states = fluid.layers.auc(input=predict,
|
||||||
|
label=label)
|
||||||
|
cost = fluid.layers.cross_entropy(input=predict, label=label)
|
||||||
|
avg_cost = fluid.layers.mean(x=cost)
|
||||||
|
|
||||||
|
return datas, avg_cost, predict, train_file_path
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
datas, avg_cost, predict, train_file_path = model()
|
||||||
|
|
||||||
|
endpoints = args.endpoints.split(",")
|
||||||
|
if args.role.upper() == "PSERVER":
|
||||||
|
current_id = endpoints.index(args.current_endpoint)
|
||||||
|
else:
|
||||||
|
current_id = 0
|
||||||
|
role = role_maker.UserDefinedRoleMaker(
|
||||||
|
current_id=current_id,
|
||||||
|
role=role_maker.Role.WORKER
|
||||||
|
if args.role.upper() == "TRAINER" else role_maker.Role.SERVER,
|
||||||
|
worker_num=args.trainers,
|
||||||
|
server_endpoints=endpoints)
|
||||||
|
|
||||||
|
exe = fluid.Executor(fluid.CPUPlace())
|
||||||
|
fleet.init(role)
|
||||||
|
|
||||||
|
strategy = DistributeTranspilerConfig()
|
||||||
|
strategy.sync_mode = False
|
||||||
|
|
||||||
|
optimizer = fluid.optimizer.SGD(learning_rate=0.0001)
|
||||||
|
optimizer = fleet.distributed_optimizer(optimizer, strategy)
|
||||||
|
optimizer.minimize(avg_cost)
|
||||||
|
|
||||||
|
if fleet.is_server():
|
||||||
|
logger.info("run pserver")
|
||||||
|
|
||||||
|
fleet.init_server()
|
||||||
|
fleet.run_server()
|
||||||
|
elif fleet.is_worker():
|
||||||
|
logger.info("run trainer")
|
||||||
|
|
||||||
|
fleet.init_worker()
|
||||||
|
exe.run(fleet.startup_program)
|
||||||
|
|
||||||
|
thread_num = 2
|
||||||
|
filelist = []
|
||||||
|
for _ in range(thread_num):
|
||||||
|
filelist.append(train_file_path)
|
||||||
|
|
||||||
|
# config dataset
|
||||||
|
dataset = fluid.DatasetFactory().create_dataset()
|
||||||
|
dataset.set_batch_size(128)
|
||||||
|
dataset.set_use_var(datas)
|
||||||
|
pipe_command = 'python ctr_dataset_reader.py'
|
||||||
|
dataset.set_pipe_command(pipe_command)
|
||||||
|
|
||||||
|
dataset.set_filelist(filelist)
|
||||||
|
dataset.set_thread(thread_num)
|
||||||
|
|
||||||
|
for epoch_id in range(10):
|
||||||
|
logger.info("epoch {} start".format(epoch_id))
|
||||||
|
pass_start = time.time()
|
||||||
|
dataset.set_filelist(filelist)
|
||||||
|
exe.train_from_dataset(
|
||||||
|
program=fleet.main_program,
|
||||||
|
dataset=dataset,
|
||||||
|
fetch_list=[avg_cost],
|
||||||
|
fetch_info=["cost"],
|
||||||
|
print_period=100,
|
||||||
|
debug=False)
|
||||||
|
pass_time = time.time() - pass_start
|
||||||
|
logger.info("epoch {} finished, pass_time {}".format(epoch_id,
|
||||||
|
pass_time))
|
||||||
|
fleet.stop_worker()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
train(args)
|
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.fluid.communicator import Communicator
|
||||||
|
|
||||||
|
|
||||||
|
class TestCommunicator(unittest.TestCase):
|
||||||
|
def test_communicator_init_and_start(self):
|
||||||
|
prog = fluid.Program()
|
||||||
|
comm = Communicator(prog)
|
||||||
|
comm.start()
|
||||||
|
comm.stop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue