[feature] prune program by feed and fetch_list automatically (#22474)

* prune train program by fetch_list, test=develop

* add unittest for prune, test=develop

* fix pruned feed, test=develop

* support ParallelExecutor and feed prune, test=develop

* add comments, test=develop

* update unittest, test=develop

* update unittests, test=develop

* remove debug code, test=develop

* support cond in clone, test=develop

* support cond in prune, test=develop

* support multiple minimize, test=develop

* support cache, test=develop

* fix _copy_param_info_from, test=develop

* support python2 str, test=develop

* remove debug code, test=develop

* fix bug of caching CompiledProgram, test=develop

* fix multi_device issue, test=develop

* tmp

* support tuple in fetch_list and overriding use_prune, test=develop

* dont use nonlocal in python2, test=develop

* remove nonlocal, test=develop

* code clean, test=develop

* code clean, test=develop

* feed list, test=develop

* test adam, test=develop

* follow comments, test=develop

* reduce duplicate code, test=develop

* update comments, test=develop
revert-23830-2.0-beta
Leo Chen 5 years ago committed by GitHub
parent 7c55a94de5
commit a62599a888
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -113,7 +113,6 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
}
int GetOpRole(const proto::OpDesc& op_desc) {
// The op role >= 0, so -1 is used to indicate "NotFound".
for (auto& attr : op_desc.attrs()) {
if (attr.name() == OpProtoAndCheckerMaker::OpRoleAttrName()) {
PADDLE_ENFORCE_EQ(
@ -124,7 +123,10 @@ int GetOpRole(const proto::OpDesc& op_desc) {
return attr.i();
}
}
return -1;
// If attr op_role is not found, it may be operator created in c++ test, like
// prune_test.cc. In that case, the op_role should be defaut value, which is
// kNotSpecified.
return static_cast<int>(OpRole::kNotSpecified);
}
void AppendOpInputVarNames(const proto::OpDesc& op_desc,
@ -145,6 +147,16 @@ void AppendOpOutputVarNames(const proto::OpDesc& op_desc,
}
}
int FindMapByValue(const std::map<int, int>& m, int val) {
// The content in map should be >= 0, so -1 is used to indicate "NotFound".
for (auto& pair : m) {
if (pair.second == val) {
return pair.first;
}
}
return -1;
}
// block_id is the idx of the current block in the input desc
// parent_block_id is the idx of the parent of the current block
// in the output desc, -1 means the current block is global block
@ -153,30 +165,41 @@ void AppendOpOutputVarNames(const proto::OpDesc& op_desc,
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
int block_id, int parent_block_id,
std::unordered_set<std::string>* dependent_vars,
const std::set<std::string> feed_var_names) {
const std::set<std::string> feed_var_names,
std::map<int, int>* pruned_origin_block_id_map) {
auto& block = input.blocks(block_id);
auto& ops = block.ops();
bool expect_feed = true;
for (auto& op_desc : ops) {
PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed,
"All FeedOps are at the beginning of the ProgramDesc");
PADDLE_ENFORCE_EQ(
op_desc.type() != kFeedOpType || expect_feed, true,
platform::errors::PreconditionNotMet(
"All FeedOps are at the beginning of the ProgramDesc"));
expect_feed = (op_desc.type() == kFeedOpType);
}
bool expect_fetch = true;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch,
"All FetchOps must at the end of the ProgramDesc");
PADDLE_ENFORCE_EQ(op_desc.type() != kFetchOpType || expect_fetch, true,
platform::errors::PreconditionNotMet(
"All FetchOps must at the end of the ProgramDesc"));
expect_fetch = (op_desc.type() == kFetchOpType);
}
std::vector<bool> should_run;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
if (IsTarget(op_desc) || HasDependentOutputVar(op_desc, *dependent_vars)) {
// insert its input to the dependency graph
if (IsTarget(op_desc) ||
(HasDependentOutputVar(op_desc, *dependent_vars) &&
(GetOpRole(op_desc) & static_cast<int>(OpRole::kOptimize)) == 0)) {
// NOTE(zhiqiu): since optimize op takes the trainable parameters as
// inputs and output, it may introduce wrong dependency graph.
// For train mode, the optimize op should be in targets, so is not need
// and not right to mark optimize op by its outputs.
// For eval / infer mode, there is no optimize op in program.
for (auto& var : op_desc.inputs()) {
for (auto& argu : var.arguments()) {
if (feed_var_names.count(argu) == 0) {
@ -203,6 +226,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
output_block->set_idx(output_block_id);
output_block->set_parent_idx(parent_block_id);
(*pruned_origin_block_id_map)[output_block_id] = block_id;
auto* op_field = output_block->mutable_ops();
op_field->Clear();
for (size_t i = 0; i < should_run.size(); ++i) {
@ -244,7 +269,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
// output_block_id is the idx of the current block in the output desc
prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
&sub_block_dependent_vars, feed_var_names);
&sub_block_dependent_vars, feed_var_names,
pruned_origin_block_id_map);
}
}
}
@ -284,22 +310,33 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
}
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output) {
std::map<int, int> Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output) {
std::unordered_set<std::string> dependent_vars;
output->clear_blocks();
prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names);
}
int FindMapByValue(const std::map<int, int>& m, int val) {
// The content in map should be >= 0, so -1 is used to indicate "NotFound".
for (auto& pair : m) {
if (pair.second == val) {
return pair.first;
std::map<int, int> pruned_origin_block_id_map;
prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names,
&pruned_origin_block_id_map);
// update subblock idx
for (int i = 0; i < output->blocks_size(); i++) {
auto* pruned = output->mutable_blocks(i);
auto* ops = pruned->mutable_ops();
for (auto op_iter = ops->rbegin(); op_iter != ops->rend(); ++op_iter) {
auto& op_desc = *op_iter;
if (HasSubBlock(op_desc)) {
int origin_sub_idx = GetSubBlockIndex(op_desc);
auto sub_idx =
FindMapByValue(pruned_origin_block_id_map, origin_sub_idx);
PADDLE_ENFORCE_NE(sub_idx, -1,
platform::errors::NotFound(
"The origin sub block id should be found in "
"pruned_progin_block_id_map"));
SetSubBlockIndex(&op_desc, sub_idx);
}
}
}
return -1;
return pruned_origin_block_id_map;
}
void PruneBackwardImpl(proto::BlockDesc* origin, proto::BlockDesc* pruned) {
@ -348,8 +385,8 @@ void PruneBackwardImpl(proto::BlockDesc* origin, proto::BlockDesc* pruned) {
var_names.insert(op_output_vars.begin(), op_output_vars.end());
for (const auto& name : var_names) {
if (var_map.count(name)) {
// NOTE(zhiqiu): For operator in a conditional block, the related vars may
// not exist in current block, but in its futher block.
// NOTE(zhiqiu): For operator in a conditional block, the related vars
// may not exist in current block, but in its futher block.
*pruned_vars->Add() = var_map[name];
}
}
@ -389,6 +426,7 @@ std::tuple<framework::ProgramDesc, std::map<int, int>> PruneBackward(
proto::ProgramDesc pruned_desc;
pruned_desc.clear_blocks();
// Step 2. Prune backward for each block.
for (size_t i = 0; i < origin_clone.Size(); i++) {
auto pruned = proto::BlockDesc();

@ -26,9 +26,9 @@ limitations under the License. */
namespace paddle {
namespace framework {
void Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output);
std::map<int, int> Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output);
std::tuple<framework::ProgramDesc, std::map<int, int>> PruneBackward(
const framework::ProgramDesc& origin);

@ -1154,8 +1154,10 @@ All parameter, weight, gradient are variables in Paddle.
prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true);
}
proto::ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc);
return new ProgramDesc(pruned_desc);
auto pruned_origin_block_id_map =
Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc);
return std::make_tuple(ProgramDesc(pruned_desc),
pruned_origin_block_id_map);
});
m.def("prune_backward",
[](const framework::ProgramDesc &program) {

File diff suppressed because it is too large Load Diff

@ -2172,6 +2172,15 @@ class Operator(object):
return attr_map
def _is_optimize_op(self):
op_maker = core.op_proto_and_checker_maker
OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize
op_role = self.desc.attr(op_maker.kOpRoleAttrName())
if op_role & int(OPTIMIZE):
return True
else:
return False
class Block(object):
"""
@ -2706,8 +2715,8 @@ class Block(object):
assert isinstance(p, Parameter)
v = self.vars.get(p.name, None)
if v is None:
raise ValueError("_copy_param_info_from should be invoked with "
"same topology")
# if the Parameter is pruned, v may be None
continue
assert isinstance(v, Variable)
new_p = None
if in_dygraph_mode():
@ -4056,52 +4065,13 @@ class Program(object):
directly. This API is in flux and not stable.
Args:
targets(list|Variable|Operator): A list of variables or operators
targets(list|Variable|Operator): A list of variables, operators, or variable names
need to be pruned
Returns:
Program: A new, pruned program.
"""
#NOTE(zhiqiu): we sync the original program first, since its program may diff with
# its desc due to modifying desc in c++ space. E.g. save op will add kLookupTablePath in desc.
self._sync_with_cpp()
if not isinstance(targets, list):
targets = [targets]
targets_idx = []
for t in targets:
if not isinstance(t, Operator):
if isinstance(t, Variable):
# After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this
# variable here.
t.op = None
global_block = self.global_block()
for idx, op in enumerate(global_block.ops):
if t.name in op.output_arg_names:
t.op = op
break
t = t.op
if t is None:
raise ValueError(
"The target variable must have an "
"associated operator that generates it.")
else:
raise ValueError("All targets of prune() can only be "
"Variable or Operator.")
targets_idx.append([t.block.idx, t.idx])
res = Program()
res.desc = core.prune(self.desc, set(), targets_idx)
res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
res._sync_with_cpp()
return res
return self._prune_with_input([], targets)
def _prune_with_input(self, feeded_var_names, targets):
"""
@ -4115,7 +4085,7 @@ class Program(object):
Args:
feeded_var_names(list|str): A list of variable names from where
pruning start. If it is set as [], this API works just like _prune()
targets(list|Variable|Operator): A list of variables or operators
targets(list|Variable|Operator): A list of variables, operators, or variable names
need to be pruned
Returns:
@ -4140,33 +4110,47 @@ class Program(object):
for t in targets:
if not isinstance(t, Operator):
if isinstance(t, Variable):
# After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this
# variable here.
t.op = None
global_block = self.global_block()
for idx, op in enumerate(global_block.ops):
if t.name in op.output_arg_names:
t.op = op
break
t = t.op
if t is None:
raise ValueError(
"The target variable must have an "
"associated operator that generates it.")
name = t.name
elif isinstance(t, six.string_types):
name = str(t)
else:
raise ValueError("All targets of prune() can only be "
"Variable or Operator.")
# After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this
# variable here.
target_op = None
global_block = self.global_block()
for idx, op in enumerate(global_block.ops):
if name in op.output_arg_names:
# NOTE(zhiqiu): Find op that generate target name.
# Skip optimize op except for optimize op in targets,
# since optimize op generates parameters.
if op._is_optimize_op() and op not in targets:
continue
else:
target_op = op
break
t = target_op
if t is None:
raise ValueError("The target variable must have an "
"associated operator that generates it.")
targets_idx.append([t.block.idx, t.idx])
res = Program()
res.desc = core.prune(self.desc, set(feeded_var_names), targets_idx)
res.desc, pruned_origin_block_id_map = core.prune(self.desc,
set(feeded_var_names),
targets_idx)
res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
res._sync_with_cpp()
res._copy_param_info_from(self)
res._copy_data_info_from(self, pruned_origin_block_id_map)
res._copy_dist_param_info_from(self)
return res
def _inference_optimize(self, prune_read_op=True):

@ -811,6 +811,9 @@ class Optimizer(object):
tuple: tuple (optimize_ops, params_grads), A list of operators appended
by minimize and a list of (param, grad) variable pairs, param is
``Parameter``, grad is the gradient value corresponding to the parameter.
The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
indicate program pruning. If so, the program will be pruned by ``feed`` and
``fetch_list`` before run, see details in ``Executor``.
Examples:
Please refer to the example of current Optimizer.

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save