|
|
@ -92,8 +92,8 @@ void LinkMemories(const std::vector<Scope*>& scopes,
|
|
|
|
auto* scope = scopes[step_id];
|
|
|
|
auto* scope = scopes[step_id];
|
|
|
|
auto* linked_scope = scopes[step_id + offset];
|
|
|
|
auto* linked_scope = scopes[step_id + offset];
|
|
|
|
for (auto& attr : memories) {
|
|
|
|
for (auto& attr : memories) {
|
|
|
|
auto mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>();
|
|
|
|
auto* mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>();
|
|
|
|
auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>();
|
|
|
|
auto* linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>();
|
|
|
|
mem->Resize(linked_mem->dims());
|
|
|
|
mem->Resize(linked_mem->dims());
|
|
|
|
mem->ShareDataWith<float>(*linked_mem);
|
|
|
|
mem->ShareDataWith<float>(*linked_mem);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -106,11 +106,11 @@ void InitArgument(const ArgumentName& name, Argument* arg,
|
|
|
|
arg->inlinks = op.Inputs(name.inlinks);
|
|
|
|
arg->inlinks = op.Inputs(name.inlinks);
|
|
|
|
arg->outlinks = op.Outputs(name.outlinks);
|
|
|
|
arg->outlinks = op.Outputs(name.outlinks);
|
|
|
|
|
|
|
|
|
|
|
|
auto boot_memories =
|
|
|
|
auto& boot_memories =
|
|
|
|
is_grad ? op.Outputs(name.boot_memories) : op.Inputs(name.boot_memories);
|
|
|
|
is_grad ? op.Outputs(name.boot_memories) : op.Inputs(name.boot_memories);
|
|
|
|
// attributes
|
|
|
|
// attributes
|
|
|
|
auto memories = op.Attr<std::vector<std::string>>(name.memories);
|
|
|
|
auto& memories = op.Attr<std::vector<std::string>>(name.memories);
|
|
|
|
auto pre_memories = op.Attr<std::vector<std::string>>(name.pre_memories);
|
|
|
|
auto& pre_memories = op.Attr<std::vector<std::string>>(name.pre_memories);
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(memories.size() == boot_memories.size(),
|
|
|
|
PADDLE_ENFORCE(memories.size() == boot_memories.size(),
|
|
|
|
"the size of memories, boot_memories don't match:%d,%d",
|
|
|
|
"the size of memories, boot_memories don't match:%d,%d",
|
|
|
|