|
|
|
|
@ -65,7 +65,8 @@ StepScopes::StepScopes(const platform::DeviceContext &dev_ctx,
|
|
|
|
|
is_backward_(is_backward) {
|
|
|
|
|
size_t num_step_scopes = is_train ? seq_len : 2;
|
|
|
|
|
PADDLE_ENFORCE_EQ(is_train || !is_backward, true,
|
|
|
|
|
"Cannot backward when is not training");
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Cannot backward when is not training"));
|
|
|
|
|
if (!is_backward_) {
|
|
|
|
|
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&parent), scopes);
|
|
|
|
|
scopes->reserve(static_cast<size_t>(num_step_scopes));
|
|
|
|
|
@ -85,7 +86,8 @@ framework::Scope &StepScopes::ExScope() {
|
|
|
|
|
void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx,
|
|
|
|
|
framework::Scope *parent_scope) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(is_backward_, true,
|
|
|
|
|
"Cannot get backward next scope when is forward");
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Cannot get backward next scope when is forward"));
|
|
|
|
|
if (counter_ + 2 == scopes_->size()) {
|
|
|
|
|
parent_scope->DeleteScope((*scopes_)[counter_ + 1]);
|
|
|
|
|
scopes_->pop_back();
|
|
|
|
|
@ -96,7 +98,8 @@ void StepScopes::BackwardNext(const platform::DeviceContext &dev_ctx,
|
|
|
|
|
|
|
|
|
|
void StepScopes::ForwardNext() {
|
|
|
|
|
PADDLE_ENFORCE_EQ(is_backward_, false,
|
|
|
|
|
"Cannot get forward next scope when is backward");
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Cannot get forward next scope when is backward"));
|
|
|
|
|
++counter_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -104,7 +107,10 @@ framework::Scope &StepScopes::GetScope(size_t scope_id) const {
|
|
|
|
|
if (!is_train_) {
|
|
|
|
|
scope_id %= 2;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_LT(scope_id, scopes_->size());
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
scope_id, scopes_->size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input scope_id is greater than scopes size in RecurrentOp"));
|
|
|
|
|
return *(*scopes_)[scope_id];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -123,18 +129,33 @@ int64_t RecurrentBase::GetSequenceLength(const framework::Scope &scope) const {
|
|
|
|
|
// Dim format SEQ_LEN, BATCH_SIZE, ...
|
|
|
|
|
int64_t seq_len = -1;
|
|
|
|
|
auto &all_inputs = Inputs(kInputs);
|
|
|
|
|
PADDLE_ENFORCE_EQ(all_inputs.empty(), false);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
all_inputs.empty(), false,
|
|
|
|
|
platform::errors::InvalidArgument("RecurrentOp gets empty input"));
|
|
|
|
|
for (auto &iname : all_inputs) {
|
|
|
|
|
auto *var = scope.FindVar(iname);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"RecurrentOp finds var %s is NULL", iname));
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"RecurrentOp only accepts LoDTensor as input but "
|
|
|
|
|
"input var %s is not LoDTensor",
|
|
|
|
|
iname));
|
|
|
|
|
auto &dim = var->Get<framework::LoDTensor>().dims();
|
|
|
|
|
if (seq_len == -1) {
|
|
|
|
|
seq_len = dim[0];
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(seq_len, dim[0]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(seq_len, dim[0],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Sequence length of input %s in RecurrentOp is NOT "
|
|
|
|
|
"equal to sequence length of previous input",
|
|
|
|
|
iname));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_GE(seq_len, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"RecurrentOp gets invalid sequence length."));
|
|
|
|
|
return seq_len;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -260,7 +281,8 @@ StepScopes RecurrentOp::CreateStepScopes(const platform::DeviceContext &dev_ctx,
|
|
|
|
|
const framework::Scope &scope,
|
|
|
|
|
size_t seq_len) const {
|
|
|
|
|
auto *var = scope.FindVar(Output(kStepScopes));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, platform::errors::InvalidArgument(
|
|
|
|
|
"RecurrentOp gets empty StepScopes var"));
|
|
|
|
|
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
|
|
|
|
|
Attr<bool>(kIsTrain), seq_len);
|
|
|
|
|
}
|
|
|
|
|
@ -328,7 +350,10 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
|
|
|
|
|
auto cur_state_grads =
|
|
|
|
|
GradVarLists(Attr<std::vector<std::string>>(kStates));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(ex_state_grads.size(), cur_state_grads.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"lengths of ex_states and cur_states are not "
|
|
|
|
|
"equal in RecurrentGradOp"));
|
|
|
|
|
for (size_t i = 0; i < ex_state_grads.size(); ++i) {
|
|
|
|
|
auto &cur_grad = cur_state_grads[i];
|
|
|
|
|
auto &ex_grad = ex_state_grads[i];
|
|
|
|
|
@ -380,7 +405,10 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
|
|
|
|
|
{
|
|
|
|
|
auto &pg_names = Outputs(kParamGrads);
|
|
|
|
|
auto &p_names = Inputs(kParameters);
|
|
|
|
|
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Sizes of Parameters and ParamGrads are not equal "
|
|
|
|
|
"in RecurrentGradOp"));
|
|
|
|
|
|
|
|
|
|
for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) {
|
|
|
|
|
auto inside_grad_name = framework::GradVarName(p_names[param_id]);
|
|
|
|
|
@ -461,7 +489,9 @@ void RecurrentGradOp::RunImpl(const framework::Scope &scope,
|
|
|
|
|
}
|
|
|
|
|
// Delete the scope of StepScopes
|
|
|
|
|
auto *var = scope.FindVar(Input(kStepScopes));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"StepScopes var is empty in RecurrentGradOp"));
|
|
|
|
|
auto *step_scopes = var->GetMutable<StepScopeVar>();
|
|
|
|
|
ClearStepScopes(dev_ctx, const_cast<framework::Scope *>(&scope), step_scopes);
|
|
|
|
|
}
|
|
|
|
|
@ -470,7 +500,9 @@ StepScopes RecurrentGradOp::CreateStepScopes(
|
|
|
|
|
const platform::DeviceContext &dev_ctx, const framework::Scope &scope,
|
|
|
|
|
size_t seq_len) const {
|
|
|
|
|
auto *var = scope.FindVar(Input(kStepScopes));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"StepScopes var is empty in RecurrentGradOp"));
|
|
|
|
|
return StepScopes(dev_ctx, scope, var->GetMutable<StepScopeVar>(),
|
|
|
|
|
Attr<bool>(kIsTrain), seq_len, true /*is_backward*/);
|
|
|
|
|
}
|
|
|
|
|
@ -619,20 +651,24 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
ctx->Attrs()
|
|
|
|
|
.Get<std::vector<std::string>>(RecurrentBase::kExStates)
|
|
|
|
|
.size(),
|
|
|
|
|
0, "The Attr(%s) should be empty.", RecurrentBase::kExStates);
|
|
|
|
|
0, platform::errors::InvalidArgument("The Attr(%s) should be empty.",
|
|
|
|
|
RecurrentBase::kExStates));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->Attrs()
|
|
|
|
|
.Get<std::vector<std::string>>(RecurrentBase::kStates)
|
|
|
|
|
.size(),
|
|
|
|
|
0, "The Attr(%s) should be empty.", RecurrentBase::kStates);
|
|
|
|
|
0, platform::errors::InvalidArgument("The Attr(%s) should be empty.",
|
|
|
|
|
RecurrentBase::kStates));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kInputs), true,
|
|
|
|
|
"The input(%s) should not be empty.",
|
|
|
|
|
RecurrentBase::kInputs);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInputs(RecurrentBase::kOutputs), true,
|
|
|
|
|
"The input(%s) should not be empty.",
|
|
|
|
|
RecurrentBase::kOutputs);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInputs(RecurrentBase::kInputs), true,
|
|
|
|
|
platform::errors::InvalidArgument("The input(%s) should not be empty.",
|
|
|
|
|
RecurrentBase::kInputs));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInputs(RecurrentBase::kOutputs), true,
|
|
|
|
|
platform::errors::InvalidArgument("The input(%s) should not be empty.",
|
|
|
|
|
RecurrentBase::kOutputs));
|
|
|
|
|
|
|
|
|
|
// In some case the kInitialStates is empty.
|
|
|
|
|
if (ctx->HasInputs(RecurrentBase::kInitialStates) &&
|
|
|
|
|
@ -644,8 +680,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kInputs)), true,
|
|
|
|
|
"The output of(%s) should not be empty.",
|
|
|
|
|
framework::GradVarName(RecurrentBase::kInputs));
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The output of(%s) should not be empty.",
|
|
|
|
|
framework::GradVarName(RecurrentBase::kInputs)));
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kInputs),
|
|
|
|
|
ctx->GetInputsDim(RecurrentBase::kInputs));
|
|
|
|
|
|
|
|
|
|
@ -653,8 +690,9 @@ class RecurrentGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
if (ctx->HasInputs(RecurrentBase::kParameters)) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutputs(framework::GradVarName(RecurrentBase::kParameters)),
|
|
|
|
|
true, "The output of(%s) should not be empty.",
|
|
|
|
|
framework::GradVarName(RecurrentBase::kParameters));
|
|
|
|
|
true, platform::errors::InvalidArgument(
|
|
|
|
|
"The output of(%s) should not be empty.",
|
|
|
|
|
framework::GradVarName(RecurrentBase::kParameters)));
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName(RecurrentBase::kParameters),
|
|
|
|
|
ctx->GetInputsDim(RecurrentBase::kParameters));
|
|
|
|
|
}
|
|
|
|
|
|