Modify tmp var name prefix in dygraph (#25280)

* Modify tmp var name prefix in dygraph test=develop

* refine comment test=develop
fix_copy_if_different
Aurelius84 5 years ago committed by GitHub
parent 82ec247a02
commit 494cb36d09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -285,6 +285,11 @@ TEST(test_tracer, test_unique_name_generator) {
auto fc_2 = tracer.GenerateUniqueName("fc");
ASSERT_STREQ("fc_0", fc_1.c_str());
ASSERT_STREQ("fc_1", fc_2.c_str());
// use `eager_tmp` as key if not specify it.
auto tmp_var_2 = tracer.GenerateUniqueName();
ASSERT_STREQ("eager_tmp_2", tmp_var_2.c_str());
auto tmp_var_3 = tracer.GenerateUniqueName("eager_tmp");
ASSERT_STREQ("eager_tmp_3", tmp_var_3.c_str());
}
TEST(test_tracer, test_current_tracer) {

@ -76,7 +76,14 @@ class Tracer {
return program_desc_tracer_.get();
}
std::string GenerateUniqueName(std::string key = "tmp") {
// Note(Aurelius84): The `tmp` is used as prefix key while naming a temporary
// intermediate var both in imperative and static mode. But the
// `UniqueNameGenerator` in C++ and `unique_name.py` in Python doesn't share
// the same auto-increment id. It will create a variable repeatedly with same
// name like `tmp_0` in some cases when transform dygraph into static layers.
// So we modify the default prefix key into `eager_tmp` to distinguish with
// static graph.
std::string GenerateUniqueName(std::string key = "eager_tmp") {
return generator_->Generate(key);
}

@ -873,7 +873,7 @@ void BindImperative(py::module *m_ptr) {
&imperative::Tracer::GetProgramDescTracer,
py::return_value_policy::reference)
.def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName,
py::arg("key") = "tmp")
py::arg("key") = "eager_tmp")
.def("trace",
[](imperative::Tracer &self, const std::string &type,
const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,

@ -43,3 +43,18 @@ class TestUniqueName(unittest.TestCase):
name3 = fluid.unique_name.generate('tmp')
self.assertNotEqual(name1, name2)
self.assertEqual(name1[-2:], name3[-2:])
class TestImperativeUniqueName(unittest.TestCase):
def test_name_generator(self):
with fluid.dygraph.guard():
tracer = fluid.framework._dygraph_tracer()
tmp_var_0 = tracer._generate_unique_name()
self.assertEqual(tmp_var_0, "eager_tmp_0")
tmp_var_1 = tracer._generate_unique_name("eager_tmp")
self.assertEqual(tmp_var_1, "eager_tmp_1")
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save