add data_generator into paddle.fluid.incubate.data_generator, add op run log in hogwild_device_worker and downpour_device_worker
test=developrevert-16555-model_data_cryption_link_all_lib
parent
73544e8b8d
commit
73b1f396d7
@ -0,0 +1,226 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__all__ = ['MultiSlotDataGenerator']
|
||||
|
||||
|
||||
class DataGenerator(object):
|
||||
def __init__(self):
|
||||
self._proto_info = None
|
||||
self.batch_size_ = 32
|
||||
|
||||
def _set_line_limit(self, line_limit):
|
||||
if not isinstance(line_limit, int):
|
||||
raise ValueError("line_limit%s must be in int type" %
|
||||
type(line_limit))
|
||||
if line_limit < 1:
|
||||
raise ValueError("line_limit can not less than 1")
|
||||
self._line_limit = line_limit
|
||||
|
||||
def set_batch(self, batch_size):
|
||||
self.batch_size_ = batch_size
|
||||
|
||||
def run_from_memory(self):
|
||||
'''
|
||||
This function generator data from memory, it is usually used for
|
||||
debug and benchmarking
|
||||
'''
|
||||
batch_samples = []
|
||||
line_iter = self.generate_sample(None)
|
||||
for user_parsed_line in line_iter():
|
||||
if user_parsed_line == None:
|
||||
continue
|
||||
batch_samples.append(user_parsed_line)
|
||||
if len(batch_samples) == self.batch_size_:
|
||||
batch_iter = self.generate_batch(batch_samples)
|
||||
for sample in batch_iter():
|
||||
sys.stdout.write(self._gen_str(sample))
|
||||
batch_samples = []
|
||||
if len(batch_samples) > 0:
|
||||
batch_iter = self.generate_batch(batch_samples)
|
||||
for sample in batch_iter():
|
||||
sys.stdout.write(self._gen_str(sample))
|
||||
|
||||
def run_from_stdin(self):
|
||||
'''
|
||||
This function reads the data row from stdin, parses it with the
|
||||
process function, and further parses the return value of the
|
||||
process function with the _gen_str function. The parsed data will
|
||||
be wrote to stdout and the corresponding protofile will be
|
||||
generated.
|
||||
|
||||
'''
|
||||
batch_samples = []
|
||||
for line in sys.stdin:
|
||||
line_iter = self.generate_sample(line)
|
||||
for user_parsed_line in line_iter():
|
||||
if user_parsed_line == None:
|
||||
continue
|
||||
batch_samples.append(user_parsed_line)
|
||||
if len(batch_samples) == self.batch_size_:
|
||||
batch_iter = self.generate_batch(batch_samples)
|
||||
for sample in batch_iter():
|
||||
sys.stdout.write(self._gen_str(sample))
|
||||
batch_samples = []
|
||||
if len(batch_samples) > 0:
|
||||
batch_iter = self.generate_batch(batch_samples)
|
||||
for sample in batch_iter():
|
||||
sys.stdout.write(self._gen_str(sample))
|
||||
|
||||
def _gen_str(self, line):
|
||||
'''
|
||||
Further processing the output of the process() function rewritten by
|
||||
user, outputting data that can be directly read by the datafeed,and
|
||||
updating proto_info infomation.
|
||||
|
||||
Args:
|
||||
line(str): the output of the process() function rewritten by user.
|
||||
|
||||
Returns:
|
||||
Return a string data that can be read directly by the datafeed.
|
||||
'''
|
||||
raise NotImplementedError(
|
||||
"pls use MultiSlotDataGenerator or PairWiseDataGenerator")
|
||||
|
||||
def generate_sample(self, line):
|
||||
'''
|
||||
This function needs to be overridden by the user to process the
|
||||
original data row into a list or tuple.
|
||||
|
||||
Args:
|
||||
line(str): the original data row
|
||||
|
||||
Returns:
|
||||
Returns the data processed by the user.
|
||||
The data format is list or tuple:
|
||||
[(name, [feasign, ...]), ...]
|
||||
or ((name, [feasign, ...]), ...)
|
||||
|
||||
For example:
|
||||
[("words", [1926, 08, 17]), ("label", [1])]
|
||||
or (("words", [1926, 08, 17]), ("label", [1]))
|
||||
|
||||
Note:
|
||||
The type of feasigns must be in int or float. Once the float
|
||||
element appears in the feasign, the type of that slot will be
|
||||
processed into a float.
|
||||
'''
|
||||
raise NotImplementedError(
|
||||
"Please rewrite this function to return a list or tuple: " +
|
||||
"[(name, [feasign, ...]), ...] or ((name, [feasign, ...]), ...)")
|
||||
|
||||
def generate_batch(self, samples):
|
||||
def local_iter():
|
||||
for sample in samples:
|
||||
yield sample
|
||||
|
||||
return local_iter
|
||||
|
||||
|
||||
class MultiSlotDataGenerator(DataGenerator):
|
||||
def _gen_str(self, line):
|
||||
'''
|
||||
Further processing the output of the process() function rewritten by
|
||||
user, outputting data that can be directly read by the MultiSlotDataFeed,
|
||||
and updating proto_info infomation.
|
||||
|
||||
The input line will be in this format:
|
||||
>>> [(name, [feasign, ...]), ...]
|
||||
>>> or ((name, [feasign, ...]), ...)
|
||||
The output will be in this format:
|
||||
>>> [ids_num id1 id2 ...] ...
|
||||
The proto_info will be in this format:
|
||||
>>> [(name, type), ...]
|
||||
|
||||
For example, if the input is like this:
|
||||
>>> [("words", [1926, 08, 17]), ("label", [1])]
|
||||
>>> or (("words", [1926, 08, 17]), ("label", [1]))
|
||||
the output will be:
|
||||
>>> 3 1234 2345 3456 1 1
|
||||
the proto_info will be:
|
||||
>>> [("words", "uint64"), ("label", "uint64")]
|
||||
|
||||
Args:
|
||||
line(str): the output of the process() function rewritten by user.
|
||||
|
||||
Returns:
|
||||
Return a string data that can be read directly by the MultiSlotDataFeed.
|
||||
'''
|
||||
if not isinstance(line, list) and not isinstance(line, tuple):
|
||||
raise ValueError(
|
||||
"the output of process() must be in list or tuple type")
|
||||
output = ""
|
||||
|
||||
if self._proto_info is None:
|
||||
self._proto_info = []
|
||||
for item in line:
|
||||
name, elements = item
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("name%s must be in str type" % type(name))
|
||||
if not isinstance(elements, list):
|
||||
raise ValueError("elements%s must be in list type" %
|
||||
type(elements))
|
||||
if not elements:
|
||||
raise ValueError(
|
||||
"the elements of each field can not be empty, you need padding it in process()."
|
||||
)
|
||||
self._proto_info.append((name, "uint64"))
|
||||
if output:
|
||||
output += " "
|
||||
output += str(len(elements))
|
||||
for elem in elements:
|
||||
if isinstance(elem, float):
|
||||
self._proto_info[-1] = (name, "float")
|
||||
elif not isinstance(elem, int) and not isinstance(elem,
|
||||
long):
|
||||
raise ValueError(
|
||||
"the type of element%s must be in int or float" %
|
||||
type(elem))
|
||||
output += " " + str(elem)
|
||||
else:
|
||||
if len(line) != len(self._proto_info):
|
||||
raise ValueError(
|
||||
"the complete field set of two given line are inconsistent.")
|
||||
for index, item in enumerate(line):
|
||||
name, elements = item
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("name%s must be in str type" % type(name))
|
||||
if not isinstance(elements, list):
|
||||
raise ValueError("elements%s must be in list type" %
|
||||
type(elements))
|
||||
if not elements:
|
||||
raise ValueError(
|
||||
"the elements of each field can not be empty, you need padding it in process()."
|
||||
)
|
||||
if name != self._proto_info[index][0]:
|
||||
raise ValueError(
|
||||
"the field name of two given line are not match: require<%s>, get<%d>."
|
||||
% (self._proto_info[index][0], name))
|
||||
if output:
|
||||
output += " "
|
||||
output += str(len(elements))
|
||||
for elem in elements:
|
||||
if self._proto_info[index][1] != "float":
|
||||
if isinstance(elem, float):
|
||||
self._proto_info[index] = (name, "float")
|
||||
elif not isinstance(elem, int) and not isinstance(elem,
|
||||
long):
|
||||
raise ValueError(
|
||||
"the type of element%s must be in int or float"
|
||||
% type(elem))
|
||||
output += " " + str(elem)
|
||||
return output + "\n"
|
@ -0,0 +1,26 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
from __init__ import *
|
||||
|
||||
|
||||
class SyntheticData(MultiSlotDataGenerator):
|
||||
def generate_sample(self, line):
|
||||
def data_iter():
|
||||
for i in range(10000):
|
||||
yield ("words", [1, 2, 3, 4]), ("label", [0])
|
||||
|
||||
return data_iter
|
||||
|
||||
|
||||
sd = SyntheticData()
|
||||
sd.run_from_memory()
|
Loading…
Reference in new issue