From 39277e9282294dc18b4c2b93aa000a15b58bea5f Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 4 Apr 2018 14:55:28 +0800 Subject: [PATCH 1/3] fix transpiler condition op in optimize --- python/paddle/fluid/distribute_transpiler.py | 32 ++++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 9311fc9904..6d76c1a8d1 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -408,11 +408,16 @@ class DistributeTranspiler: pserver_vars = pserver_program.global_block().vars created_var_map = dict() for _, var in pserver_vars.iteritems(): - tmpvar = s_prog.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=var.shape) + if var.type == core.VarDesc.VarType.STEP_SCOPES: + tmpvar = s_prog.global_block().create_var( + name=var.name, persistable=var.persistable, type=var.type) + else: + tmpvar = s_prog.global_block().create_var( + name=var.name, + persistable=var.persistable, + type=var.type, + dtype=var.dtype, + shape=var.shape) created_var_map[var.name] = tmpvar # 2. rename op outputs @@ -708,11 +713,18 @@ class DistributeTranspiler: varlist = [varlist] for var in varlist: - program.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=var.shape) + print("##### deal var: ", var) + if var.type == core.VarDesc.VarType.STEP_SCOPES: + program.global_block().create_var( + name=var.name, + persistable=var.persistable, + type=var.type) + else: + program.global_block().create_var( + name=var.name, + persistable=var.persistable, + dtype=var.dtype, + shape=var.shape) optimize_block.append_op( type=opt_op.type, From e0b396e7ba80738fe8c87edb80e5743b8d692cb7 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 4 Apr 2018 15:14:48 +0800 Subject: [PATCH 2/3] update by comment --- python/paddle/fluid/distribute_transpiler.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 6d76c1a8d1..134dbe573a 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -412,12 +412,7 @@ class DistributeTranspiler: tmpvar = s_prog.global_block().create_var( name=var.name, persistable=var.persistable, type=var.type) else: - tmpvar = s_prog.global_block().create_var( - name=var.name, - persistable=var.persistable, - type=var.type, - dtype=var.dtype, - shape=var.shape) + tmpvar = s_prog.global_block().clone_variable(var) created_var_map[var.name] = tmpvar # 2. rename op outputs @@ -713,18 +708,13 @@ class DistributeTranspiler: varlist = [varlist] for var in varlist: - print("##### deal var: ", var) if var.type == core.VarDesc.VarType.STEP_SCOPES: program.global_block().create_var( name=var.name, persistable=var.persistable, type=var.type) else: - program.global_block().create_var( - name=var.name, - persistable=var.persistable, - dtype=var.dtype, - shape=var.shape) + program.global_block().clone_variable(var) optimize_block.append_op( type=opt_op.type, From a16a872783d52d9ba7d32d53848e95cc4ccaefd6 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 4 Apr 2018 15:56:21 +0800 Subject: [PATCH 3/3] update --- python/paddle/fluid/distribute_transpiler.py | 14 ++----------- python/paddle/fluid/framework.py | 21 +++++++++++++------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 134dbe573a..31bedb592f 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -408,11 +408,7 @@ class DistributeTranspiler: pserver_vars = pserver_program.global_block().vars created_var_map = dict() for _, var in pserver_vars.iteritems(): - if var.type == core.VarDesc.VarType.STEP_SCOPES: - tmpvar = s_prog.global_block().create_var( - name=var.name, persistable=var.persistable, type=var.type) - else: - tmpvar = s_prog.global_block().clone_variable(var) + tmpvar = s_prog.global_block().clone_variable(var) created_var_map[var.name] = tmpvar # 2. rename op outputs @@ -708,13 +704,7 @@ class DistributeTranspiler: varlist = [varlist] for var in varlist: - if var.type == core.VarDesc.VarType.STEP_SCOPES: - program.global_block().create_var( - name=var.name, - persistable=var.persistable, - type=var.type) - else: - program.global_block().clone_variable(var) + program.global_block().clone_variable(var) optimize_block.append_op( type=opt_op.type, diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index e15456bfc0..39d4017861 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -946,13 +946,20 @@ class Block(object): The new variable cloned from 'var' in current block. """ assert isinstance(var, Variable) - return self.create_var( - name=var.name, - shape=var.shape, - dtype=var.dtype, - type=var.type, - lod_level=var.lod_level, - persistable=True) + ret_var = None + # make STEP_SCOPES var can be safely cloned. + if var.type == core.VarDesc.VarType.STEP_SCOPES: + ret_var = self.create_var( + name=var.name, persistable=var.persistable, type=var.type) + else: + ret_var = self.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + persistable=True) + return ret_var class Program(object):