You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
153 lines
5.5 KiB
153 lines
5.5 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op
|
|
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
|
|
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
|
|
|
|
|
|
class Shard(object):
|
|
def __init__(self, ):
|
|
self.global_params = set([])
|
|
self.worker_idx = -1
|
|
self.worker_num = -1
|
|
self.global_param2device = {}
|
|
|
|
def setup(self, params_grads, worker_idx, worker_num):
|
|
# param names of all devices
|
|
self.global_params = set([x[0].name for x in params_grads])
|
|
# _param(str) -> device_id(int)
|
|
self.worker_idx = worker_idx
|
|
self.worker_num = worker_num
|
|
# global_param2device contains fp32 params and fp16 params
|
|
self.global_param2device = self._split_params(params_grads, worker_idx,
|
|
worker_num)
|
|
|
|
def has_param(self, var_name):
|
|
return var_name in self.global_param2device and \
|
|
self._var_device_id(var_name) == self.worker_idx
|
|
|
|
def has_opt_var(self, var_name):
|
|
return self._var_device_id(var_name) == self.worker_idx
|
|
|
|
def has_var(self, var_name):
|
|
return self._var_device_id(var_name) == -1 or \
|
|
self._var_device_id(var_name) == self.worker_idx
|
|
|
|
def _split_params(self, params_grads, worker_idx, worker_num):
|
|
param2device = {}
|
|
total_param_mem = 0.0
|
|
param2mem = []
|
|
for param in [x[0] for x in params_grads]:
|
|
mem = get_var_size(param)
|
|
total_param_mem += mem
|
|
param2mem.append((param.name, mem))
|
|
device2params = {x: [] for x in range(worker_num)}
|
|
device_idx = 0
|
|
mem_accu = 0.0
|
|
for param_name, mem in param2mem:
|
|
if mem_accu > total_param_mem * 1.0 * (device_idx + 1) / worker_num:
|
|
device_idx += 1
|
|
device2params[device_idx].append(param_name)
|
|
param2device[param_name] = device_idx
|
|
mem_accu += mem
|
|
return param2device
|
|
|
|
def _var_device_id(self, var_name):
|
|
if var_name in self.global_param2device:
|
|
return self.global_param2device[var_name]
|
|
for suffix in [
|
|
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
|
|
"_beta2_pow_acc_0", "_velocity_0"
|
|
]:
|
|
base_name = re.sub(suffix, '', var_name)
|
|
if base_name in self.global_param2device:
|
|
return self.global_param2device[base_name]
|
|
return -1
|
|
|
|
def find_broadcast_params(self, block):
|
|
broadcast_vars = set([])
|
|
fp16_params = set([])
|
|
fp16_to_fp32 = {}
|
|
|
|
param_usage = {x: 0 for x in self.global_params}
|
|
for op in block.ops:
|
|
if is_optimizer_op(op):
|
|
continue
|
|
for input_name in op.desc.input_arg_names():
|
|
if input_name in self.global_params:
|
|
param_usage[input_name] += 1
|
|
|
|
for op in block.ops:
|
|
if not FP16Utils.is_fp16_cast_op(block, op, self.global_params):
|
|
continue
|
|
input_name = op.input_arg_names[0]
|
|
output_name = op.output_arg_names[0]
|
|
broadcast_vars.add(output_name)
|
|
fp16_params.add(output_name)
|
|
fp16_to_fp32[output_name] = input_name
|
|
param_usage[input_name] -= 1
|
|
self.global_param2device[output_name] = self.global_param2device[
|
|
input_name]
|
|
|
|
for param, usage in param_usage.items():
|
|
if usage > 0:
|
|
broadcast_vars.add(param)
|
|
return broadcast_vars
|
|
|
|
def device(self, var_name):
|
|
return self._var_device_id(var_name)
|
|
|
|
def is_param(self, var_name):
|
|
return var_name in self.global_params
|
|
|
|
def is_opti_var(self, var_name):
|
|
if var_name in self.global_params:
|
|
return True
|
|
for suffix in [
|
|
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
|
|
"_beta2_pow_acc_0", "_velocity_0"
|
|
]:
|
|
base_name = re.sub(suffix, '', var_name)
|
|
if base_name in self.global_params:
|
|
return True
|
|
return False
|
|
|
|
def filter_grads(self, grads):
|
|
grads_in_shard = []
|
|
for grad in grads:
|
|
param = grad.split("@")[0]
|
|
if self.has_param(param):
|
|
grads_in_shard.append(grad)
|
|
return grads_in_shard
|
|
|
|
|
|
class ProgramSegment(object):
|
|
def __init__(self, block):
|
|
self._block = block
|
|
self._allreduce_vars = []
|
|
# sub program start idx
|
|
self._start_idx = -1
|
|
# sub program end idx
|
|
self._end_idx = -1
|
|
# param name to broadcast name
|
|
self._param2broadcast = {}
|
|
self._broadcast_vars = []
|
|
# cast op pairs, fp16 name (str) -> fp32 name (str)
|
|
self._cast_ops = {}
|
|
# fill constant vars
|
|
self._fill_constant_vars = []
|
|
# parameter mems
|
|
self._param_mem = 0.0
|