|
|
|
@ -100,16 +100,12 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
|
|
|
|
|
// Find all while_ops and while_grad_ops in the graph or program
|
|
|
|
|
// The while_grad_op and while_op may located in different blocks
|
|
|
|
|
// So we should traverse all blocks in the program and find them out.
|
|
|
|
|
static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
|
|
|
|
|
static void FindAllWhileAndWhileGradOp(const framework::ProgramDesc &program,
|
|
|
|
|
std::vector<OpVariant> *while_ops,
|
|
|
|
|
std::vector<OpVariant> *while_grad_ops) {
|
|
|
|
|
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size());
|
|
|
|
|
|
|
|
|
|
if (while_ops->empty()) return;
|
|
|
|
|
|
|
|
|
|
const auto *program =
|
|
|
|
|
while_ops->front().Attr<framework::BlockDesc *>(kStepBlock)->Program();
|
|
|
|
|
for (size_t i = 1; i < program->Size(); ++i) {
|
|
|
|
|
auto &block = program->Block(i);
|
|
|
|
|
for (size_t i = 1; i < program.Size(); ++i) {
|
|
|
|
|
auto &block = program.Block(i);
|
|
|
|
|
for (size_t j = 0; j < block.OpSize(); ++j) {
|
|
|
|
|
auto *op = block.Op(j);
|
|
|
|
|
if (op->Type() == "while") {
|
|
|
|
@ -125,8 +121,9 @@ static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
|
|
|
|
|
std::vector<OpVariant> *while_ops, std::vector<OpVariant> *while_grad_ops) {
|
|
|
|
|
FindAllWhileAndWhileGradOp(while_ops, while_grad_ops);
|
|
|
|
|
const framework::ProgramDesc &program, std::vector<OpVariant> *while_ops,
|
|
|
|
|
std::vector<OpVariant> *while_grad_ops) {
|
|
|
|
|
FindAllWhileAndWhileGradOp(program, while_ops, while_grad_ops);
|
|
|
|
|
|
|
|
|
|
VLOG(2) << "Found while op num: " << while_ops->size()
|
|
|
|
|
<< ", while grad op num: " << while_grad_ops->size();
|
|
|
|
@ -155,7 +152,7 @@ static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
|
|
|
|
|
int block_id,
|
|
|
|
|
const framework::ProgramDesc &program, int block_id,
|
|
|
|
|
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
|
|
|
|
|
// If block_id is not 0, returns
|
|
|
|
|
// This is because all while_ops and while_grad_ops in the whole program
|
|
|
|
@ -176,10 +173,12 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
|
|
|
|
|
bwd_ops.emplace_back(op.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
|
|
|
|
|
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops,
|
|
|
|
|
&bwd_ops);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
|
|
|
|
|
const framework::ProgramDesc &program,
|
|
|
|
|
const std::vector<framework::OperatorBase *> &while_ops,
|
|
|
|
|
const std::vector<framework::OperatorBase *> &while_grad_ops) {
|
|
|
|
|
std::vector<OpVariant> fwd_ops, bwd_ops;
|
|
|
|
@ -193,7 +192,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
|
|
|
|
|
bwd_ops.emplace_back(op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
|
|
|
|
|
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops,
|
|
|
|
|
&bwd_ops);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|