|
|
|
@ -3,7 +3,7 @@ import core
|
|
|
|
|
import numpy
|
|
|
|
|
import six.moves as six
|
|
|
|
|
|
|
|
|
|
from framework import Variable
|
|
|
|
|
from framework import Variable, default_main_program
|
|
|
|
|
|
|
|
|
|
__all__ = ['DataFeeder']
|
|
|
|
|
|
|
|
|
@ -53,12 +53,16 @@ class DataToLoDTensorConverter(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataFeeder(object):
|
|
|
|
|
def __init__(self, feed_list, place):
|
|
|
|
|
def __init__(self, feed_list, place, program=None):
|
|
|
|
|
self.feed_dtypes = []
|
|
|
|
|
self.feed_names = []
|
|
|
|
|
self.feed_shapes = []
|
|
|
|
|
self.feed_lod_level = []
|
|
|
|
|
if program is None:
|
|
|
|
|
program = default_main_program()
|
|
|
|
|
for each_var in feed_list:
|
|
|
|
|
if isinstance(each_var, basestring):
|
|
|
|
|
each_var = program.block(0).var(each_var)
|
|
|
|
|
if not isinstance(each_var, Variable):
|
|
|
|
|
raise TypeError("Feed list should contain a list of variable")
|
|
|
|
|
self.feed_dtypes.append(each_var.dtype)
|
|
|
|
|