test=develop, optimize distributedstrategy (#22677)

* test=develop, optimize distributedstrategy
revert-22710-feature/integrated_ps_api
123malin 5 years ago committed by GitHub
parent 5ee29c67b8
commit 0f9d40816e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -52,6 +52,15 @@ class TestStrategyFactor(unittest.TestCase):
self.assertRaises(Exception, strategy.set_program_config,
program_config_illegal)
trainer_runtime_config = strategy.get_trainer_runtime_config()
trainer_runtime_config.runtime_configs[
'communicator_send_queue_size'] = '50'
runtime_configs = trainer_runtime_config.get_communicator_flags()
self.assertIn('communicator_send_queue_size', runtime_configs)
self.assertNotIn('communicator_independent_recv_thread',
runtime_configs)
self.assertEqual(runtime_configs['communicator_send_queue_size'], '2')
def test_geo_strategy(self):
strategy = StrategyFactory.create_geo_strategy(5)
self.assertEqual(strategy._program_config.sync_mode, False)
@ -82,6 +91,14 @@ class TestStrategyFactor(unittest.TestCase):
self.assertRaises(Exception, strategy.set_build_strategy,
build_strategy_illegal)
os.environ["CPU_NUM"] = '100'
trainer_runtime_config = strategy.get_trainer_runtime_config()
runtime_configs = trainer_runtime_config.get_communicator_flags()
self.assertIn('communicator_thread_pool_size', runtime_configs)
self.assertIn('communicator_send_wait_times', runtime_configs)
self.assertNotIn('communicator_independent_recv_thread',
runtime_configs)
def test_async_strategy(self):
os.environ["CPU_NUM"] = '100'
@ -164,6 +181,16 @@ class TestStrategyFactor(unittest.TestCase):
self.assertRaises(Exception, strategy.set_server_runtime_config,
server_runtime_config_illegal)
os.environ["CPU_NUM"] = '100'
trainer_runtime_config = strategy.get_trainer_runtime_config()
trainer_runtime_config.runtime_configs[
'communicator_send_queue_size'] = '50'
runtime_configs = trainer_runtime_config.get_communicator_flags()
self.assertIn('communicator_send_queue_size', runtime_configs)
self.assertNotIn('communicator_independent_recv_thread',
runtime_configs)
self.assertEqual(runtime_configs['communicator_send_queue_size'], '100')
class TestCreateDefaultStrategy(unittest.TestCase):
def test_default_strategy(self):

Loading…
Cancel
Save