Merge branch 'feature/fix_swig_dense_scanner' into feature/mnist_train_api

avx_docs
Yu Yang 8 years ago
commit efb5c10cdb

@ -15,6 +15,7 @@
import paddle.trainer.PyDataProvider2 as dp2
import collections
import swig_paddle
import numpy
__all__ = ['DataProviderConverter']
@ -35,18 +36,18 @@ class IScanner(object):
class DenseScanner(IScanner):
def __init__(self, input_type, pos):
IScanner.__init__(self, input_type, pos)
self.__mat__ = []
self.__height__ = 0
self.__mat__ = None
def scan(self, dat):
self.__mat__.extend(dat)
self.__height__ += 1
if self.__mat__ is None:
self.__mat__ = numpy.array([dat], dtype='float32')
else:
self.__mat__ = numpy.append(self.__mat__, [dat], axis=0)
def finish_scan(self, argument):
assert isinstance(argument, swig_paddle.Arguments)
assert isinstance(self.input_type, dp2.InputType)
m = swig_paddle.Matrix.createDense(self.__mat__, self.__height__,
self.input_type.dim, False)
m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, False)
argument.setSlotValue(self.pos, m)

Loading…
Cancel
Save