Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into mixed_layer
commit
7a8da332a1
@ -0,0 +1,53 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['np_array', 'text_file']
|
||||
|
||||
|
||||
def np_array(x):
|
||||
"""
|
||||
Creates a reader that yields elements of x, if it is a
|
||||
numpy vector. Or rows of x, if it is a numpy matrix.
|
||||
Or any sub-hyperplane indexed by the highest dimension.
|
||||
|
||||
:param x: the numpy array to create reader from.
|
||||
:returns: data reader created from x.
|
||||
"""
|
||||
|
||||
def reader():
|
||||
if x.ndim < 1:
|
||||
yield x
|
||||
|
||||
for e in x:
|
||||
yield e
|
||||
|
||||
return reader
|
||||
|
||||
|
||||
def text_file(path):
|
||||
"""
|
||||
Creates a data reader that outputs text line by line from given text file.
|
||||
Trailing new line ('\n') of each line will be removed.
|
||||
|
||||
:path: path of the text file.
|
||||
:returns: data reader of text file
|
||||
"""
|
||||
|
||||
def reader():
|
||||
f = open(path, "r")
|
||||
for l in f:
|
||||
yield l.rstrip('\n')
|
||||
f.close()
|
||||
|
||||
return reader
|
@ -0,0 +1,38 @@
|
||||
# Copyright PaddlePaddle contributors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
import paddle.reader.creator
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
class TestNumpyArray(unittest.TestCase):
|
||||
def test_numpy_array(self):
|
||||
l = [[1, 2, 3], [4, 5, 6]]
|
||||
x = np.array(l, np.int32)
|
||||
reader = paddle.reader.creator.np_array(x)
|
||||
for idx, e in enumerate(reader()):
|
||||
self.assertItemsEqual(e, l[idx])
|
||||
|
||||
|
||||
class TestTextFile(unittest.TestCase):
|
||||
def test_text_file(self):
|
||||
path = os.path.join(os.path.dirname(__file__), "test_data_creator.txt")
|
||||
reader = paddle.reader.creator.text_file(path)
|
||||
for idx, e in enumerate(reader()):
|
||||
self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,3 @@
|
||||
0 1
|
||||
2 3
|
||||
4 5
|
@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from py_paddle import swig_paddle
|
||||
from py_paddle import DataProviderConverter
|
||||
import data_type
|
||||
|
||||
__all__ = ['DataFeeder']
|
||||
|
||||
|
||||
class DataFeeder(DataProviderConverter):
|
||||
"""
|
||||
DataFeeder converts the data returned by paddle.reader into a data structure
|
||||
of Arguments which is defined in the API. The paddle.reader usually returns
|
||||
a list of mini-batch data entries. Each data entry in the list is one sampe.
|
||||
Each sample is a list or a tuple with one feature or multiple features.
|
||||
DataFeeder converts this mini-batch data entries into Arguments in order
|
||||
to feed it to C++ interface.
|
||||
|
||||
The example usage:
|
||||
|
||||
data_types = [('image', paddle.data_type.dense_vector(784)),
|
||||
('label', paddle.data_type.integer_value(10))]
|
||||
reader_dict = {'image':0, 'label':1}
|
||||
feeder = DataFeeder(data_types=data_types, reader_dict=reader_dict)
|
||||
minibatch_data = [
|
||||
( [1.0,2.0,3.0,4.0], 5, [6,7,8] ), # first sample
|
||||
( [1.0,2.0,3.0,4.0], 5, [6,7,8] ) # second sample
|
||||
]
|
||||
# or minibatch_data = [
|
||||
# [ [1.0,2.0,3.0,4.0], 5, [6,7,8] ], # first sample
|
||||
# [ [1.0,2.0,3.0,4.0], 5, [6,7,8] ] # second sample
|
||||
# ]
|
||||
arg = feeder(minibatch_data)
|
||||
"""
|
||||
|
||||
def __init__(self, data_types, reader_dict):
|
||||
"""
|
||||
:param data_types: A list to specify data name and type. Each item is
|
||||
a tuple of (data_name, data_type). For example:
|
||||
[('image', paddle.data_type.dense_vector(784)),
|
||||
('label', paddle.data_type.integer_value(10))]
|
||||
|
||||
:type data_types: A list of tuple
|
||||
:param reader_dict: A dictionary to specify the position of each data
|
||||
in the input data.
|
||||
:type reader_dict: dict()
|
||||
"""
|
||||
self.input_names = []
|
||||
input_types = []
|
||||
self.reader_dict = reader_dict
|
||||
for each in data_types:
|
||||
self.input_names.append(each[0])
|
||||
assert isinstance(each[1], data_type.InputType)
|
||||
input_types.append(each[1])
|
||||
DataProviderConverter.__init__(self, input_types)
|
||||
|
||||
def convert(self, dat, argument=None):
|
||||
"""
|
||||
:param dat: A list of mini-batch data. Each sample is a list or tuple
|
||||
one feature or multiple features.
|
||||
for example:
|
||||
[
|
||||
([0.2, 0.2], ), # first sample
|
||||
([0.8, 0.3], ), # second sample
|
||||
]
|
||||
or,
|
||||
[
|
||||
[[0.2, 0.2], ], # first sample
|
||||
[[0.8, 0.3], ], # second sample
|
||||
]
|
||||
|
||||
:type dat: List
|
||||
:param argument: An Arguments object contains this mini-batch data with
|
||||
one or multiple features. The Arguments definition is
|
||||
in the API.
|
||||
:type argument: swig_paddle.Arguments
|
||||
"""
|
||||
|
||||
def reorder_data(data):
|
||||
retv = []
|
||||
for each in data:
|
||||
reorder = []
|
||||
for name in self.input_names:
|
||||
reorder.append(each[self.reader_dict[name]])
|
||||
retv.append(reorder)
|
||||
return retv
|
||||
|
||||
return DataProviderConverter.convert(self, reorder_data(dat), argument)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue