local_add_cudnn_lstm
peizhilin 7 years ago
parent 6d0d5a76eb
commit d1a1fafc4c

@ -29,7 +29,7 @@ namespace platform {
void SetNumThreads(int num_threads) { void SetNumThreads(int num_threads) {
#ifdef PADDLE_USE_OPENBLAS #ifdef PADDLE_USE_OPENBLAS
// windows has no support for openblas multi-thread // windows has no support for openblas multi-thread
#ifdef _WIN32 #ifdef _WIN32
if (num_threads > 1) { if (num_threads > 1) {
num_threads = 1; num_threads = 1;

@ -29,7 +29,6 @@ if os.name != 'nt':
ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
BuildStrategy = core.ParallelExecutor.BuildStrategy BuildStrategy = core.ParallelExecutor.BuildStrategy
class ParallelExecutor(object): class ParallelExecutor(object):
""" """
ParallelExecutor is designed for data parallelism, which focuses on distributing ParallelExecutor is designed for data parallelism, which focuses on distributing
@ -161,7 +160,8 @@ if os.name != 'nt':
for p in main.global_block().iter_parameters() for p in main.global_block().iter_parameters()
if not p.stop_gradient if not p.stop_gradient
]), ]),
set(cpt.to_text(var) for var in self.persistable_vars), main.desc, set(cpt.to_text(var)
for var in self.persistable_vars), main.desc,
cpt.to_text(loss_name) cpt.to_text(loss_name)
if loss_name else six.u(''), scope, local_scopes, exec_strategy, if loss_name else six.u(''), scope, local_scopes, exec_strategy,
build_strategy, num_trainers, trainer_id) build_strategy, num_trainers, trainer_id)

Loading…
Cancel
Save