You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py

132 lines
5.5 KiB

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class ProgramDeps(object):
def __init__(self, block, start_vars, end_vars):
self._block = block
# vars where to start to build the deps
self._start_vars = start_vars
# vars where to stop to build the deps
self._end_vars = end_vars
# var name -> op idxs which depends on this var
self._var_to_use_op = {}
# sub block deps which is a subset of this topo
self._sub_block_deps = {}
# var name -> op idxs which generate var
self._var_to_generate_op = {}
self._should_removed_var = set()
self._father_block_deps = None
self._build_deps()
def get_sub_block_deps(self, idx):
if idx in self._sub_block_deps:
return self._sub_block_deps[idx]
else:
return None
def get_var_deps(self, var_name):
if var_name in self._var_to_use_op:
return self._var_to_use_op[var_name]
else:
return None
def _build_deps(self, ):
for var_name in self._start_vars:
self._var_to_use_op[var_name] = []
self._var_to_generate_op[var_name] = []
for idx, op in enumerate(self._block.ops):
if op.type in [
"c_allreduce_sum", "c_sync_comm_stream",
"c_calc_comm_stream"
]:
continue
input_vars = op.desc.input_arg_names()
output_vars = op.desc.output_arg_names()
deps_reduce = False
for input_name in input_vars:
if input_name in self._var_to_use_op:
deps_reduce = True
if not deps_reduce:
continue
for input_name in input_vars:
if input_name in self._var_to_use_op:
self._var_to_use_op[input_name].append(idx)
for output_name in output_vars:
if output_name not in self._var_to_use_op:
self._var_to_use_op[output_name] = []
if output_name not in self._var_to_generate_op:
self._var_to_generate_op[output_name] = [idx]
else:
self._var_to_generate_op[output_name].append(idx)
if op.type == "conditional_block":
# subblock
assert (op.desc.has_attr("sub_block"))
subblock_idx = op.desc.attr("sub_block").id
subblock_deps = ProgramDeps(
self._block.program.block(subblock_idx),
op.desc.input_arg_names(), op.desc.output_arg_names())
self._sub_block_deps[subblock_idx] = subblock_deps
subblock_deps._father_block_deps = self
def crop_input_var_from_op(self, op_idx, var_name):
if var_name in self._var_to_use_op:
# update var -> dep_var_op
if self._var_to_use_op[var_name] != []:
if op_idx not in self._var_to_use_op[var_name]:
raise ValueError(
"op_idx: {} is not in self._var_to_use_op[{}], "
"self._var_to_use_op[{}] is {}".format(
op_idx, var_name, var_name, self._var_to_use_op[
var_name]))
self._var_to_use_op[var_name].remove(op_idx)
# update _should_removed_var
if var_name in self._start_vars:
self._should_removed_var.discard(var_name)
elif self._var_to_use_op[
var_name] == []: # no more deps of this var
self._should_removed_var.add(var_name)
elif self._var_to_generate_op[var_name][-1] >= self._var_to_use_op[
var_name][-1]:
# there are circle in the graph
self._should_removed_var.add(var_name)
else: # input_name should not be deleted
self._should_removed_var.discard(var_name)
def crop_output_var_from_op(self, op_idx, var_name):
if var_name in self._var_to_generate_op:
assert (op_idx in self._var_to_generate_op[var_name])
self._var_to_generate_op[var_name].remove(op_idx)
if self._block.has_var(var_name):
if var_name not in self._var_to_generate_op or self._var_to_generate_op[
var_name] == []:
self._block._remove_var(var_name, sync=False)
def remove_op(self, op_idx):
# update deps
op = self._block.ops[op_idx]
for input_name in op.desc.input_arg_names():
self.crop_input_var_from_op(op_idx, input_name)
for output_name in op.desc.output_arg_names():
self.crop_output_var_from_op(op_idx, output_name)
self._block._remove_op(op_idx, sync=False)
def should_remove_op(self, op_idx):
op = self._block.ops[op_idx]
for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var:
return False
return True