add dataset_generator.py

dataset_generator.py is a framework for generating data with python
the generated data with a fixed format will be feeded into c++ reader
test=develop
revert-16555-model_data_cryption_link_all_lib
dongdaxiang 6 years ago
parent be757096da
commit c28bbdf8ba

@ -60,6 +60,7 @@ class DataFeed {
// Otherwise, Init() function will init finish_set_filelist_ flag. // Otherwise, Init() function will init finish_set_filelist_ flag.
virtual bool SetFileList(const std::vector<std::string>& files); virtual bool SetFileList(const std::vector<std::string>& files);
virtual bool Start() = 0; virtual bool Start() = 0;
// The trainer calls the Next() function, and the DataFeed will load a new // The trainer calls the Next() function, and the DataFeed will load a new
// batch to the feed_vec. The return value of this function is the batch // batch to the feed_vec. The return value of this function is the batch
// size of the current batch. // size of the current batch.

@ -28,4 +28,5 @@ message DataFeedDesc {
optional int32 batch_size = 2 [ default = 32 ]; optional int32 batch_size = 2 [ default = 32 ];
optional MultiSlotDesc multi_slot_desc = 3; optional MultiSlotDesc multi_slot_desc = 3;
optional string pipe_command = 4; optional string pipe_command = 4;
optional int32 thread_num = 5;
} }

@ -284,6 +284,7 @@ void ExecutorThreadWorker::TrainFilesWithTimer() {
for (int i = 0; i < fetch_var_num; ++i) { for (int i = 0; i < fetch_var_num; ++i) {
print_fetch_var(thread_scope_, fetch_var_names_[i]); print_fetch_var(thread_scope_, fetch_var_names_[i]);
} }
fprintf(stderr, "IO percent: %f\n", read_time / total_time);
} }
} }
timeline.Start(); timeline.Start();

File diff suppressed because it is too large Load Diff

@ -139,6 +139,10 @@ class DataFeedDesc(object):
self.proto_desc.multi_slot_desc.slots[self.__name_to_index[ self.proto_desc.multi_slot_desc.slots[self.__name_to_index[
name]].is_used = True name]].is_used = True
def global_shuffle(self):
self.data.global_shuffle()
pass
def desc(self): def desc(self):
""" """
Returns a protobuf message for this DataFeedDesc Returns a protobuf message for this DataFeedDesc

@ -0,0 +1,109 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import core
__all__ = ['DatasetFactory']
class DatasetFactory(object):
def __init__(self):
pass
def create_dataset(self, datafeed_class):
datafeed_class = datafeed_class.capitalize()
try:
dataset = globals()[datafeed_class]()
except:
raise ValueError("datafeed class %s does not exist" %
datafeed_class)
class DatasetBase(object):
def __init__(self):
# define class name here
# to decide whether we need create in memory instance
self.proto_desc = data_feed_pb2.DataFeedDesc()
self.proto_desc.pipe_command = "cat"
def set_pipe_command(self, pipe_command):
"""
Set pipe command of current dataset
A pipe command is a UNIX pipeline command that can be used only
"""
self.proto_desc.pipe_command = pipe_command
def set_batch_size(self, batch_size):
"""
Set batch size. Will be effective during training
Example:
>>> data_feed = fluid.DataFeedDesc('data.proto')
>>> data_feed.set_batch_size(128)
Args:
batch_size: batch size
"""
self.proto_desc.batch_size = batch_size
def set_use_var(self, var_list):
multi_slot = self.proto_desc.multi_slot_desc()
for var in var_list:
slot_var = multi_slot.add()
slot_var.is_used = True
slot_var.name = var.name
if var.lod_level == 0:
slot_var.is_dense = True
if var.dtype == core.VarType.FP32:
slot_var.type = "float32"
elif var.dtype == core.VarType.INT64:
slot_var.type = "uint64"
else:
raise ValueError(
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
)
def desc(self):
"""
Returns a protobuf message for this DataFeedDesc
Example:
>>> data_feed = fluid.DataFeedDesc('data.proto')
>>> print(data_feed.desc())
Returns:
A string message
"""
return text_format.MessageToString(self.proto_desc)
class InMemoryDataset(DatasetBase):
def __init__(self):
super(InMemoryDataset.__init__())
self.proto_desc.name = "InMemoryDataFeed"
def local_shuffle(self):
pass
def global_shuffle(self):
pass
class QueueDataset(DatasetBase):
def __init__(self):
super(QueueDataset.__init__())
self.proto_desc.name = "MultiSlotDataFeed"
Loading…
Cancel
Save