AsyncExecutor (#14627)

* AsyncExecutor: C++ side

* Google naming conventions

* Rename MultiExecutor to AsyncExecutor

* pybind with async_executor

* Naming convention

* remove some flags and unused code

* add refactored file of async_executor and data_feed

* clear async executor interface and add data feed factory

* split async executor into executor_thread_worker and async_executor, refactor pybind, add datafeed and corresponding proto

* Fix async_executor interfaces: 1) Remove all protobufs; 2) Stop after each epoch

* refine async_executor_refactor.cc

* add some files about datafeed

* Revert "add some files about datafeed"

This reverts commit 8ee8133ab841196925a2812b76f18d2812a6701d.

* Interface rework

* add MultiSlotDataFeed

* Creating DataFeedDesc from .proto file, then manipulate it (add/del fields etc) from python side

* update data_feed for add MultiSlotDataFeed

* update datafeed and async_executor to run bow_net demo

* fix bug that finish_set_filelist failed in multithread

* delete finish_binding_memory_(flag), because it can not be marked under the current interface

* Fix bug

* update async_executor.py for support set_use_slots

* update async_executor.py for support set_use_slots and set set_dense_slots

* fix bug that when the number of files is less than the number of threads, it will fetch nan

* remove redundant code, and make executor exit when set a illegal queue size

* add batch_size check

* add MultiSlotDesc

* Revert "add MultiSlotDesc"

This reverts commit 2e72ebfad364ed6b5dcc75f38ffb2a1fdec83d8e.

* add some checkpoint in DataFeedDesc

* add CheckFile function in MultiSlotDataFeed

* update something error info

* fix deaded lock bug

* Fix fetch variable

* Merge error

* fix code style in async_executor

* using one lock blocking queue replace two lock blocking queue because of some bugs

* update code style

* add utest for data_feed

* Fix fetch var

* update utest for data_feed for multithread

* update SetFileList info

* fix bug in utest of data_feed

* Add comments for python

* Add comments for python code

* Fix pybind.cc with new pybind11 version

* add note for DataFeedDesc's set_use_slots function

* Add save_model

* update data_feed_test for multi-type

* add comment for executor_thread_worker

* Remove unused code

* update data_feed_test for generate test data file

* removed unnecessary interfaces and add comments

* c++ style check

* update data_feed.cc

* AsyncExecutor: C++ side

Google naming conventions

Rename MultiExecutor to AsyncExecutor

pybind with async_executor

Naming convention

remove some flags and unused code

add refactored file of async_executor and data_feed

clear async executor interface and add data feed factory

split async executor into executor_thread_worker and async_executor, refactor pybind, add datafeed and corresponding proto

Fix async_executor interfaces: 1) Remove all protobufs; 2) Stop after each epoch

refine async_executor_refactor.cc

add some files about datafeed

Revert "add some files about datafeed"

This reverts commit 8ee8133ab841196925a2812b76f18d2812a6701d.

add MultiSlotDataFeed

Interface rework

Creating DataFeedDesc from .proto file, then manipulate it (add/del fields etc) from python side

update datafeed and async_executor to run bow_net demo

update async_executor.py for support set_use_slots

Fix bug

update async_executor.py for support set_use_slots and set set_dense_slots

fix bug that when the number of files is less than the number of threads, it will fetch nan

remove redundant code, and make executor exit when set a illegal queue size

add MultiSlotDesc

Revert "add MultiSlotDesc"

This reverts commit 2e72ebfad364ed6b5dcc75f38ffb2a1fdec83d8e.

add some checkpoint in DataFeedDesc

Fix fetch variable

fix code style in async_executor

Fix fetch var

add utest for data_feed

Add comments for python

update utest for data_feed for multithread

fix bug in utest of data_feed

Add comments for python code

Fix pybind.cc with new pybind11 version

add note for DataFeedDesc's set_use_slots function

update data_feed_test for multi-type

Add save_model

update data_feed_test for generate test data file

removed unnecessary interfaces and add comments

add comment for executor_thread_worker

Remove unused code

update data_feed.cc

c++ style check

* commit for code style

* commit for code style

* commit for code style

* commit for code style

* Comment away __init__ in async_executor.py

* clang-format fix test=develop

* use PADDLE_THROW instead of exit(-1); use unique_ptr to manage scope var in data_feed_test.cc

* commit for update code style

* commit for update code style

* Add async_executor demo; Remove some methods
test=develop

* commit for update code style

* commit for update code style

* commit for update code style

* update API.spec

* AsyncExecutor
test=develop

* AsyncExecutor
test=develop

* AsyncExecutor
test=develop

* AsyncExecutor
test=develop

* Fix API.spec
test=develop

* Fix API.spec
test=develop

* Fix windows build error
test=develop

* FIx windows build error
test=develop

* FIx windows build error
test=develop

* FIx windows build error
test=develop

* Fix Windows Build
test=develop

* Fix Windows Build
test=develop

* Fix Windows Build
test=develop

* Fix code style
test=develop

* Fix code style
test=develop

* update datafeed

* Fix code style
test=develop

* update data_feed_test for test Tensor test=develop

* Fix code style
test=develop

* Fix windows build failure
test=develop

* Fix code style and windows build failure
test=develop

* Fix PYTHON3.5 build failure
test=develop

* AsyncExecutor API
test=develop
f7c96f079b
Wang Guibao 6 years ago committed by GitHub
parent 78738d6c86
commit 41e19eb431
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -32,6 +32,13 @@ paddle.fluid.BuildStrategy.ReduceStrategy.__init__ __init__(self: paddle.fluid.c
paddle.fluid.BuildStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.BuildStrategy) -> None
paddle.fluid.create_lod_tensor ArgSpec(args=['data', 'recursive_seq_lens', 'place'], varargs=None, keywords=None, defaults=None)
paddle.fluid.create_random_int_lodtensor ArgSpec(args=['recursive_seq_lens', 'base_shape', 'place', 'low', 'high'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DataFeedDesc.__init__ ArgSpec(args=['self', 'proto_file'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DataFeedDesc.desc ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DataFeedDesc.set_batch_size ArgSpec(args=['self', 'batch_size'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DataFeedDesc.set_dense_slots ArgSpec(args=['self', 'dense_slots_name'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DataFeedDesc.set_use_slots ArgSpec(args=['self', 'use_slots_name'], varargs=None, keywords=None, defaults=None)
paddle.fluid.AsyncExecutor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.AsyncExecutor.run ArgSpec(args=['self', 'program', 'data_feed', 'filelist', 'thread_num', 'fetch', 'debug'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.io.save_vars ArgSpec(args=['executor', 'dirname', 'main_program', 'vars', 'predicate', 'filename'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.io.save_params ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.save_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))

@ -34,6 +34,7 @@ add_subdirectory(ir)
add_subdirectory(details)
# ddim lib
proto_library(framework_proto SRCS framework.proto)
proto_library(async_executor_proto SRCS data_feed.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
@ -135,7 +136,7 @@ endif(NOT WIN32)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
py_proto_compile(framework_py_proto SRCS framework.proto)
py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
@ -157,18 +158,19 @@ endif(NOT WIN32)
cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
if(NOT WIN32)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper)
else(NOT WIN32)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
endif(NOT WIN32)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
@ -176,8 +178,11 @@ endif()
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
graph build_strategy
fast_threaded_ssa_graph_executor)
fast_threaded_ssa_graph_executor variable_helper)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper)
cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry

@ -0,0 +1,138 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/async_executor.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "gflags/gflags.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
namespace paddle {
namespace framework {
AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place)
: root_scope_(scope), place_(place) {}
void AsyncExecutor::CreateThreads(
ExecutorThreadWorker* worker, const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names, Scope* root_scope,
const int thread_index, const bool debug) {
worker->SetThreadId(thread_index);
worker->SetDebug(debug);
worker->SetRootScope(root_scope);
worker->CreateThreadResource(main_program, place_);
worker->SetDataFeed(reader);
worker->SetFetchVarNames(fetch_var_names);
worker->BindingDataFeedMemory();
}
void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
const int thread_num, const DataFeedDesc& data_feed_desc,
const std::vector<std::string>& filelist) {
readers.resize(thread_num);
for (size_t i = 0; i < readers.size(); ++i) {
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
readers[i]->Init(data_feed_desc); // set batch_size and queue_size here
}
readers[0]->SetFileList(filelist);
}
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_var_names,
const bool debug) {
std::vector<std::thread> threads;
auto& block = main_program.Block(0);
for (auto var_name : fetch_var_names) {
auto var_desc = block.FindVar(var_name);
auto shapes = var_desc->GetShape();
PADDLE_ENFORCE(shapes[shapes.size() - 1] == 1,
"var %s: Fetched var has wrong shape, "
"only variables with the last dimension size 1 supported",
var_name);
}
DataFeedDesc data_feed_desc;
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc);
int actual_thread_num = thread_num;
int file_cnt = filelist.size();
PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");
if (actual_thread_num > file_cnt) {
VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
<< ". Changing thread_num = " << file_cnt;
actual_thread_num = file_cnt;
}
/*
readerDesc: protobuf description for reader initlization
argument: class_name, batch_size, use_slot, queue_size, buffer_size,
padding_index
reader:
1) each thread has a reader, reader will read input data and
put it into input queue
2) each reader has a Next() iterface, that can fetch an instance
from the input queue
*/
// todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num);
for (auto& worker : workers) {
worker.reset(new ExecutorThreadWorker);
}
// prepare thread resource here
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
CreateThreads(workers[thidx].get(), main_program, readers[thidx],
fetch_var_names, root_scope_, thidx, debug);
}
// start executing ops in multiple threads
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
threads.push_back(
std::thread(&ExecutorThreadWorker::TrainFiles, workers[thidx].get()));
}
for (auto& th : threads) {
th.join();
}
root_scope_->DropKids();
return;
}
} // einit_modelnd namespace framework
} // end namespace paddle

@ -0,0 +1,58 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <set>
#include <string>
#include <thread> // NOLINT
#include <typeinfo>
#include <vector>
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
class AsyncExecutor {
public:
AsyncExecutor(Scope* scope, const platform::Place& place);
virtual ~AsyncExecutor() {}
void RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_names,
const bool debug = false);
private:
void CreateThreads(ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names,
Scope* root_scope, const int thread_index,
const bool debug);
public:
Scope* root_scope_;
platform::Place place_;
};
} // namespace framework
} // namespace paddle

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,30 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
package paddle.framework;
message Slot {
required string name = 1;
required string type = 2;
optional bool is_dense = 3 [ default = false ];
optional bool is_used = 4 [ default = false ];
}
message MultiSlotDesc { repeated Slot slots = 1; }
message DataFeedDesc {
optional string name = 1;
optional int32 batch_size = 2 [ default = 32 ];
optional MultiSlotDesc multi_slot_desc = 3;
}

@ -0,0 +1,64 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace framework {
typedef std::shared_ptr<DataFeed> (*Createdata_feedFunction)();
typedef std::unordered_map<std::string, Createdata_feedFunction> data_feedMap;
data_feedMap g_data_feed_map;
#define REGISTER_DATAFEED_CLASS(data_feed_class) \
namespace { \
std::shared_ptr<DataFeed> Creator_##data_feed_class() { \
return std::shared_ptr<DataFeed>(new data_feed_class); \
} \
class __Registerer_##data_feed_class { \
public: \
__Registerer_##data_feed_class() { \
g_data_feed_map[#data_feed_class] = &Creator_##data_feed_class; \
} \
}; \
__Registerer_##data_feed_class g_registerer_##data_feed_class; \
} // namespace
std::string DataFeedFactory::DataFeedTypeList() {
std::string data_feed_types;
for (auto iter = g_data_feed_map.begin(); iter != g_data_feed_map.end();
++iter) {
if (iter != g_data_feed_map.begin()) {
data_feed_types += ", ";
}
data_feed_types += iter->first;
}
return data_feed_types;
}
std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
std::string data_feed_class) {
if (g_data_feed_map.count(data_feed_class) < 1) {
exit(-1);
}
return g_data_feed_map[data_feed_class]();
}
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
} // namespace framework
} // namespace paddle

@ -0,0 +1,29 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace framework {
class DataFeedFactory {
public:
static std::string DataFeedTypeList();
static std::shared_ptr<DataFeed> CreateDataFeed(std::string data_feed_class);
};
} // namespace framework
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -16,7 +16,7 @@
#include <stdexcept>
#include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/details/reference_count_op_handle.h"

@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
@ -114,36 +115,6 @@ void Executor::Close() {
#endif
}
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) {
var->GetMutable<SelectedRows>();
} else if (var_type == proto::VarType::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope*>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarType::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, RAW]",
var_type);
}
}
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
int block_id) {
auto& global_block = pdesc.Block(block_id);

@ -26,7 +26,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
extern void InitializeVariable(Variable* var, proto::VarType::Type var_type);
template <typename T>
std::unordered_map<std::string, T> GetNonPersistableReferenceCount(

@ -0,0 +1,223 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "gflags/gflags.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
namespace paddle {
namespace framework {
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0);
op_names_.clear();
for (auto& op_desc : block.AllOps()) {
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
op_names_.push_back(op_desc->Type());
OperatorBase* local_op_ptr = local_op.release();
ops_.push_back(local_op_ptr);
continue;
}
}
void ExecutorThreadWorker::CreateThreadResource(
const framework::ProgramDesc& program,
const paddle::platform::Place& place) {
CreateThreadScope(program);
CreateThreadOperators(program);
SetMainProgram(program);
SetPlace(place);
}
void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0);
PADDLE_ENFORCE_NOT_NULL(
root_scope_, "root_scope should be set before creating thread scope");
thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
} else {
auto* ptr = thread_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
}
}
}
void ExecutorThreadWorker::SetDataFeed(
const std::shared_ptr<DataFeed>& datafeed) {
thread_reader_ = datafeed;
}
void ExecutorThreadWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed =
thread_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
thread_reader_->AddFeedVar(thread_scope_->Var(name), name);
}
}
void ExecutorThreadWorker::SetFetchVarNames(
const std::vector<std::string>& fetch_var_names) {
fetch_var_names_.clear();
fetch_var_names_.insert(fetch_var_names_.end(), fetch_var_names.begin(),
fetch_var_names.end());
}
void ExecutorThreadWorker::SetDevice() {
#if defined _WIN32 || defined __APPLE__
return;
#else
static unsigned concurrency_cap = std::thread::hardware_concurrency();
int thread_id = this->thread_id_;
if (thread_id < concurrency_cap) {
unsigned proc = thread_id;
cpu_set_t mask;
CPU_ZERO(&mask);
CPU_SET(proc, &mask);
if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
VLOG(1) << "WARNING: Failed to set thread affinity for thread "
<< thread_id;
} else {
CPU_ZERO(&mask);
if ((0 != sched_getaffinity(0, sizeof(mask), &mask)) ||
(CPU_ISSET(proc, &mask) == 0)) {
VLOG(3) << "WARNING: Failed to set thread affinity for thread "
<< thread_id;
}
}
} else {
VLOG(1) << "WARNING: Failed to set thread affinity for thread "
<< thread_id;
}
#endif
}
template <typename T>
void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
auto inspect = lod_tensor.data<T>();
auto element_num = lod_tensor.numel();
std::ostringstream sstream;
sstream << var_name << " (element num " << element_num << "): [";
sstream << inspect[0];
for (int j = 1; j < element_num; ++j) {
sstream << " " << inspect[j];
}
sstream << "]";
std::cout << sstream.str() << std::endl;
}
void print_fetch_var(Scope* scope, std::string var_name) {
const LoDTensor& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
if (std::type_index(tensor.type()) ==
std::type_index(typeid(platform::float16))) {
print_lod_tensor<platform::float16>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(float))) {
print_lod_tensor<float>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(double))) {
print_lod_tensor<double>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(int))) {
print_lod_tensor<int>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int64_t))) {
print_lod_tensor<int64_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(bool))) {
print_lod_tensor<bool>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(uint8_t))) {
print_lod_tensor<uint8_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int16_t))) {
print_lod_tensor<int16_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int8_t))) {
print_lod_tensor<int8_t>(var_name, tensor);
} else {
VLOG(1) << "print_fetch_var: unrecognized data type:"
<< tensor.type().name();
}
return;
}
void ExecutorThreadWorker::TrainFiles() {
// todo: configurable
SetDevice();
int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear();
fetch_values_.resize(fetch_var_num);
thread_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = thread_reader_->Next()) > 0) {
// executor run here
for (auto& op : ops_) {
op->Run(*thread_scope_, place_);
}
++batch_cnt;
thread_scope_->DropKids();
if (debug_ == false || thread_id_ != 0) {
continue;
}
for (int i = 0; i < fetch_var_num; ++i) {
print_fetch_var(thread_scope_, fetch_var_names_[i]);
} // end for (int i = 0...)
} // end while ()
}
void ExecutorThreadWorker::SetThreadId(int tid) { thread_id_ = tid; }
void ExecutorThreadWorker::SetPlace(const platform::Place& place) {
place_ = place;
}
void ExecutorThreadWorker::SetMainProgram(
const ProgramDesc& main_program_desc) {
main_program_.reset(new ProgramDesc(main_program_desc));
}
void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
root_scope_ = g_scope;
}
} // einit_modelnd namespace framework
} // end namespace paddle

@ -0,0 +1,88 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <set>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
void CreateTensor(Variable* var, proto::VarType::Type var_type);
class ExecutorThreadWorker {
public:
ExecutorThreadWorker()
: thread_id_(-1), root_scope_(NULL), thread_scope_(NULL), debug_(false) {}
~ExecutorThreadWorker() {}
void CreateThreadResource(const framework::ProgramDesc& program,
const paddle::platform::Place& place);
void SetThreadId(int tid);
void SetDebug(const bool debug) { debug_ = debug; }
void SetRootScope(Scope* g_scope);
// set cpu device in this function
// cpu binding is used by default
void SetDevice();
// since we read data into memory that can not be accessed by program
// we need to bind memory of data with corresponding variables in program
// this function should be called after data feed is set
void BindingDataFeedMemory();
// set data feed declared in executor
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
// A multi-thread training function
void TrainFiles();
// set fetch variable names from python interface assigned by users
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
private:
void CreateThreadScope(const framework::ProgramDesc& program);
void CreateThreadOperators(const framework::ProgramDesc& program);
void SetMainProgram(const ProgramDesc& main_program_desc);
void SetPlace(const paddle::platform::Place& place);
protected:
// thread index
std::shared_ptr<DataFeed> thread_reader_; // shared queue, thread buffer
int thread_id_;
// operator name
std::vector<std::string> op_names_;
// thread level, local operators for forward and backward
std::vector<OperatorBase*> ops_;
// main program for training
std::unique_ptr<framework::ProgramDesc> main_program_;
// execution place
platform::Place place_;
// root scope for model parameters
Scope* root_scope_;
// a thread scope, father scope is global score which is shared
Scope* thread_scope_;
private:
std::vector<std::string> fetch_var_names_;
std::vector<std::vector<float>> fetch_values_;
bool debug_;
};
} // namespace framework
} // namespace paddle

@ -21,42 +21,11 @@
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
// These code can be shared with Executor.
static void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) {
var->GetMutable<SelectedRows>();
} else if (var_type == proto::VarType::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope *>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarType::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
var_type);
}
}
void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc,
int block_id, bool with_feed_fetch_ops) {
if (!scope) {

@ -0,0 +1,60 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/variable_helper.h"
#include <vector>
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) {
var->GetMutable<SelectedRows>();
} else if (var_type == proto::VarType::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope*>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarType::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, RAW]",
var_type);
}
}
} // namespace framework
} // namespace paddle

@ -0,0 +1,22 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace framework {
void InitializeVariable(Variable *var, proto::VarType::Type var_type);
}
}

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include <iostream>
#include <string>
#include <vector>
@ -20,7 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/printf.h"

@ -1,6 +1,6 @@
set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method pass_builder parallel_executor profiler)
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc)
set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler)
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc)
if(WITH_PYTHON)
if(WITH_AMD_GPU)
hip_library(paddle_pybind SHARED

@ -0,0 +1,53 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3)
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <string>
#include <vector>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/async_executor_py.h"
namespace py = pybind11;
namespace pd = paddle::framework;
namespace paddle {
namespace pybind {
using set_name_func = void (pd::DataFeedDesc::*)(const std::string&);
void BindAsyncExecutor(py::module* m) {
py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
.def(py::init([](framework::Scope* scope, const platform::Place& place) {
return std::unique_ptr<framework::AsyncExecutor>(
new framework::AsyncExecutor(scope, place));
}))
.def("run_from_files", &framework::AsyncExecutor::RunFromFile);
} // end BindAsyncExecutor
} // end namespace pybind
} // end namespace paddle

@ -0,0 +1,28 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindAsyncExecutor(py::module* m);
} // namespace pybind
} // namespace paddle

@ -42,6 +42,7 @@ limitations under the License. */
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/async_executor_py.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/protobuf.h"
@ -932,6 +933,7 @@ All parameter, weight, gradient are variables in Paddle.
});
BindRecordIOWriter(&m);
BindAsyncExecutor(&m);
}
} // namespace pybind
} // namespace paddle

@ -20,6 +20,13 @@ from .framework import *
# import all class inside executor into fluid module
from . import executor
from .executor import *
from . import data_feed_desc
from .data_feed_desc import *
from . import async_executor
from .async_executor import *
from . import trainer
from . import inferencer
@ -54,7 +61,8 @@ Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + \
trainer.__all__ + inferencer.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + lod_tensor.__all__ + [
parallel_executor.__all__ + lod_tensor.__all__ + \
data_feed_desc.__all__ + async_executor.__all__ + [
'io',
'initializer',
'layers',

@ -0,0 +1,151 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import contextlib
import six
from .framework import Program, default_main_program, Variable
from . import core
from .executor import global_scope, Executor
from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import io
from .data_feed_desc import DataFeedDesc
__all__ = ['AsyncExecutor']
class AsyncExecutor(object):
"""
An asynchronous Executor in Python. Through exploiting the power of
multi-core processor and data queueing, AsyncExecutor makes data reading
and cosuming decoupled, each run in multiple threads in parallel.
Instead of reading data in python side, AsyncExecutor accepts a training
file list, which will be retrieved in C++, then training inputs will be
read, parsed and fed to training network within C++ code.
AsyncExecutor is in active development and the API might change in the near
future.
Example:
>>> data_feed = fluid.DataFeedDesc('data.proto')
>>> startup_program = fluid.default_startup_program()
>>> main_program = fluid.default_main_program()
>>> filelist = ["train_data/part-%d" % i for i in range(100)]
>>> thread_num = len(filelist) / 4
>>>
>>> place = fluid.CPUPlace()
>>> async_executor = fluid.AsyncExecutor(place)
>>>
>>> async_executor.run_startup_program(startup_program)
>>>
>>> epoch = 10
>>> for i in range(epoch):
>>> async_executor.run(main_program,
>>> data_feed,
>>> filelist,
>>> thread_num,
>>> [acc],
>>> debug=False)
Args:
place(fluid.CPUPlace|None): indicate the executor run on which device.
Only CPUPlace supported
Note:
For debugging complicated network in parallel-GPUs, you can test it
on the executor. They has the exactly same arguments, and expected
the same results.
Note: Only running on CPUPlace supported.
"""
def __init__(self, place=None):
if place is None:
place = core.CPUPlace()
if not isinstance(place, core.CPUPlace):
raise ValueError("AsyncExecutor only supports CPU device")
p = core.Place()
p.set_place(place)
scope = global_scope()
self.executor = core.AsyncExecutor(scope, p)
def run(self, program, data_feed, filelist, thread_num, fetch, debug=False):
"""
Run program by this AsyncExecutor. Training dataset will be in filelist.
Users can also inspect certain variables by naming them in parameter
:code:`fetch`, like in fluid.Executor. Unlike fluid.Executor, however,
AsyncExecutor doesn't return fetched variables, instead, it will dump
the values of each fetched variable to stdandard output.
Running the dataset will be on multiple threads, within each a thread
local scope will be created, then all OPs also created in that scope.
Parameters are updated by all the OPs simultaneously.
Args:
program(Program): the program that need to run, if not provied,
then default_main_program will be used.
data_feed(DataFeedDesc): A DataFeedDesc object
filelist(str): a file containing the training dataset file list
thread_num(int): number of concurrent training threads. See
:code:`Note` for how to set this properly
fetch(str|list): the var name or a list of var names to inspect
debug(bool): When set to True, fetch vars will be printed to
standard output after each minibatch
Note:
the executor will run all operators in the program but not only
the operators dependent by the fetch_list.
Note:
Running AsyncExecutor will be on multiple threads, each bound to a
CPU core. To achieve best performance, it's suggested to set thread
num to be equal or slightly less than that of CPU cores.
"""
if program is None:
program = default_main_program()
program_desc = program.desc
if data_feed is None:
raise ValueError('ValueError: data_feed should be provided')
if filelist is None:
raise ValueError('ValueError: filelist should be provided')
if isinstance(filelist, str):
filelist = [filelist]
if not isinstance(thread_num, int):
raise TypeError('TypeError: thread_num should be a positive number')
if fetch is not None:
if isinstance(fetch, Variable):
fetch = [fetch]
fetch_var_names = [var.name for var in fetch]
for fetch_var in fetch:
shape = fetch_var.shape
if shape[len(shape) - 1] != 1:
raise AssertionError(
"%s: Fetch variable has wrong shape. Only varibles "
"with the last dimension size 1 supported." %
(fetch_var.name))
self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num,
fetch_var_names, debug)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save