You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
119 lines
4.2 KiB
119 lines
4.2 KiB
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import multiprocessing
|
|
import os
|
|
import six
|
|
from .. import compat as cpt
|
|
|
|
from . import core
|
|
|
|
ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
|
|
BuildStrategy = core.ParallelExecutor.BuildStrategy
|
|
|
|
|
|
def _place_obj(place):
|
|
p = core.Place()
|
|
p.set_place(place)
|
|
return p
|
|
|
|
|
|
class _ProgramCompiler(object):
|
|
def __init__(self, program):
|
|
self._program = program
|
|
self._compiled = False
|
|
self._is_data_parallel = False
|
|
|
|
def _with_data_parallel(self,
|
|
loss_name=None,
|
|
build_strategy=None,
|
|
exec_strategy=None):
|
|
assert not self._is_data_parallel, "Already compiled with parallel."
|
|
self._is_data_parallel = True
|
|
self._build_strategy = build_strategy
|
|
self._exec_strategy = exec_strategy
|
|
self._loss_name = loss_name
|
|
return self
|
|
|
|
def _compile_data_parallel(self):
|
|
self._places = []
|
|
self._local_scopes = []
|
|
|
|
if self._exec_strategy is None:
|
|
self._exec_strategy = ExecutionStrategy()
|
|
if self._build_strategy is None:
|
|
self._build_strategy = BuildStrategy()
|
|
|
|
self._exec_strategy.use_cuda = isinstance(self._place, core.CUDAPlace)
|
|
if self._exec_strategy.use_cuda:
|
|
gpus_env = os.getenv("FLAGS_selected_gpus")
|
|
if gpus_env:
|
|
gpus = [int(s) for s in gpus_env.split(",")]
|
|
else:
|
|
gpus = [
|
|
i for i in six.moves.range(core.get_cuda_device_count())
|
|
]
|
|
self._places = [core.CUDAPlace(i) for i in gpus]
|
|
else:
|
|
cpu_num = int(
|
|
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
|
|
self._places = [core.CPUPlace() for _ in six.moves.range(cpu_num)]
|
|
assert self._places, "no place for execution"
|
|
|
|
if self._exec_strategy.num_threads == 0:
|
|
if self._exec_strategy.use_cuda:
|
|
# Experiments on se-resnext shows that too many threads hurt
|
|
# performance. Worth tunning for other models in the future.
|
|
self._exec_strategy.num_threads = len(self._places) * 4
|
|
else:
|
|
cpu_num = int(
|
|
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
|
|
self._exec_strategy.num_threads = cpu_num * 2
|
|
|
|
trainers_endpoints = self._program._trainers_endpoints
|
|
if self._build_strategy.num_trainers > 1 and trainers_endpoints:
|
|
assert self._build_strategy.num_trainers == len(
|
|
trainers_endpoints), "num_trainers == len(end_points)"
|
|
self._build_strategy.trainers_endpoints = trainers_endpoints
|
|
|
|
self._persistable_vars = set([
|
|
cpt.to_text(v.name)
|
|
for v in [
|
|
var for var in self._program.list_vars()
|
|
if var.persistable and var.type != core.VarDesc.VarType.RAW
|
|
]
|
|
])
|
|
|
|
places = list(map(_place_obj, self._places))
|
|
return core.ParallelExecutor(
|
|
places, self._persistable_vars, self._program.desc,
|
|
cpt.to_text(self._loss_name)
|
|
if self._loss_name else six.u(''), self._scope, self._local_scopes,
|
|
self._exec_strategy, self._build_strategy)
|
|
|
|
def _compile(self, scope, place):
|
|
if self._compiled:
|
|
return self
|
|
self._compiled = True
|
|
|
|
self._scope = scope
|
|
self._place = place
|
|
|
|
if self._is_data_parallel:
|
|
self._executor = self._compile_data_parallel()
|
|
else:
|
|
p = _place_obj(self._place)
|
|
self._executor = core.Executor(p)
|
|
return self
|