|
|
|
@ -14,7 +14,7 @@
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle.distributed.fleet.proto import distributed_strategy_pb2
|
|
|
|
|
from paddle.fluid.framework import Variable
|
|
|
|
|
from paddle.fluid.framework import Variable, set_flags, core
|
|
|
|
|
import google.protobuf.text_format
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -810,6 +810,68 @@ class DistributedStrategy(object):
|
|
|
|
|
else:
|
|
|
|
|
print("WARNING: auto should have value of bool type")
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def cudnn_exhaustive_search(self):
|
|
|
|
|
return self.strategy.cudnn_exhaustive_search
|
|
|
|
|
|
|
|
|
|
@cudnn_exhaustive_search.setter
|
|
|
|
|
def cudnn_exhaustive_search(self, flag):
|
|
|
|
|
if isinstance(flag, bool):
|
|
|
|
|
self.strategy.cudnn_exhaustive_search = flag
|
|
|
|
|
else:
|
|
|
|
|
print(
|
|
|
|
|
"WARNING: cudnn_exhaustive_search should have value of bool type"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def conv_workspace_size_limit(self):
|
|
|
|
|
return self.strategy.conv_workspace_size_limit
|
|
|
|
|
|
|
|
|
|
@conv_workspace_size_limit.setter
|
|
|
|
|
def conv_workspace_size_limit(self, value):
|
|
|
|
|
if isinstance(value, int):
|
|
|
|
|
self.strategy.conv_workspace_size_limit = value
|
|
|
|
|
else:
|
|
|
|
|
print(
|
|
|
|
|
"WARNING: conv_workspace_size_limit should have value of int type"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def cudnn_batchnorm_spatial_persistent(self):
|
|
|
|
|
return self.strategy.cudnn_batchnorm_spatial_persistent
|
|
|
|
|
|
|
|
|
|
@cudnn_batchnorm_spatial_persistent.setter
|
|
|
|
|
def cudnn_batchnorm_spatial_persistent(self, flag):
|
|
|
|
|
if isinstance(flag, bool):
|
|
|
|
|
self.strategy.cudnn_batchnorm_spatial_persistent = flag
|
|
|
|
|
else:
|
|
|
|
|
print(
|
|
|
|
|
"WARNING: cudnn_batchnorm_spatial_persistent should have value of bool type"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _enable_env(self):
|
|
|
|
|
strategy = self.strategy
|
|
|
|
|
keys = [
|
|
|
|
|
"FLAGS_cudnn_batchnorm_spatial_persistent",
|
|
|
|
|
"FLAGS_conv_workspace_size_limit",
|
|
|
|
|
"FLAGS_cudnn_exhaustive_search",
|
|
|
|
|
"FLAGS_sync_nccl_allreduce",
|
|
|
|
|
"FLAGS_fuse_parameter_memory_size",
|
|
|
|
|
"FLAGS_fuse_parameter_groups_size",
|
|
|
|
|
]
|
|
|
|
|
values = [
|
|
|
|
|
bool(strategy.cudnn_batchnorm_spatial_persistent),
|
|
|
|
|
int(strategy.conv_workspace_size_limit),
|
|
|
|
|
bool(strategy.cudnn_exhaustive_search),
|
|
|
|
|
bool(strategy.sync_nccl_allreduce),
|
|
|
|
|
int(strategy.fuse_grad_size_in_MB),
|
|
|
|
|
int(strategy.fuse_grad_size_in_TFLOPS),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
for i, key in enumerate(keys):
|
|
|
|
|
if core.globals().is_public(key):
|
|
|
|
|
core.globals()[key] = values[i]
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
fields = self.strategy.DESCRIPTOR.fields
|
|
|
|
|
for f in fields:
|
|
|
|
|