|
|
|
@ -12,7 +12,7 @@
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/recurrent_network_op.h"
|
|
|
|
|
#include "paddle/operators/recurrent_op.h"
|
|
|
|
|
|
|
|
|
|
#include <glog/logging.h>
|
|
|
|
|
#include <cstring>
|
|
|
|
@ -108,8 +108,13 @@ void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes,
|
|
|
|
|
std::shared_ptr<Scope> scope = scopes[step_id];
|
|
|
|
|
std::shared_ptr<Scope> linked_scope = scopes[step_id + offset];
|
|
|
|
|
for (auto& attr : memories) {
|
|
|
|
|
PADDLE_ENFORCE(scope->HasVariable(attr.pre_var),
|
|
|
|
|
"the pre-memory [%s] is not in scope.",
|
|
|
|
|
attr.pre_var);
|
|
|
|
|
PADDLE_ENFORCE(linked_scope->HasVariable(attr.var),
|
|
|
|
|
"the memory [%s] is not in linked scope.",
|
|
|
|
|
attr.var);
|
|
|
|
|
auto mem = scope->GetVariable(attr.pre_var)->GetMutable<Tensor>();
|
|
|
|
|
// maybe share variable is better?
|
|
|
|
|
auto linked_mem = linked_scope->GetVariable(attr.var)->GetMutable<Tensor>();
|
|
|
|
|
if (infer_shape_mode) {
|
|
|
|
|
mem->Resize(linked_mem->dims());
|
|
|
|
@ -295,12 +300,12 @@ public:
|
|
|
|
|
const auto& name = RecurrentOp::kArgName;
|
|
|
|
|
// inputs and outputs stored in proto
|
|
|
|
|
AddInputs(name.inlinks,
|
|
|
|
|
"the input that need to be segmented for each step.");
|
|
|
|
|
"the inputs that need to be segmented for each step.");
|
|
|
|
|
AddInputs(name.boot_memories, "variables to initialize memories.");
|
|
|
|
|
AddInput(name.step_net, "network shared by all steps.");
|
|
|
|
|
|
|
|
|
|
AddOutputs(name.outlinks,
|
|
|
|
|
"the output that need to concated for all steps.");
|
|
|
|
|
"the outputs that need to concated for all steps.");
|
|
|
|
|
AddOutput(name.step_scopes, "step scopes");
|
|
|
|
|
|
|
|
|
|
// Attributes stored in AttributeMap
|