Enhence shrink_rnn_memory_op.

update-read-source
yangyaming 8 years ago
parent 25af35d8bb
commit 66ae0a8cb2

@ -46,8 +46,19 @@ class ShrinkRNNMemoryOp : public ArrayOp {
auto *out_var = scope.FindVar(Output("Out")); auto *out_var = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set"); PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>(); auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>();
// should consider multiple levels
size_t height = dst_num_rows;
auto lod_level = lod_rank_table.level();
if (x_tensor.lod().size() > lod_level &&
x_tensor.lod()[lod_level].size() < dst_num_rows) {
auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(
x_tensor.lod(), 0, dst_num_rows + 1, lod_level);
height = lod_offset.second.second;
}
if (dst_num_rows != 0) { if (dst_num_rows != 0) {
out_tensor.ShareDataWith(x_tensor.Slice(0, dst_num_rows)); out_tensor.ShareDataWith(x_tensor.Slice(0, height));
} }
} }
}; };
@ -64,11 +75,11 @@ class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "(LoDTensor) The shrinked RNN step memory."); AddOutput("Out", "(LoDTensor) The shrinked RNN step memory.");
AddComment( AddComment(
R"DOC( R"DOC(
In dynamic RNN, we are able to handle sequences of different lengths. In dynamic RNN, we are able to handle sequences of different lengths.
Because of the multiple lengths, the size of each step input can be Because of the multiple lengths, the size of each step input can be
different, which may lead to a mismatching between the input of different, which may lead to a mismatching between the input of
the current step and the memory generated by the previous one. This the current step and the memory generated by the previous one. This
operator shrinks memory according to the size of the next step input, operator shrinks memory according to the size of the next step input,
to make sure that they can match each other. to make sure that they can match each other.
)DOC"); )DOC");
} }

@ -26,13 +26,13 @@ class TestShrinkRNNMemory(unittest.TestCase):
cpu = core.CPUPlace() cpu = core.CPUPlace()
tensor = core.LoDTensor() tensor = core.LoDTensor()
tensor.set_lod([[0, 2, 5, 6]]) tensor.set_lod([[0, 2, 5, 6]])
tensor_np = numpy.random.random(size=(3, 100)).astype('float32') tensor_np = numpy.random.random(size=(6, 100)).astype('float32')
tensor.set(tensor_np, cpu) tensor.set(tensor_np, cpu)
exe = Executor(cpu) exe = Executor(cpu)
outs = exe.run(feed={'x': tensor}, fetch_list=[mem1, mem2, mem3]) outs = exe.run(feed={'x': tensor}, fetch_list=[mem1, mem2, mem3])
self.assertTrue(numpy.allclose(tensor_np[0:3], outs[0])) self.assertTrue(numpy.allclose(tensor_np[0:6], outs[0]))
self.assertTrue(numpy.allclose(tensor_np[0:2], outs[1])) self.assertTrue(numpy.allclose(tensor_np[0:5], outs[1]))
self.assertTrue(numpy.allclose(tensor_np[0:1], outs[2])) self.assertTrue(numpy.allclose(tensor_np[0:2], outs[2]))
mem3_mean = layers.mean(x=mem3) mem3_mean = layers.mean(x=mem3)
append_backward(loss=mem3_mean) append_backward(loss=mem3_mean)

Loading…
Cancel
Save