Add VarType::STEP_SCOPES for RNN (#5056)

revert-4814-Add_sequence_project_op
Yu Yang 7 years ago committed by GitHub
parent ee998a9c44
commit 6c0b383672

@ -115,6 +115,7 @@ message VarDesc {
SELECTED_ROWS = 2;
FEED_MINIBATCH = 3;
FETCH_LIST = 4;
STEP_SCOPES = 5;
}
required string name = 1;
required VarType type = 2;

@ -224,7 +224,8 @@ void BindVarDsec(py::module &m) {
.value("LOD_TENSOR", VarDesc::LOD_TENSOR)
.value("SELECTED_ROWS", VarDesc::SELECTED_ROWS)
.value("FEED_MINIBATCH", VarDesc::FEED_MINIBATCH)
.value("FETCH_LIST", VarDesc::FETCH_LIST);
.value("FETCH_LIST", VarDesc::FETCH_LIST)
.value("STEP_SCOPES", VarDesc::STEP_SCOPES);
}
void BindOpDesc(py::module &m) {

@ -1,5 +1,5 @@
import unittest
from paddle.v2.framework.framework import Variable, g_program
from paddle.v2.framework.framework import Variable, g_program, Program
import paddle.v2.framework.core as core
import numpy as np
@ -36,6 +36,13 @@ class TestVariable(unittest.TestCase):
self.assertRaises(ValueError,
lambda: b.create_var(name="fc.w", shape=(24, 100)))
def test_step_scopes(self):
prog = Program()
b = prog.current_block()
var = b.create_var(
name='step_scopes', type=core.VarDesc.VarType.STEP_SCOPES)
self.assertEqual(core.VarDesc.VarType.STEP_SCOPES, var.type)
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save