You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
119 lines
4.2 KiB
119 lines
4.2 KiB
7 years ago
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||
7 years ago
|
#
|
||
7 years ago
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
7 years ago
|
#
|
||
7 years ago
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
7 years ago
|
#
|
||
7 years ago
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
7 years ago
|
from __future__ import print_function
|
||
|
import core
|
||
|
import numpy
|
||
|
import six.moves as six
|
||
|
|
||
7 years ago
|
from framework import Variable, default_main_program
|
||
7 years ago
|
|
||
|
__all__ = ['DataFeeder']
|
||
|
|
||
|
|
||
|
class DataToLoDTensorConverter(object):
|
||
|
def __init__(self, place, lod_level, shape, dtype):
|
||
|
self.place = place
|
||
|
self.lod_level = lod_level
|
||
|
self.shape = shape
|
||
7 years ago
|
if dtype == core.VarDesc.VarType.FP32:
|
||
7 years ago
|
self.dtype = 'float32'
|
||
7 years ago
|
elif dtype == core.VarDesc.VarType.INT64:
|
||
7 years ago
|
self.dtype = 'int64'
|
||
7 years ago
|
elif dtype == core.VarDesc.VarType.FP64:
|
||
7 years ago
|
self.dtype = 'float64'
|
||
7 years ago
|
elif dtype == core.VarDesc.VarType.INT32:
|
||
7 years ago
|
self.dtype = 'int32'
|
||
|
else:
|
||
|
raise ValueError("dtype must be any of [int32, float32, int64, "
|
||
|
"float64]")
|
||
|
|
||
|
self.data = []
|
||
|
self.lod = []
|
||
|
|
||
|
for i in six.range(lod_level):
|
||
|
self.lod.append([0])
|
||
|
|
||
|
def feed(self, data):
|
||
|
self._feed_impl_(data, self.lod, self.lod_level)
|
||
|
|
||
|
def _feed_impl_(self, data, lod, lod_level):
|
||
|
if lod_level == 0:
|
||
|
self.data.append(data)
|
||
|
else:
|
||
|
cur_lod_len = len(data)
|
||
|
lod[-1].append(lod[-1][-1] + cur_lod_len)
|
||
|
for each_data in data:
|
||
|
self._feed_impl_(each_data, lod[:-1], lod_level - 1)
|
||
|
|
||
|
def done(self):
|
||
|
arr = numpy.array(self.data, dtype=self.dtype).reshape(self.shape)
|
||
|
t = core.LoDTensor()
|
||
|
t.set(arr, self.place)
|
||
|
if self.lod_level > 0:
|
||
|
t.set_lod(self.lod)
|
||
|
return t
|
||
|
|
||
|
|
||
|
class DataFeeder(object):
|
||
7 years ago
|
def __init__(self, feed_list, place, program=None):
|
||
7 years ago
|
self.feed_dtypes = []
|
||
|
self.feed_names = []
|
||
|
self.feed_shapes = []
|
||
|
self.feed_lod_level = []
|
||
7 years ago
|
if program is None:
|
||
|
program = default_main_program()
|
||
7 years ago
|
for each_var in feed_list:
|
||
7 years ago
|
if isinstance(each_var, basestring):
|
||
|
each_var = program.block(0).var(each_var)
|
||
7 years ago
|
if not isinstance(each_var, Variable):
|
||
|
raise TypeError("Feed list should contain a list of variable")
|
||
|
self.feed_dtypes.append(each_var.dtype)
|
||
|
self.feed_names.append(each_var.name)
|
||
|
shape = each_var.shape
|
||
|
batch_size_dim = -1
|
||
|
for i, s in enumerate(shape):
|
||
|
if s < 0:
|
||
|
batch_size_dim = i
|
||
|
break
|
||
|
if batch_size_dim == -1:
|
||
|
raise ValueError("Variable {0} must has a batch size dimension",
|
||
|
each_var.name)
|
||
|
self.feed_lod_level.append(each_var.lod_level)
|
||
|
self.feed_shapes.append(shape)
|
||
|
|
||
|
self.place = place
|
||
|
|
||
|
def feed(self, iterable):
|
||
|
converter = []
|
||
|
for lod_level, shape, dtype in six.zip(
|
||
|
self.feed_lod_level, self.feed_shapes, self.feed_dtypes):
|
||
|
converter.append(
|
||
|
DataToLoDTensorConverter(
|
||
|
place=self.place,
|
||
|
lod_level=lod_level,
|
||
|
shape=shape,
|
||
|
dtype=dtype))
|
||
|
|
||
|
for each_sample in iterable:
|
||
7 years ago
|
assert len(each_sample) == len(converter), (
|
||
|
"The number of fields in data (%s) does not match " +
|
||
|
"len(feed_list) (%s)") % (len(each_sample), len(converter))
|
||
7 years ago
|
for each_converter, each_slot in six.zip(converter, each_sample):
|
||
|
each_converter.feed(each_slot)
|
||
|
ret_dict = {}
|
||
|
for each_name, each_converter in six.zip(self.feed_names, converter):
|
||
|
ret_dict[each_name] = each_converter.done()
|
||
|
return ret_dict
|