|
|
|
@ -30,11 +30,14 @@ namespace rnn {
|
|
|
|
|
void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
|
|
|
|
|
const std::vector<Link>& inlinks,
|
|
|
|
|
const size_t seq_len,
|
|
|
|
|
bool infer_shape) {
|
|
|
|
|
bool infer_shape_mode) {
|
|
|
|
|
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
|
|
|
|
|
for (size_t i = 0; i < inlinks.size(); ++i) {
|
|
|
|
|
Tensor* input =
|
|
|
|
|
step_scopes[0]->GetVariable(inlinks[i].external)->GetMutable<Tensor>();
|
|
|
|
|
auto input_var = step_scopes[0]->GetVariable(inlinks[i].external);
|
|
|
|
|
PADDLE_ENFORCE(input_var != nullptr,
|
|
|
|
|
"input link [%s] is not in scope.",
|
|
|
|
|
inlinks[i].external);
|
|
|
|
|
Tensor* input = input_var->GetMutable<Tensor>();
|
|
|
|
|
DDim dims = input->dims();
|
|
|
|
|
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
|
|
|
|
|
"all the inlinks must have same length");
|
|
|
|
@ -43,7 +46,7 @@ void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
|
|
|
|
|
Tensor* step_input = step_scopes[j]
|
|
|
|
|
->CreateVariable(inlinks[i].internal)
|
|
|
|
|
->GetMutable<Tensor>();
|
|
|
|
|
if (!infer_shape) {
|
|
|
|
|
if (!infer_shape_mode) {
|
|
|
|
|
*step_input = input->Slice<float>(j, j + 1);
|
|
|
|
|
}
|
|
|
|
|
step_input->Resize(step_dims);
|
|
|
|
@ -54,12 +57,14 @@ void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
|
|
|
|
|
void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
|
|
|
|
|
const std::vector<Link>& outlinks,
|
|
|
|
|
const size_t seq_len,
|
|
|
|
|
bool infer_shape) {
|
|
|
|
|
bool infer_shape_mode) {
|
|
|
|
|
for (size_t i = 0; i < outlinks.size(); i++) {
|
|
|
|
|
PADDLE_ENFORCE(step_scopes[0]->HasVariable(outlinks[i].external),
|
|
|
|
|
"output link [%s] is not in scope.",
|
|
|
|
|
outlinks[i].external);
|
|
|
|
|
Tensor* output =
|
|
|
|
|
step_scopes[0]->GetVariable(outlinks[i].external)->GetMutable<Tensor>();
|
|
|
|
|
|
|
|
|
|
if (infer_shape) {
|
|
|
|
|
if (infer_shape_mode) {
|
|
|
|
|
DDim step_dims = step_scopes[0]
|
|
|
|
|
->GetVariable(outlinks[i].internal)
|
|
|
|
|
->GetMutable<Tensor>()
|
|
|
|
@ -69,8 +74,6 @@ void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
|
|
|
|
|
output->Resize(make_ddim(dims_vec));
|
|
|
|
|
} else {
|
|
|
|
|
output->mutable_data<float>(platform::CPUPlace());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t j = 0; j < seq_len; j++) {
|
|
|
|
|
Tensor* step_output = step_scopes[j]
|
|
|
|
|
->GetVariable(outlinks[i].internal)
|
|
|
|
@ -82,12 +85,13 @@ void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes,
|
|
|
|
|
const std::vector<rnn::MemoryAttr>& memories,
|
|
|
|
|
const size_t step_id,
|
|
|
|
|
const int offset,
|
|
|
|
|
bool infer_shape) {
|
|
|
|
|
bool infer_shape_mode) {
|
|
|
|
|
PADDLE_ENFORCE(step_id < scopes.size(),
|
|
|
|
|
"step [%d] is out of range of step scopes' size [%d]",
|
|
|
|
|
step_id,
|
|
|
|
@ -107,7 +111,7 @@ void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes,
|
|
|
|
|
auto mem = scope->GetVariable(attr.pre_var)->GetMutable<Tensor>();
|
|
|
|
|
// maybe share variable is better?
|
|
|
|
|
auto linked_mem = linked_scope->GetVariable(attr.var)->GetMutable<Tensor>();
|
|
|
|
|
if (infer_shape) {
|
|
|
|
|
if (infer_shape_mode) {
|
|
|
|
|
mem->Resize(linked_mem->dims());
|
|
|
|
|
} else {
|
|
|
|
|
mem->ShareDataWith<float>(*linked_mem);
|
|
|
|
@ -179,43 +183,39 @@ void RecurrentAlgorithm::InferShape(const std::shared_ptr<Scope>& scope) const {
|
|
|
|
|
->GetMutable<Tensor>()
|
|
|
|
|
->dims()[0];
|
|
|
|
|
CreateScopes(scope);
|
|
|
|
|
|
|
|
|
|
auto step_scopes = GetStepScopes(scope);
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, true);
|
|
|
|
|
|
|
|
|
|
InitMemories(step_scopes[0], true);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(scope->HasVariable(arg_->step_net),
|
|
|
|
|
"stepnet [%s] is not in scope.",
|
|
|
|
|
arg_->step_net);
|
|
|
|
|
rnn::SegmentInputs(
|
|
|
|
|
step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/);
|
|
|
|
|
InitMemories(step_scopes[0], true /*infer_shape_mode*/);
|
|
|
|
|
Variable* net = scope->GetVariable(arg_->step_net);
|
|
|
|
|
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
|
|
|
|
|
for (size_t i = 0; i < seq_len_; i++) {
|
|
|
|
|
if (i > 0) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, i, -1, true);
|
|
|
|
|
rnn::LinkMemories(
|
|
|
|
|
step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
net->GetMutable<NetOp>()->InferShape(step_scopes[i]);
|
|
|
|
|
}
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true);
|
|
|
|
|
rnn::ConcatOutputs(
|
|
|
|
|
step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecurrentAlgorithm::Run(const std::shared_ptr<Scope>& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const {
|
|
|
|
|
auto step_scopes = GetStepScopes(scope);
|
|
|
|
|
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, false);
|
|
|
|
|
|
|
|
|
|
InitMemories(step_scopes[0], false);
|
|
|
|
|
|
|
|
|
|
rnn::SegmentInputs(
|
|
|
|
|
step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/);
|
|
|
|
|
InitMemories(step_scopes[0], false /*infer_shape_mode*/);
|
|
|
|
|
Variable* net = scope->GetVariable(arg_->step_net);
|
|
|
|
|
for (size_t step_id = 0; step_id < seq_len_; step_id++) {
|
|
|
|
|
if (step_id > 0) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, false);
|
|
|
|
|
rnn::LinkMemories(
|
|
|
|
|
step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
net->GetMutable<NetOp>()->Run(step_scopes[step_id], dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, false);
|
|
|
|
|
rnn::ConcatOutputs(
|
|
|
|
|
step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecurrentAlgorithm::CreateScopes(std::shared_ptr<Scope> scope) const {
|
|
|
|
@ -227,7 +227,6 @@ void RecurrentAlgorithm::CreateScopes(std::shared_ptr<Scope> scope) const {
|
|
|
|
|
if (seq_len_ > step_scopes->size()) {
|
|
|
|
|
for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
|
|
|
|
|
std::shared_ptr<Scope> step_scope = std::make_shared<Scope>(scope);
|
|
|
|
|
|
|
|
|
|
// Now all variables in scope must be created outside of op.
|
|
|
|
|
auto net_op = scope->GetVariable(arg_->step_net)->GetMutable<NetOp>();
|
|
|
|
|
for (auto& input : net_op->inputs_) {
|
|
|
|
@ -237,14 +236,13 @@ void RecurrentAlgorithm::CreateScopes(std::shared_ptr<Scope> scope) const {
|
|
|
|
|
for (auto& output : net_op->outputs_) {
|
|
|
|
|
step_scope->CreateVariable(output);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
step_scopes->push_back(std::make_shared<Scope>(step_scope));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecurrentAlgorithm::InitMemories(std::shared_ptr<Scope> step_scope,
|
|
|
|
|
bool infer_shape) const {
|
|
|
|
|
bool infer_shape_mode) const {
|
|
|
|
|
for (auto& attr : arg_->memories) {
|
|
|
|
|
Tensor* pre_mem =
|
|
|
|
|
step_scope->CreateVariable(attr.pre_var)->GetMutable<Tensor>();
|
|
|
|
@ -254,7 +252,7 @@ void RecurrentAlgorithm::InitMemories(std::shared_ptr<Scope> step_scope,
|
|
|
|
|
attr.boot_var);
|
|
|
|
|
Tensor* boot_mem =
|
|
|
|
|
step_scope->GetVariable(attr.boot_var)->GetMutable<Tensor>();
|
|
|
|
|
if (infer_shape) {
|
|
|
|
|
if (infer_shape_mode) {
|
|
|
|
|
pre_mem->Resize(boot_mem->dims());
|
|
|
|
|
} else {
|
|
|
|
|
pre_mem->ShareDataWith<float>(*boot_mem);
|
|
|
|
@ -320,23 +318,23 @@ void RecurrentGradientAlgorithm::Run(
|
|
|
|
|
const std::shared_ptr<Scope>& scope,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) const {
|
|
|
|
|
auto step_scopes = GetStepScopes(scope);
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, false);
|
|
|
|
|
PADDLE_ENFORCE(scope->HasVariable(arg_->step_net),
|
|
|
|
|
"step net is not in scope.");
|
|
|
|
|
rnn::SegmentInputs(
|
|
|
|
|
step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/);
|
|
|
|
|
Variable* net = scope->GetVariable(arg_->step_net);
|
|
|
|
|
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
|
|
|
|
|
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
|
|
|
|
|
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, false);
|
|
|
|
|
rnn::LinkMemories(
|
|
|
|
|
step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
net->GetMutable<NetOp>()->Run(step_scopes[step_id], dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
LinkBootMemoryGradients(step_scopes[0], false);
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, false);
|
|
|
|
|
rnn::ConcatOutputs(
|
|
|
|
|
step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
|
|
|
|
|
std::shared_ptr<Scope> step_scope, bool infer_shape) const {
|
|
|
|
|
std::shared_ptr<Scope> step_scope, bool infer_shape_mode) const {
|
|
|
|
|
for (auto& attr : arg_->memories) {
|
|
|
|
|
Tensor* mem_grad =
|
|
|
|
|
step_scope->CreateVariable(attr.var)->GetMutable<Tensor>();
|
|
|
|
@ -346,7 +344,7 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
|
|
|
|
|
attr.boot_var);
|
|
|
|
|
Tensor* boot_mem_grad =
|
|
|
|
|
step_scope->CreateVariable(attr.boot_var)->GetMutable<Tensor>();
|
|
|
|
|
if (infer_shape) {
|
|
|
|
|
if (infer_shape_mode) {
|
|
|
|
|
boot_mem_grad->Resize(mem_grad->dims());
|
|
|
|
|
} else {
|
|
|
|
|
boot_mem_grad->ShareDataWith<float>(*mem_grad);
|
|
|
|
@ -360,21 +358,20 @@ void RecurrentGradientAlgorithm::InferShape(
|
|
|
|
|
->GetMutable<Tensor>()
|
|
|
|
|
->dims()[0];
|
|
|
|
|
auto step_scopes = GetStepScopes(scope);
|
|
|
|
|
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, true);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(scope->HasVariable(arg_->step_net),
|
|
|
|
|
"step net is not in scope.");
|
|
|
|
|
rnn::SegmentInputs(
|
|
|
|
|
step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/);
|
|
|
|
|
Variable* net = scope->GetVariable(arg_->step_net);
|
|
|
|
|
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
|
|
|
|
|
|
|
|
|
|
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
|
|
|
|
|
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
|
|
|
|
|
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, true);
|
|
|
|
|
rnn::LinkMemories(
|
|
|
|
|
step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
net->GetMutable<NetOp>()->InferShape(step_scopes[step_id]);
|
|
|
|
|
}
|
|
|
|
|
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true);
|
|
|
|
|
LinkBootMemoryGradients(step_scopes[0], true);
|
|
|
|
|
rnn::ConcatOutputs(
|
|
|
|
|
step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/);
|
|
|
|
|
LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RecurrentGradientOp::Init() {
|
|
|
|
|