parent
4877bd5944
commit
81244fbfab
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,154 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op, OP_ROLE_KEY, OpRole
|
||||
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
|
||||
|
||||
from paddle.fluid import core
|
||||
|
||||
|
||||
class FP16Utils(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def is_fp16_cast_op(block, op, params):
|
||||
if op.type != "cast":
|
||||
return False
|
||||
if is_optimizer_op(op):
|
||||
return False
|
||||
assert (len(op.desc.input_arg_names()) == 1)
|
||||
assert (len(op.desc.output_arg_names()) == 1)
|
||||
input_name, output_name = op.desc.input_arg_names()[
|
||||
0], op.desc.output_arg_names()[0]
|
||||
if input_name not in params:
|
||||
return False
|
||||
input_var = block.var(input_name)
|
||||
output_var = block.var(output_name)
|
||||
if input_var.dtype != core.VarDesc.VarType.FP32 or \
|
||||
output_var.dtype != core.VarDesc.VarType.FP16:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def is_fp32_cast_op(block, op):
|
||||
if op.type != "cast":
|
||||
return False
|
||||
if not is_optimizer_op(op):
|
||||
return False
|
||||
assert (len(op.desc.input_arg_names()) == 1)
|
||||
assert (len(op.desc.output_arg_names()) == 1)
|
||||
input_name, output_name = op.desc.input_arg_names()[
|
||||
0], op.desc.output_arg_names()[0]
|
||||
input_var = block.var(input_name)
|
||||
output_var = block.var(output_name)
|
||||
if input_var.dtype != core.VarDesc.VarType.FP16 or \
|
||||
output_var.dtype != core.VarDesc.VarType.FP32:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def remove_cast_op(block, params, segment, offset):
|
||||
inserted_op_num = 0
|
||||
for op_idx in reversed(
|
||||
range(offset + segment._start_idx, offset + segment._end_idx)):
|
||||
op = block.ops[op_idx]
|
||||
if FP16Utils.is_fp16_cast_op(block, op, params):
|
||||
block._remove_op(op_idx, sync=False)
|
||||
inserted_op_num -= 1
|
||||
block._sync_with_cpp()
|
||||
return inserted_op_num
|
||||
|
||||
@staticmethod
|
||||
def prune_fp16(block, shard, reduced_grads_to_param, nrings):
|
||||
# remove cast
|
||||
for idx, op in reversed(list(enumerate(block.ops))):
|
||||
if not FP16Utils.is_fp32_cast_op(block, op):
|
||||
continue
|
||||
output_name = op.desc.output_arg_names()[0]
|
||||
param_name = output_name.strip("@GRAD")
|
||||
if param_name not in shard.global_params:
|
||||
raise ValueError("Input 'X' of check_finite_and_unscale must"
|
||||
"be grads, but {} is not a grad".format(
|
||||
input_name))
|
||||
if output_name in reduced_grads_to_param:
|
||||
continue
|
||||
if shard.has_param(param_name):
|
||||
continue
|
||||
block._remove_op(idx, sync=False)
|
||||
block._remove_var(output_name, sync=False)
|
||||
|
||||
block._sync_with_cpp()
|
||||
update_loss_scaling_op_idx = -1
|
||||
inf_var_name = ''
|
||||
for idx, op in reversed(list(enumerate(block.ops))):
|
||||
if op.type == "update_loss_scaling":
|
||||
update_loss_scaling_op_idx = idx
|
||||
inf_var_name = op.desc.input('FoundInfinite')[0]
|
||||
op._rename_input(inf_var_name, inf_var_name + "@sharding")
|
||||
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
|
||||
reversed_x = []
|
||||
for input_name in op.desc.input('X'):
|
||||
param_name = input_name.strip("@GRAD")
|
||||
if param_name not in shard.global_params:
|
||||
raise ValueError(
|
||||
"Input 'X' of check_finite_and_unscale must"
|
||||
"be grads, but {} is not a grad".format(input_name))
|
||||
if shard.has_param(param_name):
|
||||
reversed_x.append(input_name)
|
||||
op.desc.set_input('X', reversed_x)
|
||||
op.desc.set_output('Out', reversed_x)
|
||||
if update_loss_scaling_op_idx == -1:
|
||||
return
|
||||
inf_var = block.var(inf_var_name)
|
||||
inf_var_fp32 = block.create_var(
|
||||
name=inf_var_name + "@cast_int32",
|
||||
shape=inf_var.shape,
|
||||
dtype=core.VarDesc.VarType.INT32)
|
||||
inf_var_sharding = block.create_var(
|
||||
name=inf_var_name + "@sharding",
|
||||
shape=inf_var.shape,
|
||||
dtype=inf_var.dtype)
|
||||
block._insert_op_without_sync(
|
||||
update_loss_scaling_op_idx,
|
||||
type='cast',
|
||||
inputs={'X': inf_var},
|
||||
outputs={'Out': inf_var_fp32},
|
||||
attrs={
|
||||
"in_dtype": inf_var.dtype,
|
||||
"out_dtype": inf_var_fp32.dtype,
|
||||
OP_ROLE_KEY: OpRole.Optimize
|
||||
})
|
||||
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
|
||||
[inf_var_fp32])
|
||||
block._insert_op_without_sync(
|
||||
update_loss_scaling_op_idx + 2,
|
||||
type='c_allreduce_max',
|
||||
inputs={'X': inf_var_fp32},
|
||||
outputs={'Out': inf_var_fp32},
|
||||
attrs={'ring_id': 0,
|
||||
OP_ROLE_KEY: OpRole.Optimize})
|
||||
comm_op_num = insert_sync_comm_ops(
|
||||
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32])
|
||||
block._insert_op_without_sync(
|
||||
update_loss_scaling_op_idx + 3 + comm_op_num,
|
||||
type='cast',
|
||||
inputs={'X': inf_var_fp32},
|
||||
outputs={'Out': inf_var_sharding},
|
||||
attrs={
|
||||
"in_dtype": inf_var_fp32.dtype,
|
||||
"out_dtype": inf_var_sharding.dtype,
|
||||
OP_ROLE_KEY: OpRole.Optimize
|
||||
})
|
||||
block._sync_with_cpp()
|
@ -0,0 +1,90 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
|
||||
|
||||
|
||||
class GradientClipHelper(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _is_gradient_clip_op(self, op):
|
||||
return op.desc.has_attr("op_namescope") \
|
||||
and op.desc.attr("op_namescope").startswith("/gradient_clip")
|
||||
|
||||
def prune_gradient_clip(self, block, shard):
|
||||
deperated_vars = set()
|
||||
deperate_op_idx = set()
|
||||
for idx, op in enumerate(block.ops):
|
||||
if not self._is_gradient_clip_op(op):
|
||||
continue
|
||||
if op.type == "sum":
|
||||
continue
|
||||
deperate_op = False
|
||||
for input_name in op.desc.input_arg_names():
|
||||
if input_name in deperated_vars:
|
||||
deperate_op = True
|
||||
param_name = input_name.strip("@GRAD")
|
||||
if shard.is_param(param_name) and \
|
||||
not shard.has_param(param_name):
|
||||
deperate_op = True
|
||||
|
||||
if deperate_op:
|
||||
deperate_op_idx.add(idx)
|
||||
for output_name in op.desc.output_arg_names():
|
||||
deperated_vars.add(output_name)
|
||||
|
||||
if not deperated_vars:
|
||||
# got no gradient_clip op
|
||||
return
|
||||
|
||||
for idx, op in reversed(list(enumerate(block.ops))):
|
||||
if not self._is_gradient_clip_op(op):
|
||||
continue
|
||||
if idx in deperate_op_idx:
|
||||
block._remove_op(idx, sync=False)
|
||||
continue
|
||||
reversed_inputs = []
|
||||
if op.type == "sum":
|
||||
for input_name in op.desc.input_arg_names():
|
||||
if input_name not in deperated_vars:
|
||||
reversed_inputs.append(input_name)
|
||||
op.desc.set_input("X", reversed_inputs)
|
||||
assert (len(op.desc.output_arg_names()) == 1)
|
||||
sum_res = op.desc.output_arg_names()[0]
|
||||
block._insert_op_without_sync(
|
||||
idx + 1,
|
||||
type='c_sync_comm_stream',
|
||||
inputs={'X': sum_res},
|
||||
outputs={'Out': sum_res},
|
||||
attrs={'ring_id': 0,
|
||||
OP_ROLE_KEY: OpRole.Optimize})
|
||||
block._insert_op_without_sync(
|
||||
idx + 1,
|
||||
type='c_allreduce_sum',
|
||||
inputs={'X': sum_res},
|
||||
outputs={'Out': sum_res},
|
||||
attrs={'ring_id': 0,
|
||||
OP_ROLE_KEY: OpRole.Optimize})
|
||||
block._insert_op_without_sync(
|
||||
idx + 1,
|
||||
type='c_sync_calc_stream',
|
||||
inputs={'X': sum_res},
|
||||
outputs={'Out': sum_res},
|
||||
attrs={OP_ROLE_KEY: OpRole.Optimize})
|
||||
|
||||
for var_name in deperated_vars:
|
||||
block._remove_var(var_name, sync=False)
|
||||
block._sync_with_cpp()
|
||||
return
|
@ -0,0 +1,131 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class ProgramDeps(object):
|
||||
def __init__(self, block, start_vars, end_vars):
|
||||
self._block = block
|
||||
# vars where to start to build the deps
|
||||
self._start_vars = start_vars
|
||||
# vars where to stop to build the deps
|
||||
self._end_vars = end_vars
|
||||
# var name -> op idxs which depends on this var
|
||||
self._var_to_use_op = {}
|
||||
# sub block deps which is a subset of this topo
|
||||
self._sub_block_deps = {}
|
||||
# var name -> op idxs which generate var
|
||||
self._var_to_generate_op = {}
|
||||
self._should_removed_var = set()
|
||||
self._father_block_deps = None
|
||||
self._build_deps()
|
||||
|
||||
def get_sub_block_deps(self, idx):
|
||||
if idx in self._sub_block_deps:
|
||||
return self._sub_block_deps[idx]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_var_deps(self, var_name):
|
||||
if var_name in self._var_to_use_op:
|
||||
return self._var_to_use_op[var_name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _build_deps(self, ):
|
||||
for var_name in self._start_vars:
|
||||
self._var_to_use_op[var_name] = []
|
||||
self._var_to_generate_op[var_name] = []
|
||||
|
||||
for idx, op in enumerate(self._block.ops):
|
||||
if op.type in [
|
||||
"c_allreduce_sum", "c_sync_comm_stream",
|
||||
"c_calc_comm_stream"
|
||||
]:
|
||||
continue
|
||||
input_vars = op.desc.input_arg_names()
|
||||
output_vars = op.desc.output_arg_names()
|
||||
deps_reduce = False
|
||||
for input_name in input_vars:
|
||||
if input_name in self._var_to_use_op:
|
||||
deps_reduce = True
|
||||
if not deps_reduce:
|
||||
continue
|
||||
for input_name in input_vars:
|
||||
if input_name in self._var_to_use_op:
|
||||
self._var_to_use_op[input_name].append(idx)
|
||||
for output_name in output_vars:
|
||||
if output_name not in self._var_to_use_op:
|
||||
self._var_to_use_op[output_name] = []
|
||||
if output_name not in self._var_to_generate_op:
|
||||
self._var_to_generate_op[output_name] = [idx]
|
||||
else:
|
||||
self._var_to_generate_op[output_name].append(idx)
|
||||
if op.type == "conditional_block":
|
||||
# subblock
|
||||
assert (op.desc.has_attr("sub_block"))
|
||||
subblock_idx = op.desc.attr("sub_block").id
|
||||
subblock_deps = ProgramDeps(
|
||||
self._block.program.block(subblock_idx),
|
||||
op.desc.input_arg_names(), op.desc.output_arg_names())
|
||||
self._sub_block_deps[subblock_idx] = subblock_deps
|
||||
subblock_deps._father_block_deps = self
|
||||
|
||||
def crop_input_var_from_op(self, op_idx, var_name):
|
||||
if var_name in self._var_to_use_op:
|
||||
# update var -> dep_var_op
|
||||
if self._var_to_use_op[var_name] != []:
|
||||
if op_idx not in self._var_to_use_op[var_name]:
|
||||
raise ValueError(
|
||||
"op_idx: {} is not in self._var_to_use_op[{}], "
|
||||
"self._var_to_use_op[{}] is {}".format(
|
||||
op_idx, var_name, var_name, self._var_to_use_op[
|
||||
var_name]))
|
||||
self._var_to_use_op[var_name].remove(op_idx)
|
||||
# update _should_removed_var
|
||||
if var_name in self._start_vars:
|
||||
self._should_removed_var.discard(var_name)
|
||||
elif self._var_to_use_op[
|
||||
var_name] == []: # no more deps of this var
|
||||
self._should_removed_var.add(var_name)
|
||||
elif self._var_to_generate_op[var_name][-1] >= self._var_to_use_op[
|
||||
var_name][-1]:
|
||||
# there are circle in the graph
|
||||
self._should_removed_var.add(var_name)
|
||||
else: # input_name should not be deleted
|
||||
self._should_removed_var.discard(var_name)
|
||||
|
||||
def crop_output_var_from_op(self, op_idx, var_name):
|
||||
if var_name in self._var_to_generate_op:
|
||||
assert (op_idx in self._var_to_generate_op[var_name])
|
||||
self._var_to_generate_op[var_name].remove(op_idx)
|
||||
if self._block.has_var(var_name):
|
||||
if var_name not in self._var_to_generate_op or self._var_to_generate_op[
|
||||
var_name] == []:
|
||||
self._block._remove_var(var_name, sync=False)
|
||||
|
||||
def remove_op(self, op_idx):
|
||||
# update deps
|
||||
op = self._block.ops[op_idx]
|
||||
for input_name in op.desc.input_arg_names():
|
||||
self.crop_input_var_from_op(op_idx, input_name)
|
||||
for output_name in op.desc.output_arg_names():
|
||||
self.crop_output_var_from_op(op_idx, output_name)
|
||||
self._block._remove_op(op_idx, sync=False)
|
||||
|
||||
def should_remove_op(self, op_idx):
|
||||
op = self._block.ops[op_idx]
|
||||
for output_name in op.desc.output_arg_names():
|
||||
if output_name not in self._should_removed_var:
|
||||
return False
|
||||
return True
|
@ -0,0 +1,144 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op
|
||||
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
|
||||
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
|
||||
|
||||
|
||||
class Shard(object):
|
||||
def __init__(self, ):
|
||||
self.global_params = set([])
|
||||
self.worker_idx = -1
|
||||
self.worker_num = -1
|
||||
self.global_param2device = {}
|
||||
|
||||
def setup(self, params_grads, worker_idx, worker_num):
|
||||
# param names of all devices
|
||||
self.global_params = set([x[0].name for x in params_grads])
|
||||
# _param(str) -> device_id(int)
|
||||
self.worker_idx = worker_idx
|
||||
self.worker_num = worker_num
|
||||
# global_param2device contains fp32 params and fp16 params
|
||||
self.global_param2device = self._split_params(params_grads, worker_idx,
|
||||
worker_num)
|
||||
|
||||
def has_param(self, var_name):
|
||||
return var_name in self.global_param2device and \
|
||||
self._var_device_id(var_name) == self.worker_idx
|
||||
|
||||
def has_opt_var(self, var_name):
|
||||
return self._var_device_id(var_name) == self.worker_idx
|
||||
|
||||
def has_var(self, var_name):
|
||||
return self._var_device_id(var_name) == -1 or \
|
||||
self._var_device_id(var_name) == self.worker_idx
|
||||
|
||||
def _split_params(self, params_grads, worker_idx, worker_num):
|
||||
param2device = {}
|
||||
total_param_mem = 0.0
|
||||
param2mem = []
|
||||
for param in [x[0] for x in params_grads]:
|
||||
mem = get_var_size(param)
|
||||
total_param_mem += mem
|
||||
param2mem.append((param.name, mem))
|
||||
device2params = {x: [] for x in range(worker_num)}
|
||||
device_idx = 0
|
||||
mem_accu = 0.0
|
||||
for param_name, mem in param2mem:
|
||||
if mem_accu > total_param_mem * 1.0 * (device_idx + 1) / worker_num:
|
||||
device_idx += 1
|
||||
device2params[device_idx].append(param_name)
|
||||
param2device[param_name] = device_idx
|
||||
mem_accu += mem
|
||||
return param2device
|
||||
|
||||
def _var_device_id(self, var_name):
|
||||
if var_name in self.global_param2device:
|
||||
return self.global_param2device[var_name]
|
||||
for suffix in [
|
||||
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
|
||||
"_beta2_pow_acc_0", "_velocity_0"
|
||||
]:
|
||||
base_name = re.sub(suffix, '', var_name)
|
||||
if base_name in self.global_param2device:
|
||||
return self.global_param2device[base_name]
|
||||
return -1
|
||||
|
||||
def find_broadcast_params(self, block):
|
||||
broadcast_vars = set([])
|
||||
fp16_params = set([])
|
||||
fp16_to_fp32 = {}
|
||||
|
||||
param_usage = {x: 0 for x in self.global_params}
|
||||
for op in block.ops:
|
||||
if is_optimizer_op(op):
|
||||
continue
|
||||
for input_name in op.desc.input_arg_names():
|
||||
if input_name in self.global_params:
|
||||
param_usage[input_name] += 1
|
||||
|
||||
for op in block.ops:
|
||||
if not FP16Utils.is_fp16_cast_op(block, op, self.global_params):
|
||||
continue
|
||||
input_name = op.input_arg_names[0]
|
||||
output_name = op.output_arg_names[0]
|
||||
broadcast_vars.add(output_name)
|
||||
fp16_params.add(output_name)
|
||||
fp16_to_fp32[output_name] = input_name
|
||||
param_usage[input_name] -= 1
|
||||
self.global_param2device[output_name] = self.global_param2device[
|
||||
input_name]
|
||||
|
||||
for param, usage in param_usage.items():
|
||||
if usage > 0:
|
||||
broadcast_vars.add(param)
|
||||
return broadcast_vars
|
||||
|
||||
def device(self, var_name):
|
||||
return self._var_device_id(var_name)
|
||||
|
||||
def is_param(self, var_name):
|
||||
return var_name in self.global_params
|
||||
|
||||
def is_opti_var(self, var_name):
|
||||
if var_name in self.global_params:
|
||||
return True
|
||||
for suffix in [
|
||||
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
|
||||
"_beta2_pow_acc_0", "_velocity_0"
|
||||
]:
|
||||
base_name = re.sub(suffix, '', var_name)
|
||||
if base_name in self.global_params:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ProgramSegment(object):
|
||||
def __init__(self, block):
|
||||
self._block = block
|
||||
self._allreduce_vars = []
|
||||
# sub program start idx
|
||||
self._start_idx = -1
|
||||
# sub program end idx
|
||||
self._end_idx = -1
|
||||
# param name to broadcast name
|
||||
self._param2broadcast = {}
|
||||
self._broadcast_vars = []
|
||||
# cast op pairs, fp16 name (str) -> fp32 name (str)
|
||||
self._cast_ops = {}
|
||||
# fill constant vars
|
||||
self._fill_constant_vars = []
|
||||
# parameter mems
|
||||
self._param_mem = 0.0
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,37 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_VAR_KEY
|
||||
|
||||
|
||||
class WeightDecayHelper(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _is_weight_decay_op(self, op):
|
||||
return op.desc.has_attr("op_namescope") \
|
||||
and op.desc.attr("op_namescope").startswith("/regularization")
|
||||
|
||||
def prune_weight_decay(self, block, shard):
|
||||
for idx, op in reversed(list(enumerate(block.ops))):
|
||||
if not self._is_weight_decay_op(op):
|
||||
continue
|
||||
if OP_ROLE_VAR_KEY not in op.attr_names:
|
||||
raise ValueError(
|
||||
"The Weight Dacay op should hold op_role_var attribute"
|
||||
"but the {} op does not hold op_role_var".format(op.type))
|
||||
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
|
||||
if not shard.has_param(op_role_var[0]):
|
||||
block._remove_op(idx, sync=False)
|
||||
block._sync_with_cpp()
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue