feature/dynamic recurrent op forward and backward (#4799)

revert-4814-Add_sequence_project_op
Yan Chunwei 7 years ago committed by GitHub
parent 5380a5471b
commit 07ea9adec0

@ -189,7 +189,7 @@ OpDesc {
inputs = {0} // the index of x in vars of BlockDesc above
outputs = {5, 3} // indices of act and hidden_out in vars of BlockDesc above
attrs {
"memories" : {1} // the index of h
"states" : {1} // the index of h
"step_net" : <above step net>
}
};

@ -21,6 +21,7 @@
#include "paddle/framework/block_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/dynamic_recurrent_op.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
@ -220,8 +221,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
// process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or
// this will result in infinite loop.
// or this will result in infinite loop.
const auto& rnnop =
*static_cast<const operators::RecurrentOp*>(&forwardOp);
auto rnn_grad_op =
@ -231,6 +231,18 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
// create stepnet's gradient op
rnn_grad_op->set_stepnet(
BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
} else if (forwardOp.Type() == "dynamic_recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
// or this will result in infinite loop.
const auto& rnnop =
*static_cast<const operators::DynamicRecurrentOp*>(&forwardOp);
auto rnn_grad_op =
static_cast<operators::DynamicRecurrentGradientOp*>(grad_op.get());
const auto& stepnet_op =
*static_cast<const OperatorBase*>(&rnnop.rnn.GetStepUnit());
// create stepnet's gradient op
rnn_grad_op->rnn.SetStepUnit(
BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
}
if (net->ops_.empty()) { // Current no aux op is added to network

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -43,16 +43,16 @@ LoDTensor* CreateVar(Scope& scope, std::string name, framework::DDim dims,
return tensor;
}
class DynamicRecurrentOpTestHelper : public ::testing::Test {
class RNNAlgorithmTestHelper : public ::testing::Test {
protected:
const rnn::ArgumentName argname = DynamicRecurrentOp::kArgName;
const rnn::ArgumentName argname = RNNAlgorithm::kArgNames[0];
virtual void SetUp() override {
CreateGlobalVariables();
auto op_desc = CreateOpDesc();
op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
dop = dynamic_cast<DynamicRecurrentOp*>(op.get());
dop = &(dynamic_cast<DynamicRecurrentOp*>(op.get())->rnn);
InitCacheManually();
InitStepNet();
}
@ -63,20 +63,20 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test {
op_desc.set_type("dynamic_recurrent");
OpDescNewVar(argname.inlinks, {"in0"}, op_desc.add_inputs());
OpDescNewVar(argname.boot_memories, {"boot_mem"}, op_desc.add_inputs());
OpDescNewVar(argname.initial_states, {"boot_mem"}, op_desc.add_inputs());
OpDescNewVar(argname.step_scopes, {"step_scopes"}, op_desc.add_outputs());
OpDescNewVar(argname.outlinks, {"out0"}, op_desc.add_outputs());
// set pre-memories
// set pre-states
auto pre_memories = op_desc.mutable_attrs()->Add();
pre_memories->set_name(argname.pre_memories);
pre_memories->set_name(argname.ex_states);
pre_memories->set_type(paddle::framework::AttrType::STRINGS);
auto pre_memories_item = pre_memories->add_strings();
*pre_memories_item = "mem@pre";
// set memories
// set states
auto memories = op_desc.mutable_attrs()->Add();
memories->set_name(argname.memories);
memories->set_name(argname.states);
memories->set_type(paddle::framework::AttrType::STRINGS);
auto memories_item = memories->add_strings();
*memories_item = "mem";
@ -113,32 +113,33 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test {
}
void InitCacheManually() {
dop->cache_.Init(DynamicRecurrentOp::kArgName, *dop, scope, &dop->arg_);
dop->cache_.Init(RNNAlgorithm::kArgNames[0], *op, scope, &device_context,
&dop->arg_);
}
void InitStepNet() {
std::unique_ptr<framework::OperatorBase> stepnet{new NetOp};
dynamic_cast<NetOp*>(stepnet.get())
->AppendOp(std::unique_ptr<TestOp>(new TestOp(
"test", {{"inlinks", {"in0"}}, {"boot_memories", {"boot_mem"}}},
{{"outlinks", {"out0"}}, {"step_scopes", {"step_scopes"}}}, {})));
dop->SetStepNet(std::move(stepnet));
"test", {{"inputs", {"in0"}}, {"initial_states", {"boot_mem"}}},
{{"outputs", {"out0"}}, {"step_scopes", {"step_scopes"}}}, {})));
dop->SetStepUnit(std::move(stepnet));
}
protected:
DynamicRecurrentOp* dop;
RNNAlgorithm* dop;
std::unique_ptr<framework::OperatorBase> op;
paddle::platform::CPUDeviceContext device_context;
paddle::framework::Scope scope;
};
TEST_F(DynamicRecurrentOpTestHelper, CreateCache) {
TEST_F(RNNAlgorithmTestHelper, CreateCache) {
const rnn::Argument& arg = dop->arg_;
ASSERT_EQ(arg.inlinks.size(), 1UL);
ASSERT_EQ(arg.outlinks.size(), 1UL);
}
TEST_F(DynamicRecurrentOpTestHelper, SplitInputs) {
TEST_F(RNNAlgorithmTestHelper, SplitInputs) {
dop->SplitInputs();
auto& in0_ta = dop->step_inputs_["in0"];
ASSERT_EQ(in0_ta.size(), 4UL);
@ -153,14 +154,14 @@ TEST_F(DynamicRecurrentOpTestHelper, SplitInputs) {
EXPECT_EQ(batch3.dims()[0], 1);
}
TEST_F(DynamicRecurrentOpTestHelper, CreateScopes) {
TEST_F(RNNAlgorithmTestHelper, CreateScopes) {
dop->SplitInputs();
dop->CreateScopes();
ASSERT_EQ(dop->cache_.num_steps, 4UL);
ASSERT_EQ(dop->cache_.scopes->size(), 4UL);
}
TEST_F(DynamicRecurrentOpTestHelper, WriteStepInputs) {
TEST_F(RNNAlgorithmTestHelper, WriteStepInputs) {
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
@ -173,7 +174,7 @@ TEST_F(DynamicRecurrentOpTestHelper, WriteStepInputs) {
}
}
TEST_F(DynamicRecurrentOpTestHelper, WriteStepOutputs) {
TEST_F(RNNAlgorithmTestHelper, WriteStepOutputs) {
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
@ -187,11 +188,12 @@ TEST_F(DynamicRecurrentOpTestHelper, WriteStepOutputs) {
}
}
TEST_F(DynamicRecurrentOpTestHelper, ConcatOutputs) {
TEST_F(RNNAlgorithmTestHelper, ConcatOutputs) {
// Let's leave this test to python unittest.
}
TEST_F(DynamicRecurrentOpTestHelper, InitStates) {
TEST_F(RNNAlgorithmTestHelper, InitStates) {
dop->SetComputeMode(RNNAlgorithm::ComputeMode::kForward);
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
@ -208,12 +210,6 @@ TEST_F(DynamicRecurrentOpTestHelper, InitStates) {
auto* boot_state = scope.FindVar("boot_mem");
ASSERT_TRUE(boot_state != nullptr);
if (step == 0) {
// check pre_state is a reference of boot_state
ASSERT_EQ(boot_state->Get<LoDTensor>().data<float>(),
pre_state->Get<LoDTensor>().data<float>());
}
}
}

@ -42,7 +42,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
for (size_t step_id = 0; step_id < seq_len; step_id++) {
if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1);
rnn::LinkMemories(step_scopes, arg_->states, step_id, -1);
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
@ -59,7 +59,8 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope,
// Now all variables in scope must be created outside of op.
PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs");
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(),
"step_unit_ op has no outputs");
if (seq_len > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len; ++i) {
@ -86,7 +87,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope,
}
void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
for (auto& attr : arg_->memories) {
for (auto& attr : arg_->states) {
auto* pre_mem = step_scope->Var(attr.pre_var)->GetMutable<LoDTensor>();
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"memory [%s]'s boot variable [%s] not exists", attr.var,
@ -100,12 +101,12 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
}
const rnn::ArgumentName RecurrentOp::kArgName{
"step_net", "step_scopes", "inlinks", "outlinks",
"memories", "pre_memories", "boot_memories"};
"step_net", "step_scopes", "inputs", "outputs",
"states", "ex_states", "initial_states"};
const rnn::ArgumentName RecurrentGradientOp::kArgName{
"step_net", "step_scopes@GRAD", "outlinks@GRAD", "inlinks@GRAD",
"memories", "pre_memories", "boot_memories@GRAD"};
"step_net", "step_scopes@GRAD", "outputs@GRAD", "inputs@GRAD",
"states", "ex_states", "initial_states@GRAD"};
RecurrentOp::RecurrentOp(const std::string& type,
const framework::VariableNameMap& inputs,
@ -127,7 +128,7 @@ class RecurrentAlgorithmProtoAndCheckerMaker
AddInput(name.inlinks,
"the inputs that need to be segmented for each step.")
.AsDuplicable();
AddInput(name.boot_memories, "variables to initialize memories.")
AddInput(name.initial_states, "variables to initialize states.")
.AsDuplicable();
AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
@ -135,9 +136,8 @@ class RecurrentAlgorithmProtoAndCheckerMaker
AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap
AddAttr<std::vector<std::string>>(name.pre_memories,
"names of pre-memories");
AddAttr<std::vector<std::string>>(name.memories, "names of memories");
AddAttr<std::vector<std::string>>(name.ex_states, "names of pre-states");
AddAttr<std::vector<std::string>>(name.states, "names of states");
AddComment("This is a recurrent group operator.");
}
@ -152,7 +152,7 @@ void RecurrentGradientAlgorithm::Run(
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len);
for (int step_id = seq_len - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
rnn::LinkMemories(step_scopes, arg_->states, step_id, 1);
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
@ -162,7 +162,7 @@ void RecurrentGradientAlgorithm::Run(
void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
Scope* step_scope) const {
for (auto& attr : arg_->memories) {
for (auto& attr : arg_->states) {
PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr,
"memory variable [%s] does not exists", attr.var);
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,

@ -36,7 +36,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
LoDTensor* input = input_var->GetMutable<LoDTensor>();
f::DDim dims = input->dims();
PADDLE_ENFORCE_EQ(static_cast<size_t>(dims[0]), seq_len,
"all the inlinks be the same length");
"all the inputs be the same length");
f::DDim step_dims = slice_ddim(dims, 1, dims.size());
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input =
@ -78,7 +78,7 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
}
void LinkMemories(const std::vector<Scope*>& scopes,
const std::vector<rnn::MemoryAttr>& memories,
const std::vector<rnn::StateAttr>& memories,
const size_t step_id, const int offset) {
PADDLE_ENFORCE_LT(step_id, scopes.size(),
"step [%d] is out of range of step scopes' size [%d]",
@ -106,26 +106,26 @@ void InitArgument(const ArgumentName& name, Argument* arg,
arg->inlinks = op.Inputs(name.inlinks);
arg->outlinks = op.Outputs(name.outlinks);
auto& boot_memories =
is_grad ? op.Outputs(name.boot_memories) : op.Inputs(name.boot_memories);
auto& boot_memories = is_grad ? op.Outputs(name.initial_states)
: op.Inputs(name.initial_states);
// attributes
auto& memories = op.Attr<std::vector<std::string>>(name.memories);
auto& pre_memories = op.Attr<std::vector<std::string>>(name.pre_memories);
auto& memories = op.Attr<std::vector<std::string>>(name.states);
auto& pre_memories = op.Attr<std::vector<std::string>>(name.ex_states);
PADDLE_ENFORCE(memories.size() == boot_memories.size(),
"the size of memories, boot_memories don't match:%d,%d",
"the size of states, initial_states don't match:%d,%d",
memories.size(), boot_memories.size());
PADDLE_ENFORCE(pre_memories.size() == boot_memories.size(),
"the size of pre_memories, boot_memories don't match:%d,%d",
"the size of ex_states, initial_states don't match:%d,%d",
pre_memories.size(), boot_memories.size());
PADDLE_ENFORCE(memories.size() > 0, "more than 1 memories should be set");
PADDLE_ENFORCE(memories.size() > 0, "more than 1 states should be set");
for (size_t i = 0; i < memories.size(); ++i) {
rnn::MemoryAttr mem_attr;
rnn::StateAttr mem_attr;
mem_attr.var = memories[i];
mem_attr.pre_var = pre_memories[i];
mem_attr.boot_var = boot_memories[i];
(arg->memories).push_back(mem_attr);
(arg->states).push_back(mem_attr);
}
}

@ -31,7 +31,7 @@ using Scope = framework::Scope;
* boot memories in father scope. Other attributes are copied from Op's proto
* attributes.
*/
struct MemoryAttr {
struct StateAttr {
// name of current state variable
std::string var;
// name of previous step's state variable
@ -46,7 +46,7 @@ struct Argument {
std::string step_scopes;
std::vector<std::string> inlinks;
std::vector<std::string> outlinks;
std::vector<rnn::MemoryAttr> memories;
std::vector<rnn::StateAttr> states;
};
struct ArgumentName {
@ -54,9 +54,9 @@ struct ArgumentName {
std::string step_scopes;
std::string inlinks;
std::string outlinks;
std::string memories; // the memory name
std::string pre_memories; // the previous memory name
std::string boot_memories; // the boot memory name
std::string states; // the memory name
std::string ex_states; // the previous memory name
std::string initial_states; // the boot memory name
};
/**
@ -74,7 +74,7 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const size_t seq_len, const platform::DeviceContext& ctx);
void LinkMemories(const std::vector<Scope*>& step_scopes,
const std::vector<MemoryAttr>& memories, const size_t step_id,
const std::vector<StateAttr>& memories, const size_t step_id,
const int offset);
void InitArgument(const ArgumentName& name, Argument* arg,

@ -413,18 +413,18 @@ All parameter, weight, gradient are variables in Paddle.
return static_cast<operators::DynamicRecurrentOp *>(
rnn_op.release());
})
.def("set_stepnet",
.def("set_step_unit",
[](operators::DynamicRecurrentOp &self, const operators::NetOp &net)
-> void { self.SetStepNet(net.Clone()); })
-> void { self.rnn.SetStepUnit(net.Clone()); })
.def("get_state",
[](operators::DynamicRecurrentOp &self, const std::string &name)
-> const TensorArray & { return self.state(name); })
-> const TensorArray & { return self.rnn.state(name); })
.def("get_step_input",
[](operators::DynamicRecurrentOp &self, const std::string &name)
-> const TensorArray & { return self.step_input(name); })
-> const TensorArray & { return self.rnn.step_input(name); })
.def("get_step_output",
[](operators::DynamicRecurrentOp &self, const std::string &name)
-> const TensorArray & { return self.step_output(name); });
-> const TensorArray & { return self.rnn.step_output(name); });
// cond_op
py::class_<operators::CondOp, OperatorBase>(m, "CondOp")

@ -4,6 +4,12 @@ import unittest
from paddle.v2.framework.op import Operator, DynamicRecurrentOp
import numpy as np
# for siplicity, just one level LoD
lod_py = [[0, 4, 7, 9, 10]]
input_dim = 30
num_sents = len(lod_py[0]) - 1
weight_dim = 15
def create_tensor(scope, name, shape, np_data):
tensor = scope.var(name).get_tensor()
@ -12,6 +18,17 @@ def create_tensor(scope, name, shape, np_data):
return tensor
class PyRNNStep(object):
def __init__(self):
self.x = np.random.normal(size=(lod_py[0][-1],
input_dim)).astype("float32")
self.W = np.random.normal(size=(input_dim, input_dim)).astype("float32")
self.U = np.random.normal(size=(input_dim, input_dim)).astype("float32")
self.h_boot = np.random.normal(size=(num_sents,
input_dim)).astype("float32")
class DynamicRecurrentOpTest(unittest.TestCase):
'''
Test RNNOp
@ -23,17 +40,13 @@ class DynamicRecurrentOpTest(unittest.TestCase):
- U
vars:
- x
memories:
states:
- h
outputs:
- h
'''
# for siplicity, just one level LoD
lod_py = [[0, 4, 7, 9, 10]]
input_dim = 30
num_sents = len(lod_py[0]) - 1
weight_dim = 15
py = PyRNNStep()
def forward(self):
self.scope = core.Scope()
@ -42,64 +55,55 @@ class DynamicRecurrentOpTest(unittest.TestCase):
self.create_step_net()
ctx = core.DeviceContext.create(core.CPUPlace())
self.rnnop.run(self.scope, ctx)
state = self.rnnop.get_state("h@mem")
state = self.rnnop.get_state("h@state")
print 'state size: ', state.size()
step_inputs = self.rnnop.get_step_input("x")
print "x size ", step_inputs.size()
for i in range(step_inputs.size()):
print "x %d" % i, np.array(step_inputs.read(i).get_dims())
step_outputs = self.rnnop.get_step_output('h@mem')
step_outputs = self.rnnop.get_step_output('h@state')
print 'step_outputs.size ', step_outputs.size()
output = self.scope.find_var("h@mem").get_tensor()
output = self.scope.find_var("h@state").get_tensor()
print 'output', np.array(output).shape
def create_global_variables(self):
x = np.random.normal(size=(self.lod_py[0][-1],
self.input_dim)).astype("float32")
W = np.random.normal(size=(self.input_dim,
self.input_dim)).astype("float32")
U = np.random.normal(size=(self.input_dim,
self.input_dim)).astype("float32")
h_boot = np.random.normal(size=(self.num_sents,
self.input_dim)).astype("float32")
# create inlink
x_tensor = create_tensor(self.scope, "x",
[self.num_sents, self.input_dim], x)
x_tensor.set_lod(self.lod_py)
create_tensor(self.scope, "W", [self.input_dim, self.input_dim], W)
create_tensor(self.scope, "U", [self.input_dim, self.input_dim], U)
create_tensor(self.scope, "h_boot", [self.num_sents, self.input_dim],
h_boot)
x_tensor = create_tensor(self.scope, "x", [num_sents, input_dim],
self.py.x)
x_tensor.set_lod(lod_py)
create_tensor(self.scope, "W", [input_dim, input_dim], self.py.W)
create_tensor(self.scope, "U", [input_dim, input_dim], self.py.U)
create_tensor(self.scope, "h_boot", [num_sents, input_dim],
self.py.h_boot)
self.scope.var("step_scopes")
self.scope.var("h@mem")
self.scope.var("h@state")
def create_rnn_op(self):
# create RNNOp
self.rnnop = DynamicRecurrentOp(
# inputs
inlinks=["x"],
boot_memories=["h_boot"],
step_net="stepnet",
inputs=["x"],
initial_states=["h_boot"],
step_net="step_unit",
# outputs
outlinks=["h@mem"],
outputs=["h@state"],
step_scopes="step_scopes",
# attributes
pre_memories=["h@pre"],
memories=["h@mem"])
ex_states=["h@pre"],
states=["h@state"])
def create_step_net(self):
stepnet = core.Net.create()
step_unit = core.Net.create()
x_fc_op = Operator("mul", X="x", Y="W", Out="Wx")
h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
sum_op = Operator("sum", X=["Wx", "Uh"], Out="sum")
sig_op = Operator("sigmoid", X="sum", Y="h@mem")
sig_op = Operator("sigmoid", X="sum", Y="h@state")
for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
stepnet.append_op(op)
stepnet.complete_add_op(True)
self.rnnop.set_stepnet(stepnet)
step_unit.append_op(op)
step_unit.complete_add_op(True)
self.rnnop.set_step_unit(step_unit)
def test_forward(self):
print 'test recurrent op forward'
@ -107,5 +111,58 @@ class DynamicRecurrentOpTest(unittest.TestCase):
print 'pd_output', pd_output
class RecurrentGradientOpTest(unittest.TestCase):
py = PyRNNStep()
def create_forward_op(self):
# create RNNOp
self.forward_op = DynamicRecurrentOp(
# inputs
inputs=["x"],
initial_states=["h_boot"],
step_net="step_unit",
# outputs
outputs=["h@state"],
step_scopes="step_scopes",
# attributes
ex_states=["h@pre"],
states=["h@state"])
def create_gradient_op(self):
a = set()
backward_op = core.DynamicRecurrentOp.backward(self.forward_op, a)
def create_step_net(self):
step_unit = core.Net.create()
x_fc_op = Operator("mul", X="x", Y="W", Out="Wx")
h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
sum_op = Operator("sum", X=["Wx", "Uh"], Out="sum")
sig_op = Operator("sigmoid", X="sum", Y="h@state")
for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
step_unit.append_op(op)
step_unit.complete_add_op(True)
self.forward_op.set_step_unit(step_unit)
def create_global_variables(self):
# create inlink
x_tensor = create_tensor(self.scope, "x", [num_sents, input_dim],
self.py.x)
x_tensor.set_lod(lod_py)
create_tensor(self.scope, "W", [input_dim, input_dim], self.py.W)
create_tensor(self.scope, "U", [input_dim, input_dim], self.py.U)
create_tensor(self.scope, "h_boot", [num_sents, input_dim],
self.py.h_boot)
self.scope.var("step_scopes")
self.scope.var("h@state")
def test_grad(self):
self.scope = core.Scope()
self.create_forward_op()
self.create_global_variables()
self.create_step_net()
self.create_gradient_op()
if __name__ == '__main__':
unittest.main()

@ -132,15 +132,15 @@ class RecurrentOpTest(unittest.TestCase):
# create RNNOp
self.rnnop = RecurrentOp(
# inputs
inlinks=["x"],
boot_memories=["h_boot"],
inputs=["x"],
initial_states=["h_boot"],
step_net="stepnet",
# outputs
outlinks=["h@mem"],
outputs=["h@mem"],
step_scopes="step_scopes",
# attributes
pre_memories=["h@pre"],
memories=["h@mem"])
ex_states=["h@pre"],
states=["h@mem"])
def create_step_net(self):
stepnet = core.Net.create()
@ -169,15 +169,15 @@ class RecurrentGradientOpTest(unittest.TestCase):
def create_forward_op(self):
self.forward_op = RecurrentOp(
# inputs
inlinks=["x"],
boot_memories=["h_boot"],
inputs=["x"],
initial_states=["h_boot"],
step_net="stepnet",
# outputs
outlinks=["h"],
outputs=["h"],
step_scopes="step_scopes",
# attributes
pre_memories=["h@pre"],
memories=["h@alias"])
ex_states=["h@pre"],
states=["h@alias"])
# create a stepnet for RNN
stepnet = core.Net.create()

Loading…
Cancel
Save