|
|
|
@ -52,6 +52,15 @@ class TestStrategyFactor(unittest.TestCase):
|
|
|
|
|
self.assertRaises(Exception, strategy.set_program_config,
|
|
|
|
|
program_config_illegal)
|
|
|
|
|
|
|
|
|
|
trainer_runtime_config = strategy.get_trainer_runtime_config()
|
|
|
|
|
trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_send_queue_size'] = '50'
|
|
|
|
|
runtime_configs = trainer_runtime_config.get_communicator_flags()
|
|
|
|
|
self.assertIn('communicator_send_queue_size', runtime_configs)
|
|
|
|
|
self.assertNotIn('communicator_independent_recv_thread',
|
|
|
|
|
runtime_configs)
|
|
|
|
|
self.assertEqual(runtime_configs['communicator_send_queue_size'], '2')
|
|
|
|
|
|
|
|
|
|
def test_geo_strategy(self):
|
|
|
|
|
strategy = StrategyFactory.create_geo_strategy(5)
|
|
|
|
|
self.assertEqual(strategy._program_config.sync_mode, False)
|
|
|
|
@ -82,6 +91,14 @@ class TestStrategyFactor(unittest.TestCase):
|
|
|
|
|
self.assertRaises(Exception, strategy.set_build_strategy,
|
|
|
|
|
build_strategy_illegal)
|
|
|
|
|
|
|
|
|
|
os.environ["CPU_NUM"] = '100'
|
|
|
|
|
trainer_runtime_config = strategy.get_trainer_runtime_config()
|
|
|
|
|
runtime_configs = trainer_runtime_config.get_communicator_flags()
|
|
|
|
|
self.assertIn('communicator_thread_pool_size', runtime_configs)
|
|
|
|
|
self.assertIn('communicator_send_wait_times', runtime_configs)
|
|
|
|
|
self.assertNotIn('communicator_independent_recv_thread',
|
|
|
|
|
runtime_configs)
|
|
|
|
|
|
|
|
|
|
def test_async_strategy(self):
|
|
|
|
|
os.environ["CPU_NUM"] = '100'
|
|
|
|
|
|
|
|
|
@ -164,6 +181,16 @@ class TestStrategyFactor(unittest.TestCase):
|
|
|
|
|
self.assertRaises(Exception, strategy.set_server_runtime_config,
|
|
|
|
|
server_runtime_config_illegal)
|
|
|
|
|
|
|
|
|
|
os.environ["CPU_NUM"] = '100'
|
|
|
|
|
trainer_runtime_config = strategy.get_trainer_runtime_config()
|
|
|
|
|
trainer_runtime_config.runtime_configs[
|
|
|
|
|
'communicator_send_queue_size'] = '50'
|
|
|
|
|
runtime_configs = trainer_runtime_config.get_communicator_flags()
|
|
|
|
|
self.assertIn('communicator_send_queue_size', runtime_configs)
|
|
|
|
|
self.assertNotIn('communicator_independent_recv_thread',
|
|
|
|
|
runtime_configs)
|
|
|
|
|
self.assertEqual(runtime_configs['communicator_send_queue_size'], '100')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestCreateDefaultStrategy(unittest.TestCase):
|
|
|
|
|
def test_default_strategy(self):
|
|
|
|
|