|
|
|
@ -54,20 +54,6 @@ static void ClearStepScopes(const platform::DeviceContext &dev_ctx,
|
|
|
|
|
step_scopes->clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// StepScopes manages scopes inside RNN.
|
|
|
|
|
// StepScopes::CurScope() get the current scope
|
|
|
|
|
// StepScopes::ExScope() get the ex-scope, or scope in previous time step.
|
|
|
|
|
// StepScopes::Next() move to next time step.
|
|
|
|
|
//
|
|
|
|
|
// if is_train = False, then
|
|
|
|
|
// there are two scopes for the RNN and just support forward.
|
|
|
|
|
// else
|
|
|
|
|
// the len(scopes) == seq_len
|
|
|
|
|
//
|
|
|
|
|
// if is_backward = True, then
|
|
|
|
|
// reversely access scopes
|
|
|
|
|
// else
|
|
|
|
|
// access scopes from begin to end.
|
|
|
|
|
StepScopes::StepScopes(const platform::DeviceContext &dev_ctx,
|
|
|
|
|
const framework::Scope &parent, StepScopeVar *scopes,
|
|
|
|
|
bool is_train, size_t seq_len, bool is_backward)
|
|
|
|
@ -76,8 +62,8 @@ StepScopes::StepScopes(const platform::DeviceContext &dev_ctx,
|
|
|
|
|
is_train_(is_train),
|
|
|
|
|
is_backward_(is_backward) {
|
|
|
|
|
size_t num_step_scopes = is_train ? seq_len : 2;
|
|
|
|
|
PADDLE_ENFORCE(is_train || !is_backward,
|
|
|
|
|
"Cannot backward when is not training");
|
|
|
|
|
PADDLE_ENFORCE_EQ(is_train || !is_backward, true,
|
|
|
|
|
"Cannot backward when is not training");
|
|
|
|
|
if (!is_backward_) {
|
|
|
|
|
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&parent), scopes);
|
|
|
|
|
scopes->reserve(static_cast<size_t>(num_step_scopes));
|
|
|
|
@ -94,12 +80,22 @@ framework::Scope &StepScopes::ExScope() {
|
|
|
|
|
return scope;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void StepScopes::Next() {
|
|
|
|
|
if (is_backward_) {
|
|
|
|
|
--counter_;
|
|
|
|
|
} else {
|
|
|
|
|
++counter_;
|
|
|
|
|
void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx,
|
|
|
|
|
framework::Scope *parent_scope) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(is_backward_, true,
|
|
|
|
|
"Cannot get backward next scope when is forward");
|
|
|
|
|
if (counter_ + 2 == scopes_->size()) {
|
|
|
|
|
parent_scope->DeleteScope((*scopes_)[counter_ + 1]);
|
|
|
|
|
scopes_->pop_back();
|
|
|
|
|
VLOG(3) << "Deleted scope at " << counter_ + 1;
|
|
|
|
|
}
|
|
|
|
|
--counter_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void StepScopes::ForwardNext() {
|
|
|
|
|
PADDLE_ENFORCE_EQ(is_backward_, false,
|
|
|
|
|
"Cannot get forward next scope when is backward");
|
|
|
|
|
++counter_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::Scope &StepScopes::GetScope(size_t scope_id) const {
|
|
|
|
@ -125,11 +121,11 @@ int64_t RecurrentBase::GetSequenceLength(const framework::Scope &scope) const {
|
|
|
|
|
// Dim format SEQ_LEN, BATCH_SIZE, ...
|
|
|
|
|
int64_t seq_len = -1;
|
|
|
|
|
auto &all_inputs = Inputs(kInputs);
|
|
|
|
|
PADDLE_ENFORCE(!all_inputs.empty());
|
|
|
|
|
PADDLE_ENFORCE_EQ(!all_inputs.empty(), true);
|
|
|
|
|
for (auto &iname : all_inputs) {
|
|
|
|
|
auto *var = scope.FindVar(iname);
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr);
|
|
|
|
|
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>());
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true);
|
|
|
|
|
auto &dim = var->Get<framework::LoDTensor>().dims();
|
|
|
|
|
if (seq_len == -1) {
|
|
|
|
|
seq_len = dim[0];
|
|
|
|
@ -254,7 +250,7 @@ void RecurrentOp::RunImpl(const framework::Scope &scope,
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
scopes.Next();
|
|
|
|
|
scopes.ForwardNext();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -262,7 +258,7 @@ StepScopes RecurrentOp::CreateStepScopes(const platform::DeviceContext &dev_ctx,
|
|
|
|
|
const framework::Scope &scope,
|
|
|
|
|
size_t seq_len) const {
|
|
|
|
|
auto *var = scope.FindVar(Output(kStepScopes));
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
|
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
|
|
|
|
|
Attr<bool>(kIsTrain), seq_len);
|
|
|
|
|
}
|
|
|
|
@ -459,11 +455,11 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
|
|
|
|
|
VLOG(5) << "Link initialize state gradient finished ";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
scopes.Next();
|
|
|
|
|
scopes.BackwardNext(dev_ctx, const_cast<framework::Scope *>(&scope));
|
|
|
|
|
}
|
|
|
|
|
// Delete the scope of StepScopes
|
|
|
|
|
auto *var = scope.FindVar(Input(kStepScopes));
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
|
auto *step_scopes = var->GetMutable<StepScopeVar>();
|
|
|
|
|
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&scope), step_scopes);
|
|
|
|
|
}
|
|
|
|
@ -472,7 +468,7 @@ StepScopes RecurrentGradOp::CreateStepScopes(
|
|
|
|
|
const platform::DeviceContext &dev_ctx, const framework::Scope &scope,
|
|
|
|
|
size_t seq_len) const {
|
|
|
|
|
auto *var = scope.FindVar(Input(kStepScopes));
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
|
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
|
|
|
|
|
Attr<bool>(kIsTrain), seq_len, true /*is_backward*/);
|
|
|
|
|
}
|
|
|
|
@ -491,6 +487,7 @@ std::unordered_set<std::string> RecurrentGradOp::LocalVarNames(
|
|
|
|
|
const framework::Scope &scope) const {
|
|
|
|
|
return this->List2Set(scope.LocalVarNames());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> RecurrentGradOp::GradVarLists(
|
|
|
|
|
const std::vector<std::string> &var_names) {
|
|
|
|
|
std::vector<std::string> retv;
|
|
|
|
@ -627,25 +624,25 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
0, "The Attr(%s) should be empty.", RecurrentBase::kStates);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInputs(RecurrentBase::kInputs),
|
|
|
|
|
"The input(%s) should not be empty.",
|
|
|
|
|
RecurrentBase::kInputs);
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInputs(RecurrentBase::kOutputs),
|
|
|
|
|
"The input(%s) should not be empty.",
|
|
|
|
|
RecurrentBase::kOutputs);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kInputs), true,
|
|
|
|
|
"The input(%s) should not be empty.",
|
|
|
|
|
RecurrentBase::kInputs);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kOutputs), true,
|
|
|
|
|
"The input(%s) should not be empty.",
|
|
|
|
|
RecurrentBase::kOutputs);
|
|
|
|
|
|
|
|
|
|
// In some case the kInitialStates is empty.
|
|
|
|
|
if (ctx->HasInputs(RecurrentBase::kInitialStates)) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutputs(
|
|
|
|
|
framework::GradVarName(RecurrentBase::kInitialStates)),
|
|
|
|
|
"The output of(%s) should not be empty.",
|
|
|
|
|
framework::GradVarName(RecurrentBase::kInitialStates));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutputs(framework::GradVarName(
|
|
|
|
|
RecurrentBase::kInitialStates)),
|
|
|
|
|
true, "The output of(%s) should not be empty.",
|
|
|
|
|
framework::GradVarName(RecurrentBase::kInitialStates));
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInitialStates),
|
|
|
|
|
ctx->GetInputsDim(RecurrentBase::kInitialStates));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)),
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), true,
|
|
|
|
|
"The output of(%s) should not be empty.",
|
|
|
|
|
framework::GradVarName(RecurrentBase::kInputs));
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInputs),
|
|
|
|
@ -653,9 +650,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
|
|
|
|
|
// In some case the kParameters is empty.
|
|
|
|
|
if (ctx->HasInputs(RecurrentBase::kParameters)) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kParameters)),
|
|
|
|
|
"The output of(%s) should not be empty.",
|
|
|
|
|
true, "The output of(%s) should not be empty.",
|
|
|
|
|
framework::GradVarName(RecurrentBase::kParameters));
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kParameters),
|
|
|
|
|
ctx->GetInputsDim(RecurrentBase::kParameters));
|
|
|
|
|