|
|
|
@ -30,36 +30,39 @@ using LoDTensor = framework::LoDTensor;
|
|
|
|
|
|
|
|
|
|
void RecurrentAlgorithm::Run(const Scope& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const {
|
|
|
|
|
auto step_scopes = GetStepScopes(scope);
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
|
|
|
|
|
false /*infer_shape_mode*/);
|
|
|
|
|
InitMemories(step_scopes[0], false /*infer_shape_mode*/);
|
|
|
|
|
auto* input0 = scope.FindVar(arg_->inlinks[0]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(input0);
|
|
|
|
|
size_t seq_len = input0->GetMutable<LoDTensor>()->dims()[0];
|
|
|
|
|
PADDLE_ENFORCE_GT(seq_len, 0);
|
|
|
|
|
|
|
|
|
|
for (size_t step_id = 0; step_id < seq_len_; step_id++) {
|
|
|
|
|
// create output alias variables
|
|
|
|
|
CreateScopes(scope, seq_len);
|
|
|
|
|
auto& step_scopes = GetStepScopes(scope);
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len);
|
|
|
|
|
InitMemories(step_scopes[0]);
|
|
|
|
|
|
|
|
|
|
for (size_t step_id = 0; step_id < seq_len; step_id++) {
|
|
|
|
|
if (step_id > 0) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
|
|
|
|
|
false /*infer_shape_mode*/);
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1);
|
|
|
|
|
}
|
|
|
|
|
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
|
|
|
|
|
false /*infer_shape_mode*/);
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
|
|
|
|
|
void RecurrentAlgorithm::CreateScopes(const Scope& scope,
|
|
|
|
|
size_t seq_len) const {
|
|
|
|
|
// TODO(superjom) Only two scopes are needed for inference, this case will be
|
|
|
|
|
// supported later.
|
|
|
|
|
auto step_scopes_var = scope.FindVar(arg_->step_scopes);
|
|
|
|
|
auto* step_scopes_var = scope.FindVar(arg_->step_scopes);
|
|
|
|
|
PADDLE_ENFORCE(step_scopes_var != nullptr, "");
|
|
|
|
|
auto step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>();
|
|
|
|
|
auto* step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>();
|
|
|
|
|
|
|
|
|
|
// Now all variables in scope must be created outside of op.
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(stepnet_);
|
|
|
|
|
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs");
|
|
|
|
|
|
|
|
|
|
if (seq_len_ > step_scopes->size()) {
|
|
|
|
|
for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
|
|
|
|
|
if (seq_len > step_scopes->size()) {
|
|
|
|
|
for (size_t i = step_scopes->size(); i < seq_len; ++i) {
|
|
|
|
|
auto& step_scope = scope.NewScope();
|
|
|
|
|
|
|
|
|
|
// create step net's temp inputs
|
|
|
|
@ -82,8 +85,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecurrentAlgorithm::InitMemories(Scope* step_scope,
|
|
|
|
|
bool infer_shape_mode) const {
|
|
|
|
|
void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
|
|
|
|
|
for (auto& attr : arg_->memories) {
|
|
|
|
|
auto* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<LoDTensor>();
|
|
|
|
|
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
|
|
|
|
@ -91,12 +93,9 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope,
|
|
|
|
|
attr.boot_var);
|
|
|
|
|
auto* boot_mem =
|
|
|
|
|
step_scope->FindVar(attr.boot_var)->GetMutable<LoDTensor>();
|
|
|
|
|
if (infer_shape_mode) {
|
|
|
|
|
pre_mem->Resize(boot_mem->dims());
|
|
|
|
|
PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2);
|
|
|
|
|
} else {
|
|
|
|
|
pre_mem->ShareDataWith<float>(*boot_mem);
|
|
|
|
|
}
|
|
|
|
|
pre_mem->Resize(boot_mem->dims());
|
|
|
|
|
PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2);
|
|
|
|
|
pre_mem->ShareDataWith<float>(*boot_mem);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -146,23 +145,23 @@ class RecurrentAlgorithmProtoAndCheckerMaker
|
|
|
|
|
|
|
|
|
|
void RecurrentGradientAlgorithm::Run(
|
|
|
|
|
const Scope& scope, const platform::DeviceContext& dev_ctx) const {
|
|
|
|
|
auto step_scopes = GetStepScopes(scope);
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
|
|
|
|
|
false /*infer_shape_mode*/);
|
|
|
|
|
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
|
|
|
|
|
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
|
|
|
|
|
false /*infer_shape_mode*/);
|
|
|
|
|
auto* input0 = scope.FindVar(arg_->inlinks[0]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(input0);
|
|
|
|
|
size_t seq_len = input0->GetMutable<LoDTensor>()->dims()[0];
|
|
|
|
|
auto& step_scopes = GetStepScopes(scope);
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len);
|
|
|
|
|
for (int step_id = seq_len - 1; step_id >= 0; --step_id) {
|
|
|
|
|
if (step_id != seq_len - 1) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
|
|
|
|
|
}
|
|
|
|
|
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
LinkBootMemoryGradients(step_scopes[0], false);
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
|
|
|
|
|
false /*infer_shape_mode*/);
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len);
|
|
|
|
|
LinkBootMemoryGradients(step_scopes[0]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
|
|
|
|
|
Scope* step_scope, bool infer_shape_mode) const {
|
|
|
|
|
Scope* step_scope) const {
|
|
|
|
|
for (auto& attr : arg_->memories) {
|
|
|
|
|
PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr,
|
|
|
|
|
"memory variable [%s] does not exists", attr.var);
|
|
|
|
@ -171,11 +170,8 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
|
|
|
|
|
auto* mem_grad = step_scope->NewVar(attr.var)->GetMutable<LoDTensor>();
|
|
|
|
|
auto* boot_mem_grad =
|
|
|
|
|
step_scope->NewVar(attr.boot_var)->GetMutable<LoDTensor>();
|
|
|
|
|
if (infer_shape_mode) {
|
|
|
|
|
boot_mem_grad->Resize(mem_grad->dims());
|
|
|
|
|
} else {
|
|
|
|
|
boot_mem_grad->ShareDataWith<float>(*mem_grad);
|
|
|
|
|
}
|
|
|
|
|
boot_mem_grad->Resize(mem_grad->dims());
|
|
|
|
|
boot_mem_grad->ShareDataWith<float>(*mem_grad);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|