Merge branch 'feature/fix_swig_dense_scanner' into feature/mnist_train_api

avx_docs
Yu Yang 8 years ago
commit efb5c10cdb

@ -15,6 +15,7 @@
import paddle.trainer.PyDataProvider2 as dp2 import paddle.trainer.PyDataProvider2 as dp2
import collections import collections
import swig_paddle import swig_paddle
import numpy
__all__ = ['DataProviderConverter'] __all__ = ['DataProviderConverter']
@ -35,18 +36,18 @@ class IScanner(object):
class DenseScanner(IScanner): class DenseScanner(IScanner):
def __init__(self, input_type, pos): def __init__(self, input_type, pos):
IScanner.__init__(self, input_type, pos) IScanner.__init__(self, input_type, pos)
self.__mat__ = [] self.__mat__ = None
self.__height__ = 0
def scan(self, dat): def scan(self, dat):
self.__mat__.extend(dat) if self.__mat__ is None:
self.__height__ += 1 self.__mat__ = numpy.array([dat], dtype='float32')
else:
self.__mat__ = numpy.append(self.__mat__, [dat], axis=0)
def finish_scan(self, argument): def finish_scan(self, argument):
assert isinstance(argument, swig_paddle.Arguments) assert isinstance(argument, swig_paddle.Arguments)
assert isinstance(self.input_type, dp2.InputType) assert isinstance(self.input_type, dp2.InputType)
m = swig_paddle.Matrix.createDense(self.__mat__, self.__height__, m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, False)
self.input_type.dim, False)
argument.setSlotValue(self.pos, m) argument.setSlotValue(self.pos, m)

Loading…
Cancel
Save