fea/docker_cudnn7
fengjiayi 7 years ago
parent 44d5f42a7e
commit 649ae2700e

@ -259,17 +259,27 @@ def _copy_reader_var_(block, var):
def _copy_reader_create_op_(block, op):
def _find_vars_(block, name_list):
res = {}
for n in name_list:
var = block.var(n)
res[n] = var
return res
input_map = _find_vars_(block, op.input_names)
output_map = _find_vars_(block, op.output_names)
input_param_names = op.input_names
new_input_map = {}
for param_name in input_param_names:
new_input_map[param_name] = []
arg_names = op.input(param_name)
for arg_name in arg_names:
new_input_map[param_name].append(block.var(arg_name))
output_param_names = op.output_names
new_output_map = {}
for param_name in output_param_names:
new_output_map[param_name] = []
arg_names = op.output(param_name)
for arg_name in arg_names:
new_output_map[param_name].append(block.var(arg_name))
new_op = block.append_op(
type=op.type, inputs=input_map, outputs=output_map, attrs=op.attrs)
type=op.type,
inputs=new_input_map,
outputs=new_output_map,
attrs=op.attrs)
return new_op

@ -15,8 +15,8 @@
import unittest
import paddle.fluid as fluid
import paddle
import paddle.dataset.mnist as mnist
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
class TestRecordIO(unittest.TestCase):

Loading…
Cancel
Save