|
|
|
@ -20,9 +20,11 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo
|
|
|
|
|
from paddle.fluid.incubate.fleet.base.role_maker import UserDefinedRoleMaker
|
|
|
|
|
from paddle.fluid.incubate.fleet.base.role_maker import UserDefinedCollectiveRoleMaker
|
|
|
|
|
from paddle.fluid.incubate.fleet.base.role_maker import Role
|
|
|
|
|
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
|
|
|
|
|
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
|
|
|
|
|
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import TranspilerOptimizer
|
|
|
|
|
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer
|
|
|
|
|
from dist_simnet_bow import train_network
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributeTranspilerConfigTest(unittest.TestCase):
|
|
|
|
@ -97,6 +99,30 @@ class FleetTest(unittest.TestCase):
|
|
|
|
|
main_program=compiled_prog)
|
|
|
|
|
self.assertRaises(Exception, fleet._transpile, "config")
|
|
|
|
|
|
|
|
|
|
def set_program(self, avg_cost, strategy):
|
|
|
|
|
optimizer = fluid.optimizer.SGD(0.1)
|
|
|
|
|
optimizer = fleet.distributed_optimizer(optimizer, strategy)
|
|
|
|
|
optimizer.minimize(avg_cost)
|
|
|
|
|
|
|
|
|
|
def test_init_role(self):
|
|
|
|
|
role = role_maker.UserDefinedRoleMaker(
|
|
|
|
|
current_id=0,
|
|
|
|
|
role=role_maker.Role.SERVER,
|
|
|
|
|
worker_num=2,
|
|
|
|
|
server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"])
|
|
|
|
|
# for test optimizer without init(role)
|
|
|
|
|
# fleet.init(role)
|
|
|
|
|
batch_size = 128
|
|
|
|
|
is_sparse = True
|
|
|
|
|
is_distribute = False
|
|
|
|
|
strategy = DistributeTranspilerConfig()
|
|
|
|
|
strategy.sync_mode = False
|
|
|
|
|
strategy.geo_sgd_mode = True
|
|
|
|
|
strategy.geo_sgd_need_push_nums = 5
|
|
|
|
|
avg_cost, _, _ = train_network(batch_size, is_distribute, is_sparse)
|
|
|
|
|
|
|
|
|
|
self.assertRaises(Exception, self.set_program, avg_cost, strategy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TranspilerOptimizerTest(unittest.TestCase):
|
|
|
|
|
def testInvalidInputs(self):
|
|
|
|
@ -124,7 +150,7 @@ class UserDefinedRoleMakerTest(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
def testRoleMaker(self):
|
|
|
|
|
self.createRoleMaker()
|
|
|
|
|
## test all invalid server_endpoints
|
|
|
|
|
# test all invalid server_endpoints
|
|
|
|
|
self.assertRaises(
|
|
|
|
|
Exception, self.createRoleMaker,
|
|
|
|
|
server_endpoints=None) # server_endpoints must be as list
|
|
|
|
@ -140,7 +166,7 @@ class UserDefinedRoleMakerTest(unittest.TestCase):
|
|
|
|
|
self.createRoleMaker,
|
|
|
|
|
server_endpoints=["127.0.0.1:8080", "127.0.0.1:8080"]
|
|
|
|
|
) # element in server_endpoints can't be duplicate
|
|
|
|
|
## test all invalid current_id
|
|
|
|
|
# test all invalid current_id
|
|
|
|
|
self.assertRaises(
|
|
|
|
|
Exception, self.createRoleMaker,
|
|
|
|
|
current_id="0") # current_id must be as int
|
|
|
|
@ -154,14 +180,14 @@ class UserDefinedRoleMakerTest(unittest.TestCase):
|
|
|
|
|
role=Role.SERVER,
|
|
|
|
|
server_endpoints=["127.0.0.1:8080"]
|
|
|
|
|
) # if role is server, current_id must be less than len(server_endpoints)
|
|
|
|
|
## test all invalid worker_num
|
|
|
|
|
# test all invalid worker_num
|
|
|
|
|
self.assertRaises(
|
|
|
|
|
Exception, self.createRoleMaker,
|
|
|
|
|
worker_num="1") # worker_num must be as int
|
|
|
|
|
self.assertRaises(
|
|
|
|
|
Exception, self.createRoleMaker,
|
|
|
|
|
worker_num=0) # worker_num must be greater than 0
|
|
|
|
|
## test all invalid role
|
|
|
|
|
# test all invalid role
|
|
|
|
|
self.assertRaises(
|
|
|
|
|
Exception, self.createRoleMaker,
|
|
|
|
|
role=3) # role must be as Role(Role.WORKER=1, Role.SERVER=2)
|
|
|
|
@ -174,7 +200,7 @@ class UserDefinedCollectiveRoleMakerTest(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
def testRoleMaker(self):
|
|
|
|
|
self.createRoleMaker()
|
|
|
|
|
## test all invalid worker_endpoints
|
|
|
|
|
# test all invalid worker_endpoints
|
|
|
|
|
self.assertRaises(
|
|
|
|
|
Exception, self.createRoleMaker,
|
|
|
|
|
worker_endpoints=None) # worker_endpoints must be as list
|
|
|
|
@ -190,7 +216,7 @@ class UserDefinedCollectiveRoleMakerTest(unittest.TestCase):
|
|
|
|
|
self.createRoleMaker,
|
|
|
|
|
worker_endpoints=["127.0.0.1:8080", "127.0.0.1:8080"]
|
|
|
|
|
) # element in worker_endpoints can't be duplicate
|
|
|
|
|
## test all invalid current_id
|
|
|
|
|
# test all invalid current_id
|
|
|
|
|
self.assertRaises(
|
|
|
|
|
Exception, self.createRoleMaker,
|
|
|
|
|
current_id="0") # current_id must be as int
|
|
|
|
|