|
|
@ -1,6 +1,8 @@
|
|
|
|
import paddle.v2.framework.core as core
|
|
|
|
import paddle.v2.framework.core as core
|
|
|
|
from paddle.v2.framework.framework import Block, Program
|
|
|
|
from paddle.v2.framework.framework import Block, Program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
g_scope = core.Scope()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Executor(object):
|
|
|
|
class Executor(object):
|
|
|
|
def __init__(self, places):
|
|
|
|
def __init__(self, places):
|
|
|
@ -20,10 +22,14 @@ class Executor(object):
|
|
|
|
feed,
|
|
|
|
feed,
|
|
|
|
fetch_list,
|
|
|
|
fetch_list,
|
|
|
|
feed_var_name='feed',
|
|
|
|
feed_var_name='feed',
|
|
|
|
fetch_var_name='fetch'):
|
|
|
|
fetch_var_name='fetch',
|
|
|
|
|
|
|
|
scope=None):
|
|
|
|
if not isinstance(program, Program):
|
|
|
|
if not isinstance(program, Program):
|
|
|
|
raise TypeError()
|
|
|
|
raise TypeError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if scope is None:
|
|
|
|
|
|
|
|
scope = g_scope
|
|
|
|
|
|
|
|
|
|
|
|
program = program.clone()
|
|
|
|
program = program.clone()
|
|
|
|
global_block = program.global_block()
|
|
|
|
global_block = program.global_block()
|
|
|
|
feed_var = global_block.create_var(
|
|
|
|
feed_var = global_block.create_var(
|
|
|
@ -38,7 +44,7 @@ class Executor(object):
|
|
|
|
inputs={'X': [feed_var]},
|
|
|
|
inputs={'X': [feed_var]},
|
|
|
|
outputs={'Out': [out]},
|
|
|
|
outputs={'Out': [out]},
|
|
|
|
attrs={'col': i})
|
|
|
|
attrs={'col': i})
|
|
|
|
core.set_feed_variable(feed[name], feed_var.name, i)
|
|
|
|
core.set_feed_variable(scope, feed[name], feed_var.name, i)
|
|
|
|
|
|
|
|
|
|
|
|
fetch_var = global_block.create_var(
|
|
|
|
fetch_var = global_block.create_var(
|
|
|
|
name=fetch_var_name,
|
|
|
|
name=fetch_var_name,
|
|
|
@ -51,8 +57,8 @@ class Executor(object):
|
|
|
|
outputs={'Out': [fetch_var]},
|
|
|
|
outputs={'Out': [fetch_var]},
|
|
|
|
attrs={'col': i})
|
|
|
|
attrs={'col': i})
|
|
|
|
|
|
|
|
|
|
|
|
self.executor.run(program.desc, 0)
|
|
|
|
self.executor.run(program.desc, scope, 0)
|
|
|
|
return [
|
|
|
|
return [
|
|
|
|
core.get_fetch_variable(fetch_var_name, i)
|
|
|
|
core.get_fetch_variable(scope, fetch_var_name, i)
|
|
|
|
for i in xrange(len(fetch_list))
|
|
|
|
for i in xrange(len(fetch_list))
|
|
|
|
]
|
|
|
|
]
|
|
|
|