|
|
|
@ -23,6 +23,7 @@ constexpr char kInitialStates[] = "initial_states";
|
|
|
|
|
constexpr char kParameters[] = "parameters";
|
|
|
|
|
constexpr char kOutputs[] = "outputs";
|
|
|
|
|
constexpr char kStepScopes[] = "step_scopes";
|
|
|
|
|
constexpr char kHasStates[] = "has_states";
|
|
|
|
|
constexpr char kExStates[] = "ex_states";
|
|
|
|
|
constexpr char kStates[] = "states";
|
|
|
|
|
constexpr char kStepBlock[] = "sub_block";
|
|
|
|
@ -241,11 +242,16 @@ class RecurrentOp : public RecurrentBase {
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
bool has_state = Attr<bool>(kHasStates);
|
|
|
|
|
auto seq_len = static_cast<size_t>(this->GetSequenceLength(scope));
|
|
|
|
|
VLOG(3) << "Static RNN input sequence length = " << seq_len;
|
|
|
|
|
StepScopes scopes = CreateStepScopes(scope, seq_len);
|
|
|
|
|
auto reverse = Attr<bool>(kReverse);
|
|
|
|
|
|
|
|
|
|
// get device context from pool
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(place);
|
|
|
|
|
|
|
|
|
|
framework::Executor executor(place);
|
|
|
|
|
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
|
|
|
|
|
|
|
|
|
@ -269,6 +275,7 @@ class RecurrentOp : public RecurrentBase {
|
|
|
|
|
inside->Resize(framework::make_ddim(dims));
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
if (has_state) {
|
|
|
|
|
if (i == 0) {
|
|
|
|
|
// Link initial states --> ex_states
|
|
|
|
|
LinkTensor(scope, Inputs(kInitialStates), &cur_scope,
|
|
|
|
@ -279,6 +286,7 @@ class RecurrentOp : public RecurrentBase {
|
|
|
|
|
LinkTensor(ex_scope, Attr<std::vector<std::string>>(kStates),
|
|
|
|
|
&cur_scope, Attr<std::vector<std::string>>(kExStates));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Every inputs are linked now, execute!
|
|
|
|
|
executor.Run(*program, &cur_scope, block->ID(),
|
|
|
|
@ -286,11 +294,6 @@ class RecurrentOp : public RecurrentBase {
|
|
|
|
|
std::vector<std::string>() /*skip_ref_cnt_vars*/,
|
|
|
|
|
true /*force_disable_gc*/);
|
|
|
|
|
|
|
|
|
|
// get device context from pool
|
|
|
|
|
platform::DeviceContextPool &pool =
|
|
|
|
|
platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(place);
|
|
|
|
|
|
|
|
|
|
// Copy inside::output -> outside::output
|
|
|
|
|
// outside::output[seq_offset: seq_offset + 1] = inside::output
|
|
|
|
|
this->LinkTensorWithCallback(
|
|
|
|
@ -333,13 +336,13 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
private:
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
auto seq_len = static_cast<size_t>(GetSequenceLength(scope));
|
|
|
|
|
bool has_state = Attr<bool>(kHasStates);
|
|
|
|
|
const size_t seq_len = static_cast<size_t>(GetSequenceLength(scope));
|
|
|
|
|
StepScopes scopes = CreateStepScopes(scope, seq_len);
|
|
|
|
|
auto reverse = Attr<bool>(kReverse);
|
|
|
|
|
|
|
|
|
|
framework::Executor executor(place);
|
|
|
|
|
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
|
|
|
|
|
|
|
|
|
|
auto *program = block->Program();
|
|
|
|
|
|
|
|
|
|
// get device context from pool
|
|
|
|
@ -350,6 +353,7 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
size_t seq_offset = reverse ? step_id : seq_len - step_id - 1;
|
|
|
|
|
VLOG(3) << "Recurrent backward operate at the time step " << seq_offset;
|
|
|
|
|
auto &cur_scope = scopes.CurScope();
|
|
|
|
|
|
|
|
|
|
// Link outside::output_grads --> inside::output_grads
|
|
|
|
|
// inside::output_grad = outside::output_grad[seq_offset:seq_offset+1]
|
|
|
|
|
LinkTensorWithCallback(
|
|
|
|
@ -370,6 +374,7 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
VLOG(10) << " RNN output gradients = [" << sout.str() << "]";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (has_state) {
|
|
|
|
|
// Link states
|
|
|
|
|
// if cur_scope::cur_state_grad in out_grads:
|
|
|
|
|
// cur_scope::cur_state_grad += ex_scope::ex_state_grad
|
|
|
|
@ -396,6 +401,7 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
framework::TensorCopy(ex_tensor, place, dev_ctx, cur_grad_tensor);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(5) << "Recurrent memory linking finished ";
|
|
|
|
|
// Run step block with cur_scope
|
|
|
|
@ -442,8 +448,8 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto new_inside_name = cur_scope.Rename(inside_grad_name);
|
|
|
|
|
// sum gradient
|
|
|
|
|
|
|
|
|
|
// sum gradient
|
|
|
|
|
auto sum_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"sum", {{"X", {pg_names[param_id], new_inside_name}}},
|
|
|
|
|
{{"Out", {pg_names[param_id]}}},
|
|
|
|
@ -475,11 +481,13 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
true /*is_backward*/);
|
|
|
|
|
VLOG(5) << "Link outside gradient finished ";
|
|
|
|
|
|
|
|
|
|
if (has_state) {
|
|
|
|
|
if (step_id + 1 == seq_len) { // at_end
|
|
|
|
|
// copy initialize states gradient from inside to outside
|
|
|
|
|
LinkTensorWithCallback(
|
|
|
|
|
cur_scope, GradVarLists(Attr<std::vector<std::string>>(kExStates)),
|
|
|
|
|
scope, Outputs(kInitStateGrads),
|
|
|
|
|
cur_scope,
|
|
|
|
|
GradVarLists(Attr<std::vector<std::string>>(kExStates)), scope,
|
|
|
|
|
Outputs(kInitStateGrads),
|
|
|
|
|
[&](const framework::LoDTensor &inside,
|
|
|
|
|
framework::LoDTensor *outside) {
|
|
|
|
|
outside->Resize(inside.dims());
|
|
|
|
@ -489,8 +497,17 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
true /*is_backward*/);
|
|
|
|
|
VLOG(5) << "Link initialize state gradient finished ";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
scopes.Next();
|
|
|
|
|
}
|
|
|
|
|
// Delete the scope of StepScopes
|
|
|
|
|
dev_ctx.Wait();
|
|
|
|
|
auto *var = scope.FindVar(Input(kStepScopes));
|
|
|
|
|
PADDLE_ENFORCE(var != nullptr);
|
|
|
|
|
auto step_scopes = var->GetMutable<StepScopeVar>();
|
|
|
|
|
for (auto *sub_scope : *step_scopes) {
|
|
|
|
|
const_cast<framework::Scope &>(scope).DeleteScope(sub_scope);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
@ -541,6 +558,7 @@ class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
.AsDuplicable();
|
|
|
|
|
AddOutput(kStepScopes,
|
|
|
|
|
"StepScopes contain all local variables in each time step.");
|
|
|
|
|
AddAttr<bool>(kHasStates, "Whether has states.").SetDefault(false);
|
|
|
|
|
AddAttr<std::vector<std::string>>(kExStates,
|
|
|
|
|
string::Sprintf(
|
|
|
|
|
R"DOC(The ex-state variable names.
|
|
|
|
@ -624,20 +642,44 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
class RecurrentGradOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
std::vector<std::string> input{kInputs, kInitialStates};
|
|
|
|
|
std::vector<std::string> output{kOutputs};
|
|
|
|
|
for (auto &s : input) {
|
|
|
|
|
// NOTE(zcd): In some case, some of kInputs doesn't have gradient.
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInputs(s));
|
|
|
|
|
}
|
|
|
|
|
for (auto &s : output) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInputs(s));
|
|
|
|
|
}
|
|
|
|
|
for (auto &s : input) {
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName(s), ctx->GetInputsDim(s));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// In some case the kInitialStates is empty.
|
|
|
|
|
// If the kInitialStates is empty, all the states should be empty.
|
|
|
|
|
if (!ctx->HasInputs(kInitialStates)) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->Attrs().Get<std::vector<std::string>>(kExStates).size(), 0,
|
|
|
|
|
"The Attr(%s) should be empty.", kExStates);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->Attrs().Get<std::vector<std::string>>(kStates).size(), 0,
|
|
|
|
|
"The Attr(%s) should be empty.", kStates);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInputs(kInputs),
|
|
|
|
|
"The input(%s) should not be empty.", kInputs);
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInputs(kOutputs),
|
|
|
|
|
"The input(%s) should not be empty.", kOutputs);
|
|
|
|
|
|
|
|
|
|
// In some case the kInitialStates is empty.
|
|
|
|
|
if (ctx->HasInputs(kInitialStates)) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kInitialStates)),
|
|
|
|
|
"The output of(%s) should not be empty.",
|
|
|
|
|
framework::GradVarName(kInitialStates));
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName(kInitialStates),
|
|
|
|
|
ctx->GetInputsDim(kInitialStates));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kInputs)),
|
|
|
|
|
"The output of(%s) should not be empty.",
|
|
|
|
|
framework::GradVarName(kInputs));
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName(kInputs),
|
|
|
|
|
ctx->GetInputsDim(kInputs));
|
|
|
|
|
|
|
|
|
|
// In some case the kParameters is empty.
|
|
|
|
|
if (ctx->HasInputs(kParameters)) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kParameters)));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kParameters)),
|
|
|
|
|
"The output of(%s) should not be empty.",
|
|
|
|
|
framework::GradVarName(kParameters));
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName(kParameters),
|
|
|
|
|
ctx->GetInputsDim(kParameters));
|
|
|
|
|
}
|
|
|
|
|