|
|
|
@ -201,7 +201,7 @@ class DistributedStrategy(object):
|
|
|
|
|
f.name).extend(getattr(strategy, f.name))
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def async_update(self):
|
|
|
|
|
def a_sync(self):
|
|
|
|
|
"""
|
|
|
|
|
Indicating whether we are using asynchronous stocastic gradient descent updates
|
|
|
|
|
for training. This property is valid when we are using parameter server training,
|
|
|
|
@ -216,29 +216,29 @@ class DistributedStrategy(object):
|
|
|
|
|
fleet.init(role_maker)
|
|
|
|
|
|
|
|
|
|
strategy = fleet.DistributedStrategy()
|
|
|
|
|
strategy.async_update = True # by default this is True
|
|
|
|
|
strategy.a_sync = True # by default this is True
|
|
|
|
|
|
|
|
|
|
# code block for defining loss and local optimizer
|
|
|
|
|
# sgd = fleet.distributed_optimizer(optimizer, strategy)
|
|
|
|
|
"""
|
|
|
|
|
return self.strategy.async
|
|
|
|
|
return self.strategy.a_sync
|
|
|
|
|
|
|
|
|
|
@async_update.setter
|
|
|
|
|
def async_update(self, flag):
|
|
|
|
|
@a_sync.setter
|
|
|
|
|
def a_sync(self, flag):
|
|
|
|
|
if isinstance(flag, bool):
|
|
|
|
|
self.strategy.async = flag
|
|
|
|
|
self.strategy.a_sync = flag
|
|
|
|
|
else:
|
|
|
|
|
print("WARNING: async_update should have value of bool type")
|
|
|
|
|
print("WARNING: a_sync should have value of bool type")
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def async_update_configs(self):
|
|
|
|
|
def a_sync_configs(self):
|
|
|
|
|
"""
|
|
|
|
|
Set async update configurations. In general, asynchronous parameter server
|
|
|
|
|
Set a_sync update configurations. In general, asynchronous parameter server
|
|
|
|
|
training has serveral configurable settings that can be configured through
|
|
|
|
|
a dict.
|
|
|
|
|
|
|
|
|
|
**Notes**:
|
|
|
|
|
**Detailed arguments for async_update_configs**
|
|
|
|
|
**Detailed arguments for a_sync_configs**
|
|
|
|
|
**k_step**: number of local optimization updates before communication
|
|
|
|
|
**max_merge_var_num**: maximum number of merged gradients before communication
|
|
|
|
|
**send_queue_size**: a buffer size of worker communication
|
|
|
|
@ -255,19 +255,20 @@ class DistributedStrategy(object):
|
|
|
|
|
fleet.init(role_maker)
|
|
|
|
|
|
|
|
|
|
strategy = fleet.DistributedStrategy()
|
|
|
|
|
strategy.async_update = True # by default this is True
|
|
|
|
|
strategy.a_sync = True # by default this is True
|
|
|
|
|
configs = {"k_step": 10000, "send_queue_size": 32}
|
|
|
|
|
strategy.async_update_configs = configs
|
|
|
|
|
strategy.a_sync_configs = configs
|
|
|
|
|
|
|
|
|
|
# code block for defining loss and local optimizer
|
|
|
|
|
# sgd = fleet.distributed_optimizer(optimizer, strategy)
|
|
|
|
|
"""
|
|
|
|
|
return get_msg_dict(self.strategy.async_configs)
|
|
|
|
|
return get_msg_dict(self.strategy.a_sync_configs)
|
|
|
|
|
|
|
|
|
|
@async_update_configs.setter
|
|
|
|
|
def async_update_configs(self, configs):
|
|
|
|
|
check_configs_key(self.strategy.async_configs, configs, "async_configs")
|
|
|
|
|
assign_configs_value(self.strategy.async_configs, configs)
|
|
|
|
|
@a_sync_configs.setter
|
|
|
|
|
def a_sync_configs(self, configs):
|
|
|
|
|
check_configs_key(self.strategy.a_sync_configs, configs,
|
|
|
|
|
"a_sync_configs")
|
|
|
|
|
assign_configs_value(self.strategy.a_sync_configs, configs)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def amp(self):
|
|
|
|
@ -584,4 +585,7 @@ class DistributedStrategy(object):
|
|
|
|
|
print("WARNING: auto should have value of bool type")
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
fields = self.strategy.DESCRIPTOR.fields
|
|
|
|
|
for f in fields:
|
|
|
|
|
print("{}: {}".format(f.name, f.default_value))
|
|
|
|
|
return str(self.strategy)
|
|
|
|
|