parent
4877bd5944
commit
81244fbfab
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,154 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op, OP_ROLE_KEY, OpRole
|
||||||
|
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
|
||||||
|
|
||||||
|
from paddle.fluid import core
|
||||||
|
|
||||||
|
|
||||||
|
class FP16Utils(object):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_fp16_cast_op(block, op, params):
|
||||||
|
if op.type != "cast":
|
||||||
|
return False
|
||||||
|
if is_optimizer_op(op):
|
||||||
|
return False
|
||||||
|
assert (len(op.desc.input_arg_names()) == 1)
|
||||||
|
assert (len(op.desc.output_arg_names()) == 1)
|
||||||
|
input_name, output_name = op.desc.input_arg_names()[
|
||||||
|
0], op.desc.output_arg_names()[0]
|
||||||
|
if input_name not in params:
|
||||||
|
return False
|
||||||
|
input_var = block.var(input_name)
|
||||||
|
output_var = block.var(output_name)
|
||||||
|
if input_var.dtype != core.VarDesc.VarType.FP32 or \
|
||||||
|
output_var.dtype != core.VarDesc.VarType.FP16:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_fp32_cast_op(block, op):
|
||||||
|
if op.type != "cast":
|
||||||
|
return False
|
||||||
|
if not is_optimizer_op(op):
|
||||||
|
return False
|
||||||
|
assert (len(op.desc.input_arg_names()) == 1)
|
||||||
|
assert (len(op.desc.output_arg_names()) == 1)
|
||||||
|
input_name, output_name = op.desc.input_arg_names()[
|
||||||
|
0], op.desc.output_arg_names()[0]
|
||||||
|
input_var = block.var(input_name)
|
||||||
|
output_var = block.var(output_name)
|
||||||
|
if input_var.dtype != core.VarDesc.VarType.FP16 or \
|
||||||
|
output_var.dtype != core.VarDesc.VarType.FP32:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def remove_cast_op(block, params, segment, offset):
|
||||||
|
inserted_op_num = 0
|
||||||
|
for op_idx in reversed(
|
||||||
|
range(offset + segment._start_idx, offset + segment._end_idx)):
|
||||||
|
op = block.ops[op_idx]
|
||||||
|
if FP16Utils.is_fp16_cast_op(block, op, params):
|
||||||
|
block._remove_op(op_idx, sync=False)
|
||||||
|
inserted_op_num -= 1
|
||||||
|
block._sync_with_cpp()
|
||||||
|
return inserted_op_num
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prune_fp16(block, shard, reduced_grads_to_param, nrings):
|
||||||
|
# remove cast
|
||||||
|
for idx, op in reversed(list(enumerate(block.ops))):
|
||||||
|
if not FP16Utils.is_fp32_cast_op(block, op):
|
||||||
|
continue
|
||||||
|
output_name = op.desc.output_arg_names()[0]
|
||||||
|
param_name = output_name.strip("@GRAD")
|
||||||
|
if param_name not in shard.global_params:
|
||||||
|
raise ValueError("Input 'X' of check_finite_and_unscale must"
|
||||||
|
"be grads, but {} is not a grad".format(
|
||||||
|
input_name))
|
||||||
|
if output_name in reduced_grads_to_param:
|
||||||
|
continue
|
||||||
|
if shard.has_param(param_name):
|
||||||
|
continue
|
||||||
|
block._remove_op(idx, sync=False)
|
||||||
|
block._remove_var(output_name, sync=False)
|
||||||
|
|
||||||
|
block._sync_with_cpp()
|
||||||
|
update_loss_scaling_op_idx = -1
|
||||||
|
inf_var_name = ''
|
||||||
|
for idx, op in reversed(list(enumerate(block.ops))):
|
||||||
|
if op.type == "update_loss_scaling":
|
||||||
|
update_loss_scaling_op_idx = idx
|
||||||
|
inf_var_name = op.desc.input('FoundInfinite')[0]
|
||||||
|
op._rename_input(inf_var_name, inf_var_name + "@sharding")
|
||||||
|
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
|
||||||
|
reversed_x = []
|
||||||
|
for input_name in op.desc.input('X'):
|
||||||
|
param_name = input_name.strip("@GRAD")
|
||||||
|
if param_name not in shard.global_params:
|
||||||
|
raise ValueError(
|
||||||
|
"Input 'X' of check_finite_and_unscale must"
|
||||||
|
"be grads, but {} is not a grad".format(input_name))
|
||||||
|
if shard.has_param(param_name):
|
||||||
|
reversed_x.append(input_name)
|
||||||
|
op.desc.set_input('X', reversed_x)
|
||||||
|
op.desc.set_output('Out', reversed_x)
|
||||||
|
if update_loss_scaling_op_idx == -1:
|
||||||
|
return
|
||||||
|
inf_var = block.var(inf_var_name)
|
||||||
|
inf_var_fp32 = block.create_var(
|
||||||
|
name=inf_var_name + "@cast_int32",
|
||||||
|
shape=inf_var.shape,
|
||||||
|
dtype=core.VarDesc.VarType.INT32)
|
||||||
|
inf_var_sharding = block.create_var(
|
||||||
|
name=inf_var_name + "@sharding",
|
||||||
|
shape=inf_var.shape,
|
||||||
|
dtype=inf_var.dtype)
|
||||||
|
block._insert_op_without_sync(
|
||||||
|
update_loss_scaling_op_idx,
|
||||||
|
type='cast',
|
||||||
|
inputs={'X': inf_var},
|
||||||
|
outputs={'Out': inf_var_fp32},
|
||||||
|
attrs={
|
||||||
|
"in_dtype": inf_var.dtype,
|
||||||
|
"out_dtype": inf_var_fp32.dtype,
|
||||||
|
OP_ROLE_KEY: OpRole.Optimize
|
||||||
|
})
|
||||||
|
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
|
||||||
|
[inf_var_fp32])
|
||||||
|
block._insert_op_without_sync(
|
||||||
|
update_loss_scaling_op_idx + 2,
|
||||||
|
type='c_allreduce_max',
|
||||||
|
inputs={'X': inf_var_fp32},
|
||||||
|
outputs={'Out': inf_var_fp32},
|
||||||
|
attrs={'ring_id': 0,
|
||||||
|
OP_ROLE_KEY: OpRole.Optimize})
|
||||||
|
comm_op_num = insert_sync_comm_ops(
|
||||||
|
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32])
|
||||||
|
block._insert_op_without_sync(
|
||||||
|
update_loss_scaling_op_idx + 3 + comm_op_num,
|
||||||
|
type='cast',
|
||||||
|
inputs={'X': inf_var_fp32},
|
||||||
|
outputs={'Out': inf_var_sharding},
|
||||||
|
attrs={
|
||||||
|
"in_dtype": inf_var_fp32.dtype,
|
||||||
|
"out_dtype": inf_var_sharding.dtype,
|
||||||
|
OP_ROLE_KEY: OpRole.Optimize
|
||||||
|
})
|
||||||
|
block._sync_with_cpp()
|
@ -0,0 +1,90 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
|
||||||
|
|
||||||
|
|
||||||
|
class GradientClipHelper(object):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _is_gradient_clip_op(self, op):
|
||||||
|
return op.desc.has_attr("op_namescope") \
|
||||||
|
and op.desc.attr("op_namescope").startswith("/gradient_clip")
|
||||||
|
|
||||||
|
def prune_gradient_clip(self, block, shard):
|
||||||
|
deperated_vars = set()
|
||||||
|
deperate_op_idx = set()
|
||||||
|
for idx, op in enumerate(block.ops):
|
||||||
|
if not self._is_gradient_clip_op(op):
|
||||||
|
continue
|
||||||
|
if op.type == "sum":
|
||||||
|
continue
|
||||||
|
deperate_op = False
|
||||||
|
for input_name in op.desc.input_arg_names():
|
||||||
|
if input_name in deperated_vars:
|
||||||
|
deperate_op = True
|
||||||
|
param_name = input_name.strip("@GRAD")
|
||||||
|
if shard.is_param(param_name) and \
|
||||||
|
not shard.has_param(param_name):
|
||||||
|
deperate_op = True
|
||||||
|
|
||||||
|
if deperate_op:
|
||||||
|
deperate_op_idx.add(idx)
|
||||||
|
for output_name in op.desc.output_arg_names():
|
||||||
|
deperated_vars.add(output_name)
|
||||||
|
|
||||||
|
if not deperated_vars:
|
||||||
|
# got no gradient_clip op
|
||||||
|
return
|
||||||
|
|
||||||
|
for idx, op in reversed(list(enumerate(block.ops))):
|
||||||
|
if not self._is_gradient_clip_op(op):
|
||||||
|
continue
|
||||||
|
if idx in deperate_op_idx:
|
||||||
|
block._remove_op(idx, sync=False)
|
||||||
|
continue
|
||||||
|
reversed_inputs = []
|
||||||
|
if op.type == "sum":
|
||||||
|
for input_name in op.desc.input_arg_names():
|
||||||
|
if input_name not in deperated_vars:
|
||||||
|
reversed_inputs.append(input_name)
|
||||||
|
op.desc.set_input("X", reversed_inputs)
|
||||||
|
assert (len(op.desc.output_arg_names()) == 1)
|
||||||
|
sum_res = op.desc.output_arg_names()[0]
|
||||||
|
block._insert_op_without_sync(
|
||||||
|
idx + 1,
|
||||||
|
type='c_sync_comm_stream',
|
||||||
|
inputs={'X': sum_res},
|
||||||
|
outputs={'Out': sum_res},
|
||||||
|
attrs={'ring_id': 0,
|
||||||
|
OP_ROLE_KEY: OpRole.Optimize})
|
||||||
|
block._insert_op_without_sync(
|
||||||
|
idx + 1,
|
||||||
|
type='c_allreduce_sum',
|
||||||
|
inputs={'X': sum_res},
|
||||||
|
outputs={'Out': sum_res},
|
||||||
|
attrs={'ring_id': 0,
|
||||||
|
OP_ROLE_KEY: OpRole.Optimize})
|
||||||
|
block._insert_op_without_sync(
|
||||||
|
idx + 1,
|
||||||
|
type='c_sync_calc_stream',
|
||||||
|
inputs={'X': sum_res},
|
||||||
|
outputs={'Out': sum_res},
|
||||||
|
attrs={OP_ROLE_KEY: OpRole.Optimize})
|
||||||
|
|
||||||
|
for var_name in deperated_vars:
|
||||||
|
block._remove_var(var_name, sync=False)
|
||||||
|
block._sync_with_cpp()
|
||||||
|
return
|
@ -0,0 +1,131 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
class ProgramDeps(object):
|
||||||
|
def __init__(self, block, start_vars, end_vars):
|
||||||
|
self._block = block
|
||||||
|
# vars where to start to build the deps
|
||||||
|
self._start_vars = start_vars
|
||||||
|
# vars where to stop to build the deps
|
||||||
|
self._end_vars = end_vars
|
||||||
|
# var name -> op idxs which depends on this var
|
||||||
|
self._var_to_use_op = {}
|
||||||
|
# sub block deps which is a subset of this topo
|
||||||
|
self._sub_block_deps = {}
|
||||||
|
# var name -> op idxs which generate var
|
||||||
|
self._var_to_generate_op = {}
|
||||||
|
self._should_removed_var = set()
|
||||||
|
self._father_block_deps = None
|
||||||
|
self._build_deps()
|
||||||
|
|
||||||
|
def get_sub_block_deps(self, idx):
|
||||||
|
if idx in self._sub_block_deps:
|
||||||
|
return self._sub_block_deps[idx]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_var_deps(self, var_name):
|
||||||
|
if var_name in self._var_to_use_op:
|
||||||
|
return self._var_to_use_op[var_name]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _build_deps(self, ):
|
||||||
|
for var_name in self._start_vars:
|
||||||
|
self._var_to_use_op[var_name] = []
|
||||||
|
self._var_to_generate_op[var_name] = []
|
||||||
|
|
||||||
|
for idx, op in enumerate(self._block.ops):
|
||||||
|
if op.type in [
|
||||||
|
"c_allreduce_sum", "c_sync_comm_stream",
|
||||||
|
"c_calc_comm_stream"
|
||||||
|
]:
|
||||||
|
continue
|
||||||
|
input_vars = op.desc.input_arg_names()
|
||||||
|
output_vars = op.desc.output_arg_names()
|
||||||
|
deps_reduce = False
|
||||||
|
for input_name in input_vars:
|
||||||
|
if input_name in self._var_to_use_op:
|
||||||
|
deps_reduce = True
|
||||||
|
if not deps_reduce:
|
||||||
|
continue
|
||||||
|
for input_name in input_vars:
|
||||||
|
if input_name in self._var_to_use_op:
|
||||||
|
self._var_to_use_op[input_name].append(idx)
|
||||||
|
for output_name in output_vars:
|
||||||
|
if output_name not in self._var_to_use_op:
|
||||||
|
self._var_to_use_op[output_name] = []
|
||||||
|
if output_name not in self._var_to_generate_op:
|
||||||
|
self._var_to_generate_op[output_name] = [idx]
|
||||||
|
else:
|
||||||
|
self._var_to_generate_op[output_name].append(idx)
|
||||||
|
if op.type == "conditional_block":
|
||||||
|
# subblock
|
||||||
|
assert (op.desc.has_attr("sub_block"))
|
||||||
|
subblock_idx = op.desc.attr("sub_block").id
|
||||||
|
subblock_deps = ProgramDeps(
|
||||||
|
self._block.program.block(subblock_idx),
|
||||||
|
op.desc.input_arg_names(), op.desc.output_arg_names())
|
||||||
|
self._sub_block_deps[subblock_idx] = subblock_deps
|
||||||
|
subblock_deps._father_block_deps = self
|
||||||
|
|
||||||
|
def crop_input_var_from_op(self, op_idx, var_name):
|
||||||
|
if var_name in self._var_to_use_op:
|
||||||
|
# update var -> dep_var_op
|
||||||
|
if self._var_to_use_op[var_name] != []:
|
||||||
|
if op_idx not in self._var_to_use_op[var_name]:
|
||||||
|
raise ValueError(
|
||||||
|
"op_idx: {} is not in self._var_to_use_op[{}], "
|
||||||
|
"self._var_to_use_op[{}] is {}".format(
|
||||||
|
op_idx, var_name, var_name, self._var_to_use_op[
|
||||||
|
var_name]))
|
||||||
|
self._var_to_use_op[var_name].remove(op_idx)
|
||||||
|
# update _should_removed_var
|
||||||
|
if var_name in self._start_vars:
|
||||||
|
self._should_removed_var.discard(var_name)
|
||||||
|
elif self._var_to_use_op[
|
||||||
|
var_name] == []: # no more deps of this var
|
||||||
|
self._should_removed_var.add(var_name)
|
||||||
|
elif self._var_to_generate_op[var_name][-1] >= self._var_to_use_op[
|
||||||
|
var_name][-1]:
|
||||||
|
# there are circle in the graph
|
||||||
|
self._should_removed_var.add(var_name)
|
||||||
|
else: # input_name should not be deleted
|
||||||
|
self._should_removed_var.discard(var_name)
|
||||||
|
|
||||||
|
def crop_output_var_from_op(self, op_idx, var_name):
|
||||||
|
if var_name in self._var_to_generate_op:
|
||||||
|
assert (op_idx in self._var_to_generate_op[var_name])
|
||||||
|
self._var_to_generate_op[var_name].remove(op_idx)
|
||||||
|
if self._block.has_var(var_name):
|
||||||
|
if var_name not in self._var_to_generate_op or self._var_to_generate_op[
|
||||||
|
var_name] == []:
|
||||||
|
self._block._remove_var(var_name, sync=False)
|
||||||
|
|
||||||
|
def remove_op(self, op_idx):
|
||||||
|
# update deps
|
||||||
|
op = self._block.ops[op_idx]
|
||||||
|
for input_name in op.desc.input_arg_names():
|
||||||
|
self.crop_input_var_from_op(op_idx, input_name)
|
||||||
|
for output_name in op.desc.output_arg_names():
|
||||||
|
self.crop_output_var_from_op(op_idx, output_name)
|
||||||
|
self._block._remove_op(op_idx, sync=False)
|
||||||
|
|
||||||
|
def should_remove_op(self, op_idx):
|
||||||
|
op = self._block.ops[op_idx]
|
||||||
|
for output_name in op.desc.output_arg_names():
|
||||||
|
if output_name not in self._should_removed_var:
|
||||||
|
return False
|
||||||
|
return True
|
@ -0,0 +1,144 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op
|
||||||
|
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
|
||||||
|
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
|
||||||
|
|
||||||
|
|
||||||
|
class Shard(object):
|
||||||
|
def __init__(self, ):
|
||||||
|
self.global_params = set([])
|
||||||
|
self.worker_idx = -1
|
||||||
|
self.worker_num = -1
|
||||||
|
self.global_param2device = {}
|
||||||
|
|
||||||
|
def setup(self, params_grads, worker_idx, worker_num):
|
||||||
|
# param names of all devices
|
||||||
|
self.global_params = set([x[0].name for x in params_grads])
|
||||||
|
# _param(str) -> device_id(int)
|
||||||
|
self.worker_idx = worker_idx
|
||||||
|
self.worker_num = worker_num
|
||||||
|
# global_param2device contains fp32 params and fp16 params
|
||||||
|
self.global_param2device = self._split_params(params_grads, worker_idx,
|
||||||
|
worker_num)
|
||||||
|
|
||||||
|
def has_param(self, var_name):
|
||||||
|
return var_name in self.global_param2device and \
|
||||||
|
self._var_device_id(var_name) == self.worker_idx
|
||||||
|
|
||||||
|
def has_opt_var(self, var_name):
|
||||||
|
return self._var_device_id(var_name) == self.worker_idx
|
||||||
|
|
||||||
|
def has_var(self, var_name):
|
||||||
|
return self._var_device_id(var_name) == -1 or \
|
||||||
|
self._var_device_id(var_name) == self.worker_idx
|
||||||
|
|
||||||
|
def _split_params(self, params_grads, worker_idx, worker_num):
|
||||||
|
param2device = {}
|
||||||
|
total_param_mem = 0.0
|
||||||
|
param2mem = []
|
||||||
|
for param in [x[0] for x in params_grads]:
|
||||||
|
mem = get_var_size(param)
|
||||||
|
total_param_mem += mem
|
||||||
|
param2mem.append((param.name, mem))
|
||||||
|
device2params = {x: [] for x in range(worker_num)}
|
||||||
|
device_idx = 0
|
||||||
|
mem_accu = 0.0
|
||||||
|
for param_name, mem in param2mem:
|
||||||
|
if mem_accu > total_param_mem * 1.0 * (device_idx + 1) / worker_num:
|
||||||
|
device_idx += 1
|
||||||
|
device2params[device_idx].append(param_name)
|
||||||
|
param2device[param_name] = device_idx
|
||||||
|
mem_accu += mem
|
||||||
|
return param2device
|
||||||
|
|
||||||
|
def _var_device_id(self, var_name):
|
||||||
|
if var_name in self.global_param2device:
|
||||||
|
return self.global_param2device[var_name]
|
||||||
|
for suffix in [
|
||||||
|
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
|
||||||
|
"_beta2_pow_acc_0", "_velocity_0"
|
||||||
|
]:
|
||||||
|
base_name = re.sub(suffix, '', var_name)
|
||||||
|
if base_name in self.global_param2device:
|
||||||
|
return self.global_param2device[base_name]
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def find_broadcast_params(self, block):
|
||||||
|
broadcast_vars = set([])
|
||||||
|
fp16_params = set([])
|
||||||
|
fp16_to_fp32 = {}
|
||||||
|
|
||||||
|
param_usage = {x: 0 for x in self.global_params}
|
||||||
|
for op in block.ops:
|
||||||
|
if is_optimizer_op(op):
|
||||||
|
continue
|
||||||
|
for input_name in op.desc.input_arg_names():
|
||||||
|
if input_name in self.global_params:
|
||||||
|
param_usage[input_name] += 1
|
||||||
|
|
||||||
|
for op in block.ops:
|
||||||
|
if not FP16Utils.is_fp16_cast_op(block, op, self.global_params):
|
||||||
|
continue
|
||||||
|
input_name = op.input_arg_names[0]
|
||||||
|
output_name = op.output_arg_names[0]
|
||||||
|
broadcast_vars.add(output_name)
|
||||||
|
fp16_params.add(output_name)
|
||||||
|
fp16_to_fp32[output_name] = input_name
|
||||||
|
param_usage[input_name] -= 1
|
||||||
|
self.global_param2device[output_name] = self.global_param2device[
|
||||||
|
input_name]
|
||||||
|
|
||||||
|
for param, usage in param_usage.items():
|
||||||
|
if usage > 0:
|
||||||
|
broadcast_vars.add(param)
|
||||||
|
return broadcast_vars
|
||||||
|
|
||||||
|
def device(self, var_name):
|
||||||
|
return self._var_device_id(var_name)
|
||||||
|
|
||||||
|
def is_param(self, var_name):
|
||||||
|
return var_name in self.global_params
|
||||||
|
|
||||||
|
def is_opti_var(self, var_name):
|
||||||
|
if var_name in self.global_params:
|
||||||
|
return True
|
||||||
|
for suffix in [
|
||||||
|
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
|
||||||
|
"_beta2_pow_acc_0", "_velocity_0"
|
||||||
|
]:
|
||||||
|
base_name = re.sub(suffix, '', var_name)
|
||||||
|
if base_name in self.global_params:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ProgramSegment(object):
|
||||||
|
def __init__(self, block):
|
||||||
|
self._block = block
|
||||||
|
self._allreduce_vars = []
|
||||||
|
# sub program start idx
|
||||||
|
self._start_idx = -1
|
||||||
|
# sub program end idx
|
||||||
|
self._end_idx = -1
|
||||||
|
# param name to broadcast name
|
||||||
|
self._param2broadcast = {}
|
||||||
|
self._broadcast_vars = []
|
||||||
|
# cast op pairs, fp16 name (str) -> fp32 name (str)
|
||||||
|
self._cast_ops = {}
|
||||||
|
# fill constant vars
|
||||||
|
self._fill_constant_vars = []
|
||||||
|
# parameter mems
|
||||||
|
self._param_mem = 0.0
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,37 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_VAR_KEY
|
||||||
|
|
||||||
|
|
||||||
|
class WeightDecayHelper(object):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _is_weight_decay_op(self, op):
|
||||||
|
return op.desc.has_attr("op_namescope") \
|
||||||
|
and op.desc.attr("op_namescope").startswith("/regularization")
|
||||||
|
|
||||||
|
def prune_weight_decay(self, block, shard):
|
||||||
|
for idx, op in reversed(list(enumerate(block.ops))):
|
||||||
|
if not self._is_weight_decay_op(op):
|
||||||
|
continue
|
||||||
|
if OP_ROLE_VAR_KEY not in op.attr_names:
|
||||||
|
raise ValueError(
|
||||||
|
"The Weight Dacay op should hold op_role_var attribute"
|
||||||
|
"but the {} op does not hold op_role_var".format(op.type))
|
||||||
|
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
|
||||||
|
if not shard.has_param(op_role_var[0]):
|
||||||
|
block._remove_op(idx, sync=False)
|
||||||
|
block._sync_with_cpp()
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue