|
|
|
@ -36,15 +36,13 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
|
|
|
|
|
true /*infer_shape_mode*/);
|
|
|
|
|
InitMemories(step_scopes[0], true /*infer_shape_mode*/);
|
|
|
|
|
Variable* net = scope.FindVar(arg_->step_net);
|
|
|
|
|
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < seq_len_; i++) {
|
|
|
|
|
if (i > 0) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
|
|
|
|
|
true /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
net->GetMutable<NetOp>()->InferShape(*step_scopes[i]);
|
|
|
|
|
(*stepnet_)->InferShape(*step_scopes[i]);
|
|
|
|
|
}
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
|
|
|
|
|
true /*infer_shape_mode*/);
|
|
|
|
@ -56,7 +54,6 @@ void RecurrentAlgorithm::Run(const Scope& scope,
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
|
|
|
|
|
false /*infer_shape_mode*/);
|
|
|
|
|
InitMemories(step_scopes[0], false /*infer_shape_mode*/);
|
|
|
|
|
Variable* net = scope.FindVar(arg_->step_net);
|
|
|
|
|
|
|
|
|
|
for (size_t step_id = 0; step_id < seq_len_; step_id++) {
|
|
|
|
|
// create output alias variables
|
|
|
|
@ -64,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
|
|
|
|
|
false /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
|
|
|
|
|
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
|
|
|
|
|
false /*infer_shape_mode*/);
|
|
|
|
@ -78,18 +75,16 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
|
|
|
|
|
auto step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>();
|
|
|
|
|
|
|
|
|
|
// Now all variables in scope must be created outside of op.
|
|
|
|
|
auto net_var = scope.FindVar(arg_->step_net);
|
|
|
|
|
PADDLE_ENFORCE(net_var != nullptr, "no stepnet called %s in scope",
|
|
|
|
|
arg_->step_net);
|
|
|
|
|
auto net_op = net_var->GetMutable<NetOp>();
|
|
|
|
|
PADDLE_ENFORCE(!net_op->Outputs().empty(), "net_op has no outputs");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(stepnet_);
|
|
|
|
|
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs");
|
|
|
|
|
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs");
|
|
|
|
|
|
|
|
|
|
if (seq_len_ > step_scopes->size()) {
|
|
|
|
|
for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
|
|
|
|
|
auto& step_scope = scope.NewScope();
|
|
|
|
|
|
|
|
|
|
// create step net's temp inputs
|
|
|
|
|
for (auto& input : net_op->Inputs()) {
|
|
|
|
|
for (auto& input : (*stepnet_)->Inputs()) {
|
|
|
|
|
// the weight are located in parent scope
|
|
|
|
|
for (auto& var_name : input.second) {
|
|
|
|
|
if (!step_scope.FindVar(var_name)) {
|
|
|
|
@ -98,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// create stepnet's outputs
|
|
|
|
|
for (const auto& output : net_op->Outputs()) {
|
|
|
|
|
for (const auto& output : (*stepnet_)->Outputs()) {
|
|
|
|
|
for (auto& var_name : output.second) {
|
|
|
|
|
step_scope.NewVar(var_name);
|
|
|
|
|
}
|
|
|
|
@ -140,9 +135,8 @@ RecurrentOp::RecurrentOp(const std::string& type,
|
|
|
|
|
const framework::OperatorBase::VarNameMap& outputs,
|
|
|
|
|
const framework::AttributeMap& attrs)
|
|
|
|
|
: OperatorBase(type, inputs, outputs, attrs) {
|
|
|
|
|
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
|
|
|
|
|
rnn::InitArgument(kArgName, arg.get(), *this);
|
|
|
|
|
alg_.Init(std::move(arg));
|
|
|
|
|
rnn::InitArgument(kArgName, &arg_, *this);
|
|
|
|
|
alg_.Init(&arg_, &stepnet_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class RecurrentAlgorithmProtoAndCheckerMaker
|
|
|
|
@ -158,7 +152,6 @@ class RecurrentAlgorithmProtoAndCheckerMaker
|
|
|
|
|
.AsDuplicable();
|
|
|
|
|
AddInput(name.boot_memories, "variables to initialize memories.")
|
|
|
|
|
.AsDuplicable();
|
|
|
|
|
AddInput(name.step_net, "network shared by all steps.");
|
|
|
|
|
|
|
|
|
|
AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
|
|
|
|
|
.AsDuplicable();
|
|
|
|
@ -180,14 +173,12 @@ void RecurrentGradientAlgorithm::Run(
|
|
|
|
|
auto step_scopes = GetStepScopes(scope);
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
|
|
|
|
|
false /*infer_shape_mode*/);
|
|
|
|
|
Variable* net = scope.FindVar(arg_->step_net);
|
|
|
|
|
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
|
|
|
|
|
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
|
|
|
|
|
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
|
|
|
|
|
false /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
|
|
|
|
|
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
LinkBootMemoryGradients(step_scopes[0], false);
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
|
|
|
|
@ -219,14 +210,12 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
|
|
|
|
|
auto step_scopes = GetStepScopes(scope);
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
|
|
|
|
|
true /*infer_shape_mode*/);
|
|
|
|
|
Variable* net = scope.FindVar(arg_->step_net);
|
|
|
|
|
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
|
|
|
|
|
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
|
|
|
|
|
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
|
|
|
|
|
true /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]);
|
|
|
|
|
(*stepnet_)->InferShape(*step_scopes[step_id]);
|
|
|
|
|
}
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
|
|
|
|
|
true /*infer_shape_mode*/);
|
|
|
|
@ -238,9 +227,8 @@ RecurrentGradientOp::RecurrentGradientOp(
|
|
|
|
|
const framework::OperatorBase::VarNameMap& outputs,
|
|
|
|
|
const framework::AttributeMap& attrs)
|
|
|
|
|
: OperatorBase(type, inputs, outputs, attrs) {
|
|
|
|
|
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
|
|
|
|
|
rnn::InitArgument(kArgName, arg.get(), *this);
|
|
|
|
|
alg_.Init(std::move(arg));
|
|
|
|
|
rnn::InitArgument(kArgName, &arg_, *this);
|
|
|
|
|
alg_.Init(&arg_, &stepnet_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|