You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
132 lines
5.5 KiB
132 lines
5.5 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
class ProgramDeps(object):
|
|
def __init__(self, block, start_vars, end_vars):
|
|
self._block = block
|
|
# vars where to start to build the deps
|
|
self._start_vars = start_vars
|
|
# vars where to stop to build the deps
|
|
self._end_vars = end_vars
|
|
# var name -> op idxs which depends on this var
|
|
self._var_to_use_op = {}
|
|
# sub block deps which is a subset of this topo
|
|
self._sub_block_deps = {}
|
|
# var name -> op idxs which generate var
|
|
self._var_to_generate_op = {}
|
|
self._should_removed_var = set()
|
|
self._father_block_deps = None
|
|
self._build_deps()
|
|
|
|
def get_sub_block_deps(self, idx):
|
|
if idx in self._sub_block_deps:
|
|
return self._sub_block_deps[idx]
|
|
else:
|
|
return None
|
|
|
|
def get_var_deps(self, var_name):
|
|
if var_name in self._var_to_use_op:
|
|
return self._var_to_use_op[var_name]
|
|
else:
|
|
return None
|
|
|
|
def _build_deps(self, ):
|
|
for var_name in self._start_vars:
|
|
self._var_to_use_op[var_name] = []
|
|
self._var_to_generate_op[var_name] = []
|
|
|
|
for idx, op in enumerate(self._block.ops):
|
|
if op.type in [
|
|
"c_allreduce_sum", "c_sync_comm_stream",
|
|
"c_calc_comm_stream"
|
|
]:
|
|
continue
|
|
input_vars = op.desc.input_arg_names()
|
|
output_vars = op.desc.output_arg_names()
|
|
deps_reduce = False
|
|
for input_name in input_vars:
|
|
if input_name in self._var_to_use_op:
|
|
deps_reduce = True
|
|
if not deps_reduce:
|
|
continue
|
|
for input_name in input_vars:
|
|
if input_name in self._var_to_use_op:
|
|
self._var_to_use_op[input_name].append(idx)
|
|
for output_name in output_vars:
|
|
if output_name not in self._var_to_use_op:
|
|
self._var_to_use_op[output_name] = []
|
|
if output_name not in self._var_to_generate_op:
|
|
self._var_to_generate_op[output_name] = [idx]
|
|
else:
|
|
self._var_to_generate_op[output_name].append(idx)
|
|
if op.type == "conditional_block":
|
|
# subblock
|
|
assert (op.desc.has_attr("sub_block"))
|
|
subblock_idx = op.desc.attr("sub_block").id
|
|
subblock_deps = ProgramDeps(
|
|
self._block.program.block(subblock_idx),
|
|
op.desc.input_arg_names(), op.desc.output_arg_names())
|
|
self._sub_block_deps[subblock_idx] = subblock_deps
|
|
subblock_deps._father_block_deps = self
|
|
|
|
def crop_input_var_from_op(self, op_idx, var_name):
|
|
if var_name in self._var_to_use_op:
|
|
# update var -> dep_var_op
|
|
if self._var_to_use_op[var_name] != []:
|
|
if op_idx not in self._var_to_use_op[var_name]:
|
|
raise ValueError(
|
|
"op_idx: {} is not in self._var_to_use_op[{}], "
|
|
"self._var_to_use_op[{}] is {}".format(
|
|
op_idx, var_name, var_name, self._var_to_use_op[
|
|
var_name]))
|
|
self._var_to_use_op[var_name].remove(op_idx)
|
|
# update _should_removed_var
|
|
if var_name in self._start_vars:
|
|
self._should_removed_var.discard(var_name)
|
|
elif self._var_to_use_op[
|
|
var_name] == []: # no more deps of this var
|
|
self._should_removed_var.add(var_name)
|
|
elif self._var_to_generate_op[var_name][-1] >= self._var_to_use_op[
|
|
var_name][-1]:
|
|
# there are circle in the graph
|
|
self._should_removed_var.add(var_name)
|
|
else: # input_name should not be deleted
|
|
self._should_removed_var.discard(var_name)
|
|
|
|
def crop_output_var_from_op(self, op_idx, var_name):
|
|
if var_name in self._var_to_generate_op:
|
|
assert (op_idx in self._var_to_generate_op[var_name])
|
|
self._var_to_generate_op[var_name].remove(op_idx)
|
|
if self._block.has_var(var_name):
|
|
if var_name not in self._var_to_generate_op or self._var_to_generate_op[
|
|
var_name] == []:
|
|
self._block._remove_var(var_name, sync=False)
|
|
|
|
def remove_op(self, op_idx):
|
|
# update deps
|
|
op = self._block.ops[op_idx]
|
|
for input_name in op.desc.input_arg_names():
|
|
self.crop_input_var_from_op(op_idx, input_name)
|
|
for output_name in op.desc.output_arg_names():
|
|
self.crop_output_var_from_op(op_idx, output_name)
|
|
self._block._remove_op(op_idx, sync=False)
|
|
|
|
def should_remove_op(self, op_idx):
|
|
op = self._block.ops[op_idx]
|
|
for output_name in op.desc.output_arg_names():
|
|
if output_name not in self._should_removed_var:
|
|
return False
|
|
return True
|