Add a new DataFeed named PaddleBoxDataFeed (#23321)

* add paddleboxdatafeed
* add ifdef linux and boxps
* add untest for datafeed
* fix untest of test_paddlebox_datafeed
* fix untest
* rename function
revert-23830-2.0-beta
ShenLiang 5 years ago committed by GitHub
parent 75bd350710
commit 5223e2bbc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -58,6 +58,51 @@ namespace framework {
// while (reader->Next()) {
// // trainer do something
// }
union FeatureKey {
uint64_t uint64_feasign_;
float float_feasign_;
};
struct FeatureItem {
FeatureItem() {}
FeatureItem(FeatureKey sign, uint16_t slot) {
this->sign() = sign;
this->slot() = slot;
}
FeatureKey& sign() { return *(reinterpret_cast<FeatureKey*>(sign_buffer())); }
const FeatureKey& sign() const {
const FeatureKey* ret = reinterpret_cast<FeatureKey*>(sign_buffer());
return *ret;
}
uint16_t& slot() { return slot_; }
const uint16_t& slot() const { return slot_; }
private:
char* sign_buffer() const { return const_cast<char*>(sign_); }
char sign_[sizeof(FeatureKey)];
uint16_t slot_;
};
// sizeof Record is much less than std::vector<MultiSlotType>
struct Record {
std::vector<FeatureItem> uint64_feasigns_;
std::vector<FeatureItem> float_feasigns_;
std::string ins_id_;
std::string content_;
uint64_t search_id;
uint32_t rank;
uint32_t cmatch;
};
struct PvInstanceObject {
std::vector<Record*> ads;
void merge_instance(Record* ins) { ads.push_back(ins); }
};
using PvInstance = PvInstanceObject*;
inline PvInstance make_pv_instance() { return new PvInstanceObject(); }
class DataFeed {
public:
DataFeed() {
@ -93,6 +138,13 @@ class DataFeed {
// This function is used for binding feed_vec memory in a given scope
virtual void AssignFeedVar(const Scope& scope);
// This function will do nothing at default
virtual void SetInputPvChannel(void* channel) {}
// This function will do nothing at default
virtual void SetOutputPvChannel(void* channel) {}
// This function will do nothing at default
virtual void SetConsumePvChannel(void* channel) {}
// This function will do nothing at default
virtual void SetInputChannel(void* channel) {}
// This function will do nothing at default
@ -106,6 +158,9 @@ class DataFeed {
// This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetParseContent(bool parse_content) {}
virtual void SetParseLogKey(bool parse_logkey) {}
virtual void SetEnablePvMerge(bool enable_pv_merge) {}
virtual void SetCurrentPhase(int current_phase) {}
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
@ -163,6 +218,8 @@ class DataFeed {
// The data read by DataFeed will be stored here
std::vector<LoDTensor*> feed_vec_;
LoDTensor* rank_offset_;
// the batch size defined by user
int default_batch_size_;
// current batch size
@ -226,6 +283,10 @@ class InMemoryDataFeed : public DataFeed {
virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
virtual bool Start();
virtual int Next();
virtual void SetInputPvChannel(void* channel);
virtual void SetOutputPvChannel(void* channel);
virtual void SetConsumePvChannel(void* channel);
virtual void SetInputChannel(void* channel);
virtual void SetOutputChannel(void* channel);
virtual void SetConsumeChannel(void* channel);
@ -233,6 +294,9 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseContent(bool parse_content);
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetCurrentPhase(int current_phase);
virtual void LoadIntoMemory();
protected:
@ -244,11 +308,18 @@ class InMemoryDataFeed : public DataFeed {
int thread_num_;
bool parse_ins_id_;
bool parse_content_;
bool parse_logkey_;
bool enable_pv_merge_;
int current_phase_{-1}; // only for untest
std::ifstream file_;
std::shared_ptr<FILE> fp_;
paddle::framework::ChannelObject<T>* input_channel_;
paddle::framework::ChannelObject<T>* output_channel_;
paddle::framework::ChannelObject<T>* consume_channel_;
paddle::framework::ChannelObject<PvInstance>* input_pv_channel_;
paddle::framework::ChannelObject<PvInstance>* output_pv_channel_;
paddle::framework::ChannelObject<PvInstance>* consume_pv_channel_;
};
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
@ -408,39 +479,6 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
return ar;
}
union FeatureKey {
uint64_t uint64_feasign_;
float float_feasign_;
};
struct FeatureItem {
FeatureItem() {}
FeatureItem(FeatureKey sign, uint16_t slot) {
this->sign() = sign;
this->slot() = slot;
}
FeatureKey& sign() { return *(reinterpret_cast<FeatureKey*>(sign_buffer())); }
const FeatureKey& sign() const {
const FeatureKey* ret = reinterpret_cast<FeatureKey*>(sign_buffer());
return *ret;
}
uint16_t& slot() { return slot_; }
const uint16_t& slot() const { return slot_; }
private:
char* sign_buffer() const { return const_cast<char*>(sign_); }
char sign_[sizeof(FeatureKey)];
uint16_t slot_;
};
// sizeof Record is much less than std::vector<MultiSlotType>
struct Record {
std::vector<FeatureItem> uint64_feasigns_;
std::vector<FeatureItem> float_feasigns_;
std::string ins_id_;
std::string content_;
};
struct RecordCandidate {
std::string ins_id_;
std::unordered_multimap<uint16_t, FeatureKey> feas;
@ -557,6 +595,27 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
virtual bool ParseOneInstance(Record* instance);
virtual bool ParseOneInstanceFromPipe(Record* instance);
virtual void PutToFeedVec(const std::vector<Record>& ins_vec);
virtual void GetMsgFromLogKey(const std::string& log_key, uint64_t* search_id,
uint32_t* cmatch, uint32_t* rank);
};
class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed {
public:
PaddleBoxDataFeed() {}
virtual ~PaddleBoxDataFeed() {}
protected:
virtual void Init(const DataFeedDesc& data_feed_desc);
virtual bool Start();
virtual int Next();
virtual void AssignFeedVar(const Scope& scope);
virtual void PutToFeedVec(const std::vector<PvInstance>& pv_vec);
virtual void PutToFeedVec(const std::vector<Record*>& ins_vec);
virtual int GetCurrentPhase();
virtual void GetRankOffset(const std::vector<PvInstance>& pv_vec,
int ins_number);
std::string rank_offset_name_;
int pv_batch_size_;
};
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)

@ -30,4 +30,6 @@ message DataFeedDesc {
optional MultiSlotDesc multi_slot_desc = 3;
optional string pipe_command = 4;
optional int32 thread_num = 5;
optional string rank_offset = 6;
optional int32 pv_batch_size = 7 [ default = 32 ];
}

@ -64,6 +64,7 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
REGISTER_DATAFEED_CLASS(PaddleBoxDataFeed);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed);
#endif

File diff suppressed because it is too large Load Diff

@ -65,6 +65,9 @@ class Dataset {
// set parse ins id
virtual void SetParseInsId(bool parse_ins_id) = 0;
virtual void SetParseContent(bool parse_content) = 0;
virtual void SetParseLogKey(bool parse_logkey) = 0;
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual void SetMergeBySid(bool is_merge) = 0;
// set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0;
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
@ -115,10 +118,18 @@ class Dataset {
virtual void DestroyReaders() = 0;
// get memory data size
virtual int64_t GetMemoryDataSize() = 0;
// get memory data size in input_pv_channel_
virtual int64_t GetPvDataSize() = 0;
// get shuffle data size
virtual int64_t GetShuffleDataSize() = 0;
// merge by ins id
virtual void MergeByInsId() = 0;
// merge pv instance
virtual void PreprocessInstance() = 0;
// divide pv instance
virtual void PostprocessInstance() = 0;
// only for untest
virtual void SetCurrentPhase(int current_phase) = 0;
virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num,
int consume_thread_num,
@ -161,6 +172,10 @@ class DatasetImpl : public Dataset {
virtual void SetChannelNum(int channel_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseContent(bool parse_content);
virtual void SetParseLogKey(bool parse_logkey);
virtual void SetEnablePvMerge(bool enable_pv_merge);
virtual void SetMergeBySid(bool is_merge);
virtual void SetMergeByInsId(int merge_size);
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
@ -192,8 +207,12 @@ class DatasetImpl : public Dataset {
virtual void CreateReaders();
virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize();
virtual int64_t GetPvDataSize();
virtual int64_t GetShuffleDataSize();
virtual void MergeByInsId() {}
virtual void PreprocessInstance() {}
virtual void PostprocessInstance() {}
virtual void SetCurrentPhase(int current_phase) {}
virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num,
int consume_thread_num,
@ -213,6 +232,10 @@ class DatasetImpl : public Dataset {
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_;
paddle::framework::Channel<T> input_channel_;
paddle::framework::Channel<PvInstance> input_pv_channel_;
std::vector<paddle::framework::Channel<PvInstance>> multi_pv_output_;
std::vector<paddle::framework::Channel<PvInstance>> multi_pv_consume_;
int channel_num_;
std::vector<paddle::framework::Channel<T>> multi_output_channel_;
std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
@ -238,6 +261,10 @@ class DatasetImpl : public Dataset {
bool merge_by_insid_;
bool parse_ins_id_;
bool parse_content_;
bool parse_logkey_;
bool merge_by_sid_;
bool enable_pv_merge_; // True means to merge pv
int current_phase_; // 1 join, 0 update
size_t merge_size_;
bool slots_shuffle_fea_eval_ = false;
bool gen_uni_feasigns_ = false;
@ -252,6 +279,9 @@ class MultiSlotDataset : public DatasetImpl<Record> {
public:
MultiSlotDataset() {}
virtual void MergeByInsId();
virtual void PreprocessInstance();
virtual void PostprocessInstance();
virtual void SetCurrentPhase(int current_phase);
virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num,
int consume_thread_num, int shard_num);
@ -266,6 +296,9 @@ class MultiSlotDataset : public DatasetImpl<Record> {
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
virtual ~MultiSlotDataset() {}
protected:
std::vector<Record> input_records_; // the real data
};
} // end namespace framework

@ -239,6 +239,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("get_memory_data_size", &framework::Dataset::GetMemoryDataSize,
py::call_guard<py::gil_scoped_release>())
.def("get_pv_data_size", &framework::Dataset::GetPvDataSize,
py::call_guard<py::gil_scoped_release>())
.def("get_shuffle_data_size", &framework::Dataset::GetShuffleDataSize,
py::call_guard<py::gil_scoped_release>())
.def("set_queue_num", &framework::Dataset::SetChannelNum,
@ -247,6 +249,19 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("set_parse_content", &framework::Dataset::SetParseContent,
py::call_guard<py::gil_scoped_release>())
.def("set_parse_logkey", &framework::Dataset::SetParseLogKey,
py::call_guard<py::gil_scoped_release>())
.def("set_merge_by_sid", &framework::Dataset::SetMergeBySid,
py::call_guard<py::gil_scoped_release>())
.def("preprocess_instance", &framework::Dataset::PreprocessInstance,
py::call_guard<py::gil_scoped_release>())
.def("postprocess_instance", &framework::Dataset::PostprocessInstance,
py::call_guard<py::gil_scoped_release>())
.def("set_current_phase", &framework::Dataset::SetCurrentPhase,
py::call_guard<py::gil_scoped_release>())
.def("set_enable_pv_merge", &framework::Dataset::SetEnablePvMerge,
py::call_guard<py::gil_scoped_release>())
.def("set_merge_by_lineid", &framework::Dataset::SetMergeByInsId,
py::call_guard<py::gil_scoped_release>())
.def("merge_by_lineid", &framework::Dataset::MergeByInsId,

@ -92,6 +92,23 @@ class DatasetBase(object):
"""
self.proto_desc.pipe_command = pipe_command
def set_rank_offset(self, rank_offset):
"""
Set rank_offset for merge_pv. It set the message of Pv.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_rank_offset("rank_offset")
Args:
rank_offset(str): rank_offset's name
"""
self.proto_desc.rank_offset = rank_offset
def set_fea_eval(self, record_candidate_size, fea_eval=True):
"""
set fea eval mode for slots shuffle to debug the importance level of
@ -154,6 +171,22 @@ class DatasetBase(object):
"""
self.proto_desc.batch_size = batch_size
def set_pv_batch_size(self, pv_batch_size):
"""
Set pv batch size. It will be effective during enable_pv_merge
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_pv_batch(128)
Args:
pv_batch_size(int): pv batch size
"""
self.proto_desc.pv_batch_size = pv_batch_size
def set_thread(self, thread_num):
"""
Set thread num, it is the num of readers.
@ -308,9 +341,18 @@ class InMemoryDataset(DatasetBase):
self.queue_num = None
self.parse_ins_id = False
self.parse_content = False
self.parse_logkey = False
self.merge_by_sid = True
self.enable_pv_merge = False
self.merge_by_lineid = False
self.fleet_send_sleep_seconds = None
def set_feed_type(self, data_feed_type):
"""
Set data_feed_desc
"""
self.proto_desc.name = data_feed_type
def _prepare_to_run(self):
"""
Set data_feed_desc before load or shuffle,
@ -324,6 +366,9 @@ class InMemoryDataset(DatasetBase):
self.dataset.set_queue_num(self.queue_num)
self.dataset.set_parse_ins_id(self.parse_ins_id)
self.dataset.set_parse_content(self.parse_content)
self.dataset.set_parse_logkey(self.parse_logkey)
self.dataset.set_merge_by_sid(self.merge_by_sid)
self.dataset.set_enable_pv_merge(self.enable_pv_merge)
self.dataset.set_data_feed_desc(self.desc())
self.dataset.create_channel()
self.dataset.create_readers()
@ -390,6 +435,112 @@ class InMemoryDataset(DatasetBase):
"""
self.parse_content = parse_content
def set_parse_logkey(self, parse_logkey):
"""
Set if Dataset need to parse logkey
Args:
parse_content(bool): if parse logkey or not
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_parse_logkey(True)
"""
self.parse_logkey = parse_logkey
def set_merge_by_sid(self, merge_by_sid):
"""
Set if Dataset need to merge sid. If not, one ins means one Pv.
Args:
merge_by_sid(bool): if merge sid or not
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_merge_by_sid(True)
"""
self.merge_by_sid = merge_by_sid
def set_enable_pv_merge(self, enable_pv_merge):
"""
Set if Dataset need to merge pv.
Args:
enable_pv_merge(bool): if enable_pv_merge or not
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_enable_pv_merge(True)
"""
self.enable_pv_merge = enable_pv_merge
def preprocess_instance(self):
"""
Merge pv instance and convey it from input_channel to input_pv_channel.
It will be effective when enable_pv_merge_ is True.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.load_into_memory()
dataset.preprocess_instance()
"""
self.dataset.preprocess_instance()
def set_current_phase(self, current_phase):
"""
Set current phase in train. It is useful for untest.
current_phase : 1 for join, 0 for update.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.load_into_memory()
dataset.set_current_phase(1)
"""
self.dataset.set_current_phase(current_phase)
def postprocess_instance(self):
"""
Divide pv instance and convey it to input_channel.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.load_into_memory()
dataset.preprocess_instance()
exe.train_from_dataset(dataset)
dataset.postprocess_instance()
"""
self.dataset.postprocess_instance()
def set_fleet_send_batch_size(self, fleet_send_batch_size=1024):
"""
Set fleet send batch size, default is 1024
@ -594,6 +745,30 @@ class InMemoryDataset(DatasetBase):
"""
self.dataset.release_memory()
def get_pv_data_size(self):
"""
Get memory data size of Pv, user can call this function to know the pv num
of ins in all workers after load into memory.
Note:
This function may cause bad performance, because it has barrier
Returns:
The size of memory pv data.
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
filelist = ["a.txt", "b.txt"]
dataset.set_filelist(filelist)
dataset.load_into_memory()
print dataset.get_pv_data_size()
"""
return self.dataset.get_pv_data_size()
def get_memory_data_size(self, fleet=None):
"""
Get memory data size, user can call this function to know the num
@ -808,6 +983,7 @@ class BoxPSDataset(InMemoryDataset):
"""
super(BoxPSDataset, self).__init__()
self.boxps = core.BoxPS(self.dataset)
self.proto_desc.name = "PaddleBoxDataFeed"
def set_date(self, date):
"""
@ -895,3 +1071,6 @@ class BoxPSDataset(InMemoryDataset):
if not self.is_user_set_queue_num:
self.dataset.dynamic_adjust_channel_num(thread_num, True)
self.dataset.dynamic_adjust_readers_num(thread_num)
def _dynamic_adjust_after_train(self):
pass

@ -44,6 +44,7 @@ endif()
if(WIN32)
LIST(REMOVE_ITEM TEST_OPS test_boxps)
LIST(REMOVE_ITEM TEST_OPS test_paddlebox_datafeed)
LIST(REMOVE_ITEM TEST_OPS test_trainer_desc)
LIST(REMOVE_ITEM TEST_OPS test_multiprocess_reader_exception)
LIST(REMOVE_ITEM TEST_OPS test_avoid_twice_initialization)
@ -59,6 +60,7 @@ endif()
if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_pipeline)
LIST(REMOVE_ITEM TEST_OPS test_boxps)
LIST(REMOVE_ITEM TEST_OPS test_paddlebox_datafeed)
endif()
list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185

@ -0,0 +1,147 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.fluid as fluid
import paddle.fluid.core as core
import os
import unittest
import paddle.fluid.layers as layers
from paddle.fluid.layers.nn import _pull_box_sparse
class TestDataFeed(unittest.TestCase):
""" TestBaseCase(Merge PV) """
def setUp(self):
self.batch_size = 10
self.pv_batch_size = 10
self.enable_pv_merge = True
self.merge_by_sid = True
def set_data_config(self):
self.dataset = fluid.DatasetFactory().create_dataset("BoxPSDataset")
self.dataset.set_feed_type("PaddleBoxDataFeed")
self.dataset.set_parse_logkey(True)
self.dataset.set_thread(1)
self.dataset.set_enable_pv_merge(self.enable_pv_merge)
self.dataset.set_batch_size(self.batch_size)
if self.enable_pv_merge:
self.dataset.set_merge_by_sid(self.merge_by_sid)
self.dataset.set_rank_offset("rank_offset")
self.dataset.set_pv_batch_size(self.pv_batch_size)
def test_pboxdatafeed(self):
self.run_dataset(False)
def test_pboxdatafeed(self):
self.run_dataset(True)
def run_dataset(self, is_cpu):
x = fluid.layers.data(name='x', shape=[1], dtype='int64', lod_level=0)
y = fluid.layers.data(name='y', shape=[1], dtype='int64', lod_level=0)
rank_offset = fluid.layers.data(
name="rank_offset",
shape=[-1, 7],
dtype="int32",
lod_level=0,
append_batch_size=False)
emb_x, emb_y = _pull_box_sparse([x, y], size=2)
emb_xp = _pull_box_sparse(x, size=2)
concat = layers.concat([emb_x, emb_y], axis=1)
fc = layers.fc(input=concat,
name="fc",
size=1,
num_flatten_dims=1,
bias_attr=False)
loss = layers.reduce_mean(fc)
place = fluid.CPUPlace() if is_cpu or not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
with open("test_run_with_dump_a.txt", "w") as f:
data = "1 1702f830eee19501ad7429505f714c1d 1 1 1 9\n"
data += "1 1702f830eee19502ad7429505f714c1d 1 2 1 8\n"
data += "1 1702f830eee19503ad7429505f714c1d 1 3 1 7\n"
data += "1 1702f830eee0de01ad7429505f714c2d 1 4 1 6\n"
data += "1 1702f830eee0df01ad7429505f714c3d 1 5 1 5\n"
data += "1 1702f830eee0df02ad7429505f714c3d 1 6 1 4\n"
f.write(data)
with open("test_run_with_dump_b.txt", "w") as f:
data = "1 1702f830fff22201ad7429505f715c1d 1 1 1 1\n"
data += "1 1702f830fff22202ad7429505f715c1d 1 2 1 2\n"
data += "1 1702f830fff22203ad7429505f715c1d 1 3 1 3\n"
data += "1 1702f830fff22101ad7429505f714ccd 1 4 1 4\n"
data += "1 1702f830fff22102ad7429505f714ccd 1 5 1 5\n"
data += "1 1702f830fff22103ad7429505f714ccd 1 6 1 6\n"
data += "1 1702f830fff22104ad7429505f714ccd 1 6 1 7\n"
f.write(data)
self.set_data_config()
self.dataset.set_use_var([x, y])
self.dataset.set_filelist(
["test_run_with_dump_a.txt", "test_run_with_dump_b.txt"])
optimizer = fluid.optimizer.SGD(learning_rate=0.5)
optimizer = fluid.optimizer.PipelineOptimizer(
optimizer,
cut_list=[],
place_list=[place],
concurrency_list=[1],
queue_size=1,
sync_steps=-1)
optimizer.minimize(loss)
exe.run(fluid.default_startup_program())
self.dataset.set_current_phase(1)
self.dataset.load_into_memory()
self.dataset.preprocess_instance()
self.dataset.begin_pass()
pv_num = self.dataset.get_pv_data_size()
exe.train_from_dataset(
program=fluid.default_main_program(),
dataset=self.dataset,
print_period=1)
self.dataset.set_current_phase(0)
self.dataset.postprocess_instance()
exe.train_from_dataset(
program=fluid.default_main_program(),
dataset=self.dataset,
print_period=1)
self.dataset.end_pass(True)
os.remove("test_run_with_dump_a.txt")
os.remove("test_run_with_dump_b.txt")
class TestDataFeed2(TestDataFeed):
""" TestBaseCase(Merge PV not merge by sid) """
def setUp(self):
self.batch_size = 10
self.pv_batch_size = 10
self.enable_pv_merge = True
self.merge_by_sid = False
class TestDataFeed3(TestDataFeed):
""" TestBaseCase(Not Merge PV) """
def setUp(self):
self.batch_size = 10
self.pv_batch_size = 10
self.enable_pv_merge = False
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save