add sharding strategy in fleet(#27900)

* add sharding
revert-27871-prv-conv-grad-opt
mapingshuo 4 years ago committed by GitHub
parent 4877bd5944
commit 81244fbfab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,6 +24,10 @@ enum Mode {
message RecomputeConfig { repeated string checkpoints = 1; }
message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
}
message AMPConfig {
optional float init_loss_scaling = 1 [ default = 32768.0 ];
optional int32 incr_every_n_steps = 2 [ default = 1000 ];
@ -130,6 +134,7 @@ message DistributedStrategy {
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional bool adaptive_localsgd = 24 [ default = false ];
optional bool fp16_allreduce = 25 [ default = false ];
optional bool sharding = 26 [ default = false ];
optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
@ -141,6 +146,7 @@ message DistributedStrategy {
optional LarsConfig lars_configs = 108;
optional LambConfig lamb_configs = 109;
optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110;
optional ShardingConfig sharding_configs = 111;
optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202;
}

@ -611,6 +611,55 @@ class DistributedStrategy(object):
"checkpoint_configs")
assign_configs_value(self.strategy.recompute_configs, configs)
@property
def sharding(self):
"""
Indicating whether we are using sharding Optimizer for memory
optimization
Default value: False
Examples:
.. code-block:: python
import paddle.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.sharding = True
"""
return self.strategy.sharding
@sharding.setter
@is_strict_auto
def sharding(self, flag):
if isinstance(flag, bool):
self.strategy.sharding = flag
else:
print("WARNING: sharding should have value of bool type")
@property
def sharding_configs(self):
"""
Set sharding configurations.
**Note**:
fuse_broadcast_MB(float): size of a fused group of broadcasted parameters.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 32}
"""
return get_msg_dict(self.strategy.sharding_configs)
@sharding_configs.setter
@is_strict_auto
def sharding_configs(self, configs):
check_configs_key(self.strategy.sharding_configs, configs,
"sharding_configs")
assign_configs_value(self.strategy.sharding_configs, configs)
@property
def pipeline(self):
"""

@ -24,3 +24,4 @@ from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer
from .dgc_optimizer import DGCOptimizer
from .lamb_optimizer import LambOptimizer
from .fp16_allreduce_optimizer import FP16AllReduceOptimizer
from .sharding_optimizer import ShardingOptimizer

@ -99,6 +99,12 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward
})
def _wait(self, current_endpoint, endpoints):
assert (self.wait_port)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
wait_server_ready(other_endpoints)
def _broadcast_params(self):
block = self.startup_program.global_block()
ring_id = -1

@ -30,6 +30,10 @@ class DGCOptimizer(MetaOptimizerBase):
super(DGCOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy)
def _init_dgc_opt(self):
if self.dgc_opt is not None:
return
opt = self.inner_opt
if not self.role_maker._is_collective:
@ -86,13 +90,16 @@ class DGCOptimizer(MetaOptimizerBase):
parameter_list=None,
no_grad_set=None,
callbacks=None):
self._init_dgc_opt()
return self.dgc_opt.backward(loss, startup_program, parameter_list,
no_grad_set, callbacks)
def apply_gradients(self, params_grads):
self._init_dgc_opt()
return self.dgc_opt.apply_gradients(params_grads=params_grads)
def apply_optimize(self, loss, startup_program, params_grads):
self._init_dgc_opt()
return self.dgc_opt.apply_optimize(
loss, startup_program=startup_program, params_grads=params_grads)
@ -101,6 +108,7 @@ class DGCOptimizer(MetaOptimizerBase):
startup_program=None,
parameter_list=None,
no_grad_set=None):
self._init_dgc_opt()
optimize_ops, params_grads = \
self.dgc_opt.minimize(loss, startup_program,
parameter_list, no_grad_set)

@ -0,0 +1,13 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,154 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op, OP_ROLE_KEY, OpRole
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
from paddle.fluid import core
class FP16Utils(object):
def __init__(self):
pass
@staticmethod
def is_fp16_cast_op(block, op, params):
if op.type != "cast":
return False
if is_optimizer_op(op):
return False
assert (len(op.desc.input_arg_names()) == 1)
assert (len(op.desc.output_arg_names()) == 1)
input_name, output_name = op.desc.input_arg_names()[
0], op.desc.output_arg_names()[0]
if input_name not in params:
return False
input_var = block.var(input_name)
output_var = block.var(output_name)
if input_var.dtype != core.VarDesc.VarType.FP32 or \
output_var.dtype != core.VarDesc.VarType.FP16:
return False
return True
@staticmethod
def is_fp32_cast_op(block, op):
if op.type != "cast":
return False
if not is_optimizer_op(op):
return False
assert (len(op.desc.input_arg_names()) == 1)
assert (len(op.desc.output_arg_names()) == 1)
input_name, output_name = op.desc.input_arg_names()[
0], op.desc.output_arg_names()[0]
input_var = block.var(input_name)
output_var = block.var(output_name)
if input_var.dtype != core.VarDesc.VarType.FP16 or \
output_var.dtype != core.VarDesc.VarType.FP32:
return False
return True
@staticmethod
def remove_cast_op(block, params, segment, offset):
inserted_op_num = 0
for op_idx in reversed(
range(offset + segment._start_idx, offset + segment._end_idx)):
op = block.ops[op_idx]
if FP16Utils.is_fp16_cast_op(block, op, params):
block._remove_op(op_idx, sync=False)
inserted_op_num -= 1
block._sync_with_cpp()
return inserted_op_num
@staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, nrings):
# remove cast
for idx, op in reversed(list(enumerate(block.ops))):
if not FP16Utils.is_fp32_cast_op(block, op):
continue
output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError("Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(
input_name))
if output_name in reduced_grads_to_param:
continue
if shard.has_param(param_name):
continue
block._remove_op(idx, sync=False)
block._remove_var(output_name, sync=False)
block._sync_with_cpp()
update_loss_scaling_op_idx = -1
inf_var_name = ''
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = []
for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError(
"Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(input_name))
if shard.has_param(param_name):
reversed_x.append(input_name)
op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x)
if update_loss_scaling_op_idx == -1:
return
inf_var = block.var(inf_var_name)
inf_var_fp32 = block.create_var(
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_sharding = block.create_var(
name=inf_var_name + "@sharding",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_fp32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_fp32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
[inf_var_fp32])
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
type='c_allreduce_max',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
comm_op_num = insert_sync_comm_ops(
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32])
block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num,
type='cast',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_sharding},
attrs={
"in_dtype": inf_var_fp32.dtype,
"out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._sync_with_cpp()

@ -0,0 +1,90 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object):
def __init__(self):
pass
def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")
def prune_gradient_clip(self, block, shard):
deperated_vars = set()
deperate_op_idx = set()
for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op):
continue
if op.type == "sum":
continue
deperate_op = False
for input_name in op.desc.input_arg_names():
if input_name in deperated_vars:
deperate_op = True
param_name = input_name.strip("@GRAD")
if shard.is_param(param_name) and \
not shard.has_param(param_name):
deperate_op = True
if deperate_op:
deperate_op_idx.add(idx)
for output_name in op.desc.output_arg_names():
deperated_vars.add(output_name)
if not deperated_vars:
# got no gradient_clip op
return
for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_gradient_clip_op(op):
continue
if idx in deperate_op_idx:
block._remove_op(idx, sync=False)
continue
reversed_inputs = []
if op.type == "sum":
for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars:
reversed_inputs.append(input_name)
op.desc.set_input("X", reversed_inputs)
assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1,
type='c_sync_comm_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
block._insert_op_without_sync(
idx + 1,
type='c_sync_calc_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={OP_ROLE_KEY: OpRole.Optimize})
for var_name in deperated_vars:
block._remove_var(var_name, sync=False)
block._sync_with_cpp()
return

@ -0,0 +1,131 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class ProgramDeps(object):
def __init__(self, block, start_vars, end_vars):
self._block = block
# vars where to start to build the deps
self._start_vars = start_vars
# vars where to stop to build the deps
self._end_vars = end_vars
# var name -> op idxs which depends on this var
self._var_to_use_op = {}
# sub block deps which is a subset of this topo
self._sub_block_deps = {}
# var name -> op idxs which generate var
self._var_to_generate_op = {}
self._should_removed_var = set()
self._father_block_deps = None
self._build_deps()
def get_sub_block_deps(self, idx):
if idx in self._sub_block_deps:
return self._sub_block_deps[idx]
else:
return None
def get_var_deps(self, var_name):
if var_name in self._var_to_use_op:
return self._var_to_use_op[var_name]
else:
return None
def _build_deps(self, ):
for var_name in self._start_vars:
self._var_to_use_op[var_name] = []
self._var_to_generate_op[var_name] = []
for idx, op in enumerate(self._block.ops):
if op.type in [
"c_allreduce_sum", "c_sync_comm_stream",
"c_calc_comm_stream"
]:
continue
input_vars = op.desc.input_arg_names()
output_vars = op.desc.output_arg_names()
deps_reduce = False
for input_name in input_vars:
if input_name in self._var_to_use_op:
deps_reduce = True
if not deps_reduce:
continue
for input_name in input_vars:
if input_name in self._var_to_use_op:
self._var_to_use_op[input_name].append(idx)
for output_name in output_vars:
if output_name not in self._var_to_use_op:
self._var_to_use_op[output_name] = []
if output_name not in self._var_to_generate_op:
self._var_to_generate_op[output_name] = [idx]
else:
self._var_to_generate_op[output_name].append(idx)
if op.type == "conditional_block":
# subblock
assert (op.desc.has_attr("sub_block"))
subblock_idx = op.desc.attr("sub_block").id
subblock_deps = ProgramDeps(
self._block.program.block(subblock_idx),
op.desc.input_arg_names(), op.desc.output_arg_names())
self._sub_block_deps[subblock_idx] = subblock_deps
subblock_deps._father_block_deps = self
def crop_input_var_from_op(self, op_idx, var_name):
if var_name in self._var_to_use_op:
# update var -> dep_var_op
if self._var_to_use_op[var_name] != []:
if op_idx not in self._var_to_use_op[var_name]:
raise ValueError(
"op_idx: {} is not in self._var_to_use_op[{}], "
"self._var_to_use_op[{}] is {}".format(
op_idx, var_name, var_name, self._var_to_use_op[
var_name]))
self._var_to_use_op[var_name].remove(op_idx)
# update _should_removed_var
if var_name in self._start_vars:
self._should_removed_var.discard(var_name)
elif self._var_to_use_op[
var_name] == []: # no more deps of this var
self._should_removed_var.add(var_name)
elif self._var_to_generate_op[var_name][-1] >= self._var_to_use_op[
var_name][-1]:
# there are circle in the graph
self._should_removed_var.add(var_name)
else: # input_name should not be deleted
self._should_removed_var.discard(var_name)
def crop_output_var_from_op(self, op_idx, var_name):
if var_name in self._var_to_generate_op:
assert (op_idx in self._var_to_generate_op[var_name])
self._var_to_generate_op[var_name].remove(op_idx)
if self._block.has_var(var_name):
if var_name not in self._var_to_generate_op or self._var_to_generate_op[
var_name] == []:
self._block._remove_var(var_name, sync=False)
def remove_op(self, op_idx):
# update deps
op = self._block.ops[op_idx]
for input_name in op.desc.input_arg_names():
self.crop_input_var_from_op(op_idx, input_name)
for output_name in op.desc.output_arg_names():
self.crop_output_var_from_op(op_idx, output_name)
self._block._remove_op(op_idx, sync=False)
def should_remove_op(self, op_idx):
op = self._block.ops[op_idx]
for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var:
return False
return True

@ -0,0 +1,144 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
class Shard(object):
def __init__(self, ):
self.global_params = set([])
self.worker_idx = -1
self.worker_num = -1
self.global_param2device = {}
def setup(self, params_grads, worker_idx, worker_num):
# param names of all devices
self.global_params = set([x[0].name for x in params_grads])
# _param(str) -> device_id(int)
self.worker_idx = worker_idx
self.worker_num = worker_num
# global_param2device contains fp32 params and fp16 params
self.global_param2device = self._split_params(params_grads, worker_idx,
worker_num)
def has_param(self, var_name):
return var_name in self.global_param2device and \
self._var_device_id(var_name) == self.worker_idx
def has_opt_var(self, var_name):
return self._var_device_id(var_name) == self.worker_idx
def has_var(self, var_name):
return self._var_device_id(var_name) == -1 or \
self._var_device_id(var_name) == self.worker_idx
def _split_params(self, params_grads, worker_idx, worker_num):
param2device = {}
total_param_mem = 0.0
param2mem = []
for param in [x[0] for x in params_grads]:
mem = get_var_size(param)
total_param_mem += mem
param2mem.append((param.name, mem))
device2params = {x: [] for x in range(worker_num)}
device_idx = 0
mem_accu = 0.0
for param_name, mem in param2mem:
if mem_accu > total_param_mem * 1.0 * (device_idx + 1) / worker_num:
device_idx += 1
device2params[device_idx].append(param_name)
param2device[param_name] = device_idx
mem_accu += mem
return param2device
def _var_device_id(self, var_name):
if var_name in self.global_param2device:
return self.global_param2device[var_name]
for suffix in [
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
"_beta2_pow_acc_0", "_velocity_0"
]:
base_name = re.sub(suffix, '', var_name)
if base_name in self.global_param2device:
return self.global_param2device[base_name]
return -1
def find_broadcast_params(self, block):
broadcast_vars = set([])
fp16_params = set([])
fp16_to_fp32 = {}
param_usage = {x: 0 for x in self.global_params}
for op in block.ops:
if is_optimizer_op(op):
continue
for input_name in op.desc.input_arg_names():
if input_name in self.global_params:
param_usage[input_name] += 1
for op in block.ops:
if not FP16Utils.is_fp16_cast_op(block, op, self.global_params):
continue
input_name = op.input_arg_names[0]
output_name = op.output_arg_names[0]
broadcast_vars.add(output_name)
fp16_params.add(output_name)
fp16_to_fp32[output_name] = input_name
param_usage[input_name] -= 1
self.global_param2device[output_name] = self.global_param2device[
input_name]
for param, usage in param_usage.items():
if usage > 0:
broadcast_vars.add(param)
return broadcast_vars
def device(self, var_name):
return self._var_device_id(var_name)
def is_param(self, var_name):
return var_name in self.global_params
def is_opti_var(self, var_name):
if var_name in self.global_params:
return True
for suffix in [
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
"_beta2_pow_acc_0", "_velocity_0"
]:
base_name = re.sub(suffix, '', var_name)
if base_name in self.global_params:
return True
return False
class ProgramSegment(object):
def __init__(self, block):
self._block = block
self._allreduce_vars = []
# sub program start idx
self._start_idx = -1
# sub program end idx
self._end_idx = -1
# param name to broadcast name
self._param2broadcast = {}
self._broadcast_vars = []
# cast op pairs, fp16 name (str) -> fp32 name (str)
self._cast_ops = {}
# fill constant vars
self._fill_constant_vars = []
# parameter mems
self._param_mem = 0.0

@ -0,0 +1,37 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_VAR_KEY
class WeightDecayHelper(object):
def __init__(self):
pass
def _is_weight_decay_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/regularization")
def prune_weight_decay(self, block, shard):
for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_weight_decay_op(op):
continue
if OP_ROLE_VAR_KEY not in op.attr_names:
raise ValueError(
"The Weight Dacay op should hold op_role_var attribute"
"but the {} op does not hold op_role_var".format(op.type))
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if not shard.has_param(op_role_var[0]):
block._remove_op(idx, sync=False)
block._sync_with_cpp()

@ -669,7 +669,7 @@ def append_gradient_clip_ops(param_grads):
if g is None:
continue
with p.block.program._optimized_guard(
[p, g]), framework.name_scope('gradient_clip_@CLIP'):
[p, g]), framework.name_scope('gradient_clip'):
clip_attr = getattr(p, 'gradient_clip_attr', None)
if clip_attr is None:
return param_grads
@ -685,7 +685,7 @@ def append_gradient_clip_ops(param_grads):
if g is None:
continue
with p.block.program._optimized_guard(
[p, g]), framework.name_scope('graident_clip_@CLIP'):
[p, g]), framework.name_scope('gradient_clip'):
param, new_grad = clip_attr._create_operators(param=p, grad=g)
param_new_grad_name_dict[param.name] = new_grad.name
res.append([param, new_grad])

@ -2100,10 +2100,16 @@ class Operator(object):
% (out_proto.name, len(out_args)))
out_arg_names = []
for arg in out_args:
out_arg_names.append(cpt.to_text(arg.name))
if isinstance(arg, six.string_types):
out_arg_names.append(arg)
else:
out_arg_names.append(cpt.to_text(arg.name))
# TODO(minqiyang): could we remove variable's op in static mode?
if not in_dygraph_mode():
arg.op = self
if isinstance(arg, six.string_types):
block.var(arg).op = self
else:
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
if op_attrs is not None:
@ -2837,8 +2843,9 @@ class Block(object):
self._sync_with_cpp()
return var
def _remove_var(self, name):
self._sync_with_cpp()
def _remove_var(self, name, sync=True):
if sync == True:
self._sync_with_cpp()
self.desc._remove_var(cpt.to_bytes(name))
del self.vars[name]
@ -2936,7 +2943,23 @@ class Block(object):
self.ops.insert(index, op)
return op
def _remove_op(self, index):
def _insert_op_without_sync(self, index, *args, **kwargs):
"""
Insert an Operator according to the giving arguments,
without sync_with_cpp to meke the compilation faster.
Args:
index(int): the place that the operator to insert.
Returns:
Operator: the insert Operator.
"""
op_desc = self.desc._insert_op(index)
op = Operator(block=self, desc=op_desc, *args, **kwargs)
self.ops.insert(index, op)
return op
def _remove_op(self, index, sync=True):
"""
Remove the specific position operator.
@ -2946,7 +2969,8 @@ class Block(object):
Returns:
None
"""
self._sync_with_cpp()
if sync == True:
self._sync_with_cpp()
self.desc._remove_op(index, index + 1)
del self.ops[index]

@ -41,6 +41,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_sharding_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_lars_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_lamb_meta_optimizer)
@ -461,6 +462,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_fleet_recompute_meta_optimizer MODULES test_fleet_recompute_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_graph_executor MODULES test_fleet_graph_executor ENVS ${dist_ENVS})
py_test_modules(test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_sharding_meta_optimizer MODULES test_fleet_sharding_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_fp16_allreduce_meta_optimizer MODULES test_fleet_fp16_allreduce_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS ${dist_ENVS})

@ -55,14 +55,22 @@ class TestFleetMetaOptimizer(unittest.TestCase):
strategy,
train_prog,
startup_prog,
name='momentum'):
name='momentum',
regularization=None,
grad_clip=None):
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
if name == 'momentum':
optimizer = paddle.fluid.optimizer.Momentum(
learning_rate=0.01, momentum=0.9)
learning_rate=0.01,
momentum=0.9,
regularization=regularization,
grad_clip=grad_clip)
elif name == 'adam':
optimizer = paddle.fluid.optimizer.Adam(learning_rate=0.01)
optimizer = paddle.fluid.optimizer.Adam(
learning_rate=0.01,
regularization=regularization,
grad_clip=grad_clip)
optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy)
optimizer.minimize(loss)
@ -121,5 +129,8 @@ class TestFleetMetaOptimizer(unittest.TestCase):
elif name == "gradient_merge":
strategy.gradient_merge = True
strategy.gradient_merge_configs = {"k_steps": 2, "avg": True}
elif name == "sharding":
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 0.2}
else:
raise NotImplementedError()

@ -32,9 +32,6 @@ class TestFleetGradientMergeMetaOptimizer(TestFleetMetaOptimizer):
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
vars = [x.name for x in train_prog.list_vars()]
with open("main_program", 'w') as f:
f.write(str(train_prog))
self.assertIn('@GradientMerge', ''.join(vars))
def test_recom_gm_optimizer(self):

@ -148,6 +148,7 @@ packages=['paddle',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.meta_optimizers',
'paddle.distributed.fleet.meta_optimizers.sharding',
'paddle.distributed.fleet.runtime',
'paddle.distributed.fleet.dataset',
'paddle.distributed.fleet.data_generator',

Loading…
Cancel
Save