|
|
|
@ -37,6 +37,20 @@ constexpr char kInitStateGrads[] = "initial_states" GRAD_SUFFIX;
|
|
|
|
|
|
|
|
|
|
using StepScopeVar = std::vector<framework::Scope *>;
|
|
|
|
|
|
|
|
|
|
static void ClearStepScopes(const platform::DeviceContext &dev_ctx,
|
|
|
|
|
framework::Scope *parent_scope,
|
|
|
|
|
StepScopeVar *step_scopes) {
|
|
|
|
|
if (step_scopes->empty()) return;
|
|
|
|
|
|
|
|
|
|
dev_ctx.Wait();
|
|
|
|
|
|
|
|
|
|
for (auto *sub_scope : *step_scopes) {
|
|
|
|
|
parent_scope->DeleteScope(sub_scope);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
step_scopes->clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// StepScopes manages scopes inside RNN.
|
|
|
|
|
// StepScopes::CurScope() get the current scope
|
|
|
|
|
// StepScopes::ExScope() get the ex-scope, or scope in previous time step.
|
|
|
|
@ -53,7 +67,8 @@ using StepScopeVar = std::vector<framework::Scope *>;
|
|
|
|
|
// access scopes from begin to end.
|
|
|
|
|
class StepScopes {
|
|
|
|
|
public:
|
|
|
|
|
StepScopes(const framework::Scope &parent, StepScopeVar *scopes,
|
|
|
|
|
StepScopes(const platform::DeviceContext &dev_ctx,
|
|
|
|
|
const framework::Scope &parent, StepScopeVar *scopes,
|
|
|
|
|
bool is_train, size_t seq_len, bool is_backward = false)
|
|
|
|
|
: counter_(is_backward ? seq_len - 1 : 0UL),
|
|
|
|
|
scopes_(scopes),
|
|
|
|
@ -63,7 +78,7 @@ class StepScopes {
|
|
|
|
|
PADDLE_ENFORCE(is_train || !is_backward,
|
|
|
|
|
"Cannot backward when is not training");
|
|
|
|
|
if (!is_backward_) {
|
|
|
|
|
PADDLE_ENFORCE(scopes->empty());
|
|
|
|
|
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&parent), scopes);
|
|
|
|
|
scopes->reserve(static_cast<size_t>(num_step_scopes));
|
|
|
|
|
for (size_t i = 0; i < num_step_scopes; ++i) {
|
|
|
|
|
scopes->emplace_back(&parent.NewScope());
|
|
|
|
@ -244,14 +259,15 @@ class RecurrentOp : public RecurrentBase {
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
bool has_state = Attr<bool>(kHasStates);
|
|
|
|
|
auto seq_len = static_cast<size_t>(this->GetSequenceLength(scope));
|
|
|
|
|
VLOG(3) << "Static RNN input sequence length = " << seq_len;
|
|
|
|
|
StepScopes scopes = CreateStepScopes(scope, seq_len);
|
|
|
|
|
auto reverse = Attr<bool>(kReverse);
|
|
|
|
|
|
|
|
|
|
// get device context from pool
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(place);
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "Static RNN input sequence length = " << seq_len;
|
|
|
|
|
StepScopes scopes = CreateStepScopes(dev_ctx, scope, seq_len);
|
|
|
|
|
auto reverse = Attr<bool>(kReverse);
|
|
|
|
|
|
|
|
|
|
framework::Executor executor(place);
|
|
|
|
|
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
|
|
|
|
|
|
|
|
|
@ -316,11 +332,12 @@ class RecurrentOp : public RecurrentBase {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
StepScopes CreateStepScopes(const framework::Scope &scope,
|
|
|
|
|
StepScopes CreateStepScopes(const platform::DeviceContext &dev_ctx,
|
|
|
|
|
const framework::Scope &scope,
|
|
|
|
|
size_t seq_len) const {
|
|
|
|
|
auto *var = scope.FindVar(Output(kStepScopes));
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr);
|
|
|
|
|
return StepScopes(scope, var->GetMutable<StepScopeVar>(),
|
|
|
|
|
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
|
|
|
|
|
Attr<bool>(kIsTrain), seq_len);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -338,17 +355,18 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
bool has_state = Attr<bool>(kHasStates);
|
|
|
|
|
const size_t seq_len = static_cast<size_t>(GetSequenceLength(scope));
|
|
|
|
|
StepScopes scopes = CreateStepScopes(scope, seq_len);
|
|
|
|
|
|
|
|
|
|
// get device context from pool
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(place);
|
|
|
|
|
|
|
|
|
|
StepScopes scopes = CreateStepScopes(dev_ctx, scope, seq_len);
|
|
|
|
|
auto reverse = Attr<bool>(kReverse);
|
|
|
|
|
|
|
|
|
|
framework::Executor executor(place);
|
|
|
|
|
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
|
|
|
|
|
auto *program = block->Program();
|
|
|
|
|
|
|
|
|
|
// get device context from pool
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(place);
|
|
|
|
|
|
|
|
|
|
for (size_t step_id = 0; step_id < seq_len; ++step_id) {
|
|
|
|
|
size_t seq_offset = reverse ? step_id : seq_len - step_id - 1;
|
|
|
|
|
VLOG(3) << "Recurrent backward operate at the time step " << seq_offset;
|
|
|
|
@ -501,22 +519,20 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
scopes.Next();
|
|
|
|
|
}
|
|
|
|
|
// Delete the scope of StepScopes
|
|
|
|
|
dev_ctx.Wait();
|
|
|
|
|
auto *var = scope.FindVar(Input(kStepScopes));
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr);
|
|
|
|
|
auto step_scopes = var->GetMutable<StepScopeVar>();
|
|
|
|
|
for (auto *sub_scope : *step_scopes) {
|
|
|
|
|
const_cast<framework::Scope &>(scope).DeleteScope(sub_scope);
|
|
|
|
|
}
|
|
|
|
|
step_scopes->clear();
|
|
|
|
|
auto *step_scopes = var->GetMutable<StepScopeVar>();
|
|
|
|
|
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&scope),
|
|
|
|
|
step_scopes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
StepScopes CreateStepScopes(const framework::Scope &scope,
|
|
|
|
|
StepScopes CreateStepScopes(const platform::DeviceContext &dev_ctx,
|
|
|
|
|
const framework::Scope &scope,
|
|
|
|
|
size_t seq_len) const {
|
|
|
|
|
auto *var = scope.FindVar(Input(kStepScopes));
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr);
|
|
|
|
|
return StepScopes(scope, var->GetMutable<StepScopeVar>(),
|
|
|
|
|
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
|
|
|
|
|
Attr<bool>(kIsTrain), seq_len, true /*is_backward*/);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|