|
|
|
@ -28,6 +28,29 @@ using Variable = framework::Variable;
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
|
|
|
|
|
void RecurrentAlgorithm::InferShape(const Scope& scope) const {
|
|
|
|
|
auto* input0 = scope.FindVar(arg_->inlinks[0]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(input0);
|
|
|
|
|
seq_len_ = input0->GetMutable<LoDTensor>()->dims()[0];
|
|
|
|
|
PADDLE_ENFORCE_GT(seq_len_, 0);
|
|
|
|
|
|
|
|
|
|
CreateScopes(scope);
|
|
|
|
|
auto& step_scopes = GetStepScopes(scope);
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
|
|
|
|
|
true /*infer_shape_mode*/);
|
|
|
|
|
InitMemories(step_scopes[0], true /*infer_shape_mode*/);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < seq_len_; i++) {
|
|
|
|
|
if (i > 0) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
|
|
|
|
|
true /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
(*stepnet_)->InferShape(*step_scopes[i]);
|
|
|
|
|
}
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
|
|
|
|
|
true /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecurrentAlgorithm::Run(const Scope& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const {
|
|
|
|
|
auto step_scopes = GetStepScopes(scope);
|
|
|
|
@ -179,6 +202,24 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
|
|
|
|
|
seq_len_ =
|
|
|
|
|
scope.FindVar(arg_->inlinks[0])->GetMutable<LoDTensor>()->dims()[0];
|
|
|
|
|
auto step_scopes = GetStepScopes(scope);
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
|
|
|
|
|
true /*infer_shape_mode*/);
|
|
|
|
|
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
|
|
|
|
|
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
|
|
|
|
|
true /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
(*stepnet_)->InferShape(*step_scopes[step_id]);
|
|
|
|
|
}
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
|
|
|
|
|
true /*infer_shape_mode*/);
|
|
|
|
|
LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RecurrentGradientOp::RecurrentGradientOp(
|
|
|
|
|
const std::string& type, const framework::VariableNameMap& inputs,
|
|
|
|
|
const framework::VariableNameMap& outputs,
|
|
|
|
|