|
|
|
@ -272,6 +272,9 @@ class RecurrentOp : public RecurrentBase {
|
|
|
|
|
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
|
|
|
|
|
|
|
|
|
|
auto *program = block->Program();
|
|
|
|
|
auto ctx = executor.Prepare(
|
|
|
|
|
*program, block->ID(), std::vector<std::string>() /*skip_ref_cnt_vars*/,
|
|
|
|
|
true /*force_disable_gc*/);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < seq_len; ++i) {
|
|
|
|
|
size_t seq_offset = reverse ? seq_len - i - 1 : i;
|
|
|
|
@ -305,10 +308,9 @@ class RecurrentOp : public RecurrentBase {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Every inputs are linked now, execute!
|
|
|
|
|
executor.Run(*program, &cur_scope, block->ID(),
|
|
|
|
|
false /*create_local_scope*/, true /*create_vars*/,
|
|
|
|
|
std::vector<std::string>() /*skip_ref_cnt_vars*/,
|
|
|
|
|
true /*force_disable_gc*/);
|
|
|
|
|
executor.RunPreparedContext(ctx.get(), &cur_scope,
|
|
|
|
|
false /*create_local_scope*/,
|
|
|
|
|
true /*create_vars*/, true /* keep_kids */);
|
|
|
|
|
|
|
|
|
|
// Copy inside::output -> outside::output
|
|
|
|
|
// outside::output[seq_offset: seq_offset + 1] = inside::output
|
|
|
|
@ -366,6 +368,9 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
framework::Executor executor(place);
|
|
|
|
|
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
|
|
|
|
|
auto *program = block->Program();
|
|
|
|
|
auto ctx = executor.Prepare(
|
|
|
|
|
*program, block->ID(), std::vector<std::string>() /*skip_ref_cnt_vars*/,
|
|
|
|
|
true /*force_disable_gc*/);
|
|
|
|
|
|
|
|
|
|
for (size_t step_id = 0; step_id < seq_len; ++step_id) {
|
|
|
|
|
size_t seq_offset = reverse ? step_id : seq_len - step_id - 1;
|
|
|
|
@ -423,10 +428,9 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
|
|
|
|
|
VLOG(5) << "Recurrent memory linking finished ";
|
|
|
|
|
// Run step block with cur_scope
|
|
|
|
|
executor.Run(*program, &cur_scope, block->ID(),
|
|
|
|
|
false /*create_local_scope*/, true /*create_vars*/,
|
|
|
|
|
std::vector<std::string>() /*skip_ref_cnt_vars*/,
|
|
|
|
|
true /*force_disable_gc*/);
|
|
|
|
|
executor.RunPreparedContext(ctx.get(), &cur_scope,
|
|
|
|
|
false /*create_local_scope*/,
|
|
|
|
|
true /*create_vars*/, true /* keep_kids */);
|
|
|
|
|
|
|
|
|
|
VLOG(5) << "executor.Run finished ";
|
|
|
|
|
|
|
|
|
|