You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
292 lines
9.9 KiB
292 lines
9.9 KiB
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
|
|
#include <string>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
#include "paddle/fluid/framework/program_desc.h"
|
|
|
|
namespace paddle {
|
|
namespace operators {
|
|
|
|
// OpVariant is a wrapper class of OpDesc and OperatorBase
|
|
// So that API would be the same.
|
|
class OpVariant {
|
|
struct InputsVisitor
|
|
: public boost::static_visitor<const framework::VariableNameMap *> {
|
|
template <typename OpType>
|
|
const framework::VariableNameMap *operator()(const OpType *op) const {
|
|
return &(op->Inputs());
|
|
}
|
|
};
|
|
|
|
struct OutputsVisitor
|
|
: public boost::static_visitor<const framework::VariableNameMap *> {
|
|
template <typename OpType>
|
|
const framework::VariableNameMap *operator()(const OpType *op) const {
|
|
return &(op->Outputs());
|
|
}
|
|
};
|
|
|
|
struct AttributeMapVisitor
|
|
: public boost::static_visitor<const framework::AttributeMap *> {
|
|
const framework::AttributeMap *operator()(
|
|
const framework::OpDesc *op) const {
|
|
return &(op->GetAttrMap());
|
|
}
|
|
|
|
const framework::AttributeMap *operator()(
|
|
const framework::OperatorBase *op) const {
|
|
return &(op->Attrs());
|
|
}
|
|
};
|
|
|
|
struct RawPointerVisitor : public boost::static_visitor<const void *> {
|
|
template <typename OpType>
|
|
const void *operator()(const OpType *op) const {
|
|
return op;
|
|
}
|
|
};
|
|
|
|
public:
|
|
OpVariant(const framework::OperatorBase *op) : op_(op) {} // NOLINT
|
|
|
|
OpVariant(const framework::OpDesc *op) : op_(op) {} // NOLINT
|
|
|
|
const framework::VariableNameMap &Inputs() const {
|
|
return *boost::apply_visitor(InputsVisitor(), op_);
|
|
}
|
|
|
|
const framework::VariableNameMap &Outputs() const {
|
|
return *boost::apply_visitor(OutputsVisitor(), op_);
|
|
}
|
|
|
|
const framework::AttributeMap &Attrs() const {
|
|
return *boost::apply_visitor(AttributeMapVisitor(), op_);
|
|
}
|
|
|
|
template <typename AttrType>
|
|
const AttrType &Attr(const std::string &name) const {
|
|
auto &attrs = Attrs();
|
|
auto it = attrs.find(name);
|
|
PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name);
|
|
return boost::get<AttrType>(it->second);
|
|
}
|
|
|
|
bool operator==(const OpVariant &other) const {
|
|
return RawPointer() == other.RawPointer();
|
|
}
|
|
|
|
const void *RawPointer() const {
|
|
return boost::apply_visitor(RawPointerVisitor(), op_);
|
|
}
|
|
|
|
int which() const { return static_cast<int>(op_.which()); }
|
|
|
|
struct Hasher {
|
|
size_t operator()(const OpVariant &op) const {
|
|
return reinterpret_cast<size_t>(op.RawPointer());
|
|
}
|
|
};
|
|
|
|
private:
|
|
const boost::variant<const framework::OperatorBase *,
|
|
const framework::OpDesc *>
|
|
op_;
|
|
};
|
|
|
|
static std::string GetDebugString(const std::vector<std::string> &names) {
|
|
if (names.empty()) return "";
|
|
std::string ret = names[0];
|
|
for (size_t i = 1; i < names.size(); ++i) {
|
|
ret += (" " + names[i]);
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
// Set skip variables of while_op and while_grad_op
|
|
// These variables should be skipped when eager deletion enables.
|
|
// It is because:
|
|
// 1. while_grad_op needs some variables defined in while_op.
|
|
// 2. while_grad_op needs variables from the previous time step.
|
|
static void SetSkipVars(const OpVariant &op, std::vector<std::string> attr) {
|
|
auto &attrs = const_cast<framework::AttributeMap &>(op.Attrs());
|
|
VLOG(2) << "Prepare to skip " << attr.size()
|
|
<< " var(s): " << GetDebugString(attr);
|
|
attrs[kSkipEagerDeletionVars] = std::move(attr);
|
|
}
|
|
|
|
// Check whether the forward while_op and while_grad_op match
|
|
// The program may have many while_ops.
|
|
static bool IsMatchedWhileOpAndWhileGradOp(const OpVariant &fwd_op,
|
|
const OpVariant &grad_op) {
|
|
return fwd_op.Inputs().at(kX) == grad_op.Inputs().at(kX) &&
|
|
fwd_op.Outputs().at(kOutputs) == grad_op.Inputs().at(kOutputs);
|
|
}
|
|
|
|
// Test whether the variable is skippable in forward while_op
|
|
// The variable is skippable in while_op when the variable used in while_grad
|
|
// is not from grad_block.
|
|
static bool IsSkippableVar(const std::string &name,
|
|
framework::BlockDesc *grad_block) {
|
|
return name != framework::kEmptyVarName && !grad_block->HasVar(name);
|
|
}
|
|
|
|
static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
|
|
const OpVariant &bwd_op) {
|
|
auto *grad_block = bwd_op.Attr<framework::BlockDesc *>(kStepBlock);
|
|
|
|
// Find all skippable variables in forward while_op
|
|
std::unordered_set<std::string> forward_skip_vars;
|
|
for (auto *op_desc : grad_block->AllOps()) {
|
|
for (auto &in_arg_name : op_desc->InputArgumentNames()) {
|
|
if (IsSkippableVar(in_arg_name, grad_block)) {
|
|
forward_skip_vars.insert(in_arg_name);
|
|
}
|
|
}
|
|
|
|
for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
|
|
if (IsSkippableVar(out_arg_name, grad_block)) {
|
|
forward_skip_vars.insert(out_arg_name);
|
|
}
|
|
}
|
|
}
|
|
|
|
SetSkipVars(fwd_op, std::vector<std::string>(forward_skip_vars.begin(),
|
|
forward_skip_vars.end()));
|
|
|
|
// Find all skippable variables in while_grad_op
|
|
// The skipped variables are those which would be used across time steps.
|
|
auto &fwd_input = fwd_op.Inputs().at(kX);
|
|
auto &in_grads = bwd_op.Outputs().at(framework::GradVarName(kX));
|
|
PADDLE_ENFORCE_EQ(
|
|
fwd_input.size(), in_grads.size(),
|
|
"Backward input gradient number does not match forward input number.");
|
|
|
|
std::unordered_set<std::string> backward_skip_vars;
|
|
for (size_t i = 0; i < in_grads.size(); ++i) {
|
|
if (in_grads[i] == framework::kEmptyVarName) {
|
|
continue;
|
|
}
|
|
backward_skip_vars.insert(in_grads[i]);
|
|
backward_skip_vars.insert(framework::GradVarName(fwd_input[i]));
|
|
}
|
|
|
|
SetSkipVars(bwd_op, std::vector<std::string>(backward_skip_vars.begin(),
|
|
backward_skip_vars.end()));
|
|
}
|
|
|
|
// Find all while_ops and while_grad_ops in the graph or program
|
|
// The while_grad_op and while_op may located in different blocks
|
|
// So we should traverse all blocks in the program and find them out.
|
|
static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
|
|
std::vector<OpVariant> *while_grad_ops) {
|
|
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size());
|
|
|
|
if (while_ops->empty()) return;
|
|
|
|
const auto *program =
|
|
while_ops->front().Attr<framework::BlockDesc *>(kStepBlock)->Program();
|
|
for (size_t i = 1; i < program->Size(); ++i) {
|
|
auto &block = program->Block(i);
|
|
for (size_t j = 0; j < block.OpSize(); ++j) {
|
|
auto *op = block.Op(j);
|
|
if (op->Type() == "while") {
|
|
while_ops->emplace_back(op);
|
|
} else if (op->Type() == "while_grad") {
|
|
while_grad_ops->emplace_back(op);
|
|
}
|
|
}
|
|
}
|
|
|
|
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size(),
|
|
"There are extra while_grad ops in the graph or program");
|
|
}
|
|
|
|
static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
|
|
std::vector<OpVariant> *while_ops, std::vector<OpVariant> *while_grad_ops) {
|
|
FindAllWhileAndWhileGradOp(while_ops, while_grad_ops);
|
|
|
|
VLOG(2) << "Found while op num: " << while_ops->size()
|
|
<< ", while grad op num: " << while_grad_ops->size();
|
|
|
|
if (while_grad_ops->empty()) {
|
|
return;
|
|
}
|
|
|
|
std::unordered_set<OpVariant, OpVariant::Hasher> while_op_set(
|
|
while_ops->begin(), while_ops->end());
|
|
|
|
for (auto &bwd_op : *while_grad_ops) {
|
|
const OpVariant *matched_fwd_op = nullptr;
|
|
for (auto &fwd_op : while_op_set) {
|
|
if (IsMatchedWhileOpAndWhileGradOp(fwd_op, bwd_op)) {
|
|
PADDLE_ENFORCE(matched_fwd_op == nullptr,
|
|
"Found multiple matched while ops");
|
|
matched_fwd_op = &fwd_op;
|
|
}
|
|
}
|
|
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
|
|
"Cannot find matched forward while op.");
|
|
ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op);
|
|
while_op_set.erase(*matched_fwd_op);
|
|
}
|
|
}
|
|
|
|
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
|
|
int block_id,
|
|
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
|
|
// If block_id is not 0, returns
|
|
// This is because all while_ops and while_grad_ops in the whole program
|
|
// would be processed when block_id is 0 (i.e. when Executor::Run() or
|
|
// ParallelExecutor constructs).
|
|
|
|
// What's more, all while_ops and while_grad_ops must be processed when
|
|
// block_id is zero. If not, while_op may run first and erase variables
|
|
// used in while_grad_op, and in this moment, while_grad_ops may be not
|
|
// constructed yet.
|
|
if (block_id != 0) return;
|
|
|
|
std::vector<OpVariant> fwd_ops, bwd_ops;
|
|
for (auto &op : all_ops) {
|
|
if (op->Type() == "while") {
|
|
fwd_ops.emplace_back(op.get());
|
|
} else if (op->Type() == "while_grad") {
|
|
bwd_ops.emplace_back(op.get());
|
|
}
|
|
}
|
|
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
|
|
}
|
|
|
|
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
|
|
const std::vector<framework::OperatorBase *> &while_ops,
|
|
const std::vector<framework::OperatorBase *> &while_grad_ops) {
|
|
std::vector<OpVariant> fwd_ops, bwd_ops;
|
|
fwd_ops.reserve(while_ops.size());
|
|
for (auto *op : while_ops) {
|
|
fwd_ops.emplace_back(op);
|
|
}
|
|
|
|
bwd_ops.reserve(while_grad_ops.size());
|
|
for (auto *op : while_grad_ops) {
|
|
bwd_ops.emplace_back(op);
|
|
}
|
|
|
|
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
|
|
}
|
|
|
|
} // namespace operators
|
|
} // namespace paddle
|