AsyncExecutor (#14627)
	
		
	
				
					
				
			* AsyncExecutor: C++ side * Google naming conventions * Rename MultiExecutor to AsyncExecutor * pybind with async_executor * Naming convention * remove some flags and unused code * add refactored file of async_executor and data_feed * clear async executor interface and add data feed factory * split async executor into executor_thread_worker and async_executor, refactor pybind, add datafeed and corresponding proto * Fix async_executor interfaces: 1) Remove all protobufs; 2) Stop after each epoch * refine async_executor_refactor.cc * add some files about datafeed * Revert "add some files about datafeed" This reverts commit 8ee8133ab841196925a2812b76f18d2812a6701d. * Interface rework * add MultiSlotDataFeed * Creating DataFeedDesc from .proto file, then manipulate it (add/del fields etc) from python side * update data_feed for add MultiSlotDataFeed * update datafeed and async_executor to run bow_net demo * fix bug that finish_set_filelist failed in multithread * delete finish_binding_memory_(flag), because it can not be marked under the current interface * Fix bug * update async_executor.py for support set_use_slots * update async_executor.py for support set_use_slots and set set_dense_slots * fix bug that when the number of files is less than the number of threads, it will fetch nan * remove redundant code, and make executor exit when set a illegal queue size * add batch_size check * add MultiSlotDesc * Revert "add MultiSlotDesc" This reverts commit 2e72ebfad364ed6b5dcc75f38ffb2a1fdec83d8e. * add some checkpoint in DataFeedDesc * add CheckFile function in MultiSlotDataFeed * update something error info * fix deaded lock bug * Fix fetch variable * Merge error * fix code style in async_executor * using one lock blocking queue replace two lock blocking queue because of some bugs * update code style * add utest for data_feed * Fix fetch var * update utest for data_feed for multithread * update SetFileList info * fix bug in utest of data_feed * Add comments for python * Add comments for python code * Fix pybind.cc with new pybind11 version * add note for DataFeedDesc's set_use_slots function * Add save_model * update data_feed_test for multi-type * add comment for executor_thread_worker * Remove unused code * update data_feed_test for generate test data file * removed unnecessary interfaces and add comments * c++ style check * update data_feed.cc * AsyncExecutor: C++ side Google naming conventions Rename MultiExecutor to AsyncExecutor pybind with async_executor Naming convention remove some flags and unused code add refactored file of async_executor and data_feed clear async executor interface and add data feed factory split async executor into executor_thread_worker and async_executor, refactor pybind, add datafeed and corresponding proto Fix async_executor interfaces: 1) Remove all protobufs; 2) Stop after each epoch refine async_executor_refactor.cc add some files about datafeed Revert "add some files about datafeed" This reverts commit 8ee8133ab841196925a2812b76f18d2812a6701d. add MultiSlotDataFeed Interface rework Creating DataFeedDesc from .proto file, then manipulate it (add/del fields etc) from python side update datafeed and async_executor to run bow_net demo update async_executor.py for support set_use_slots Fix bug update async_executor.py for support set_use_slots and set set_dense_slots fix bug that when the number of files is less than the number of threads, it will fetch nan remove redundant code, and make executor exit when set a illegal queue size add MultiSlotDesc Revert "add MultiSlotDesc" This reverts commit 2e72ebfad364ed6b5dcc75f38ffb2a1fdec83d8e. add some checkpoint in DataFeedDesc Fix fetch variable fix code style in async_executor Fix fetch var add utest for data_feed Add comments for python update utest for data_feed for multithread fix bug in utest of data_feed Add comments for python code Fix pybind.cc with new pybind11 version add note for DataFeedDesc's set_use_slots function update data_feed_test for multi-type Add save_model update data_feed_test for generate test data file removed unnecessary interfaces and add comments add comment for executor_thread_worker Remove unused code update data_feed.cc c++ style check * commit for code style * commit for code style * commit for code style * commit for code style * Comment away __init__ in async_executor.py * clang-format fix test=develop * use PADDLE_THROW instead of exit(-1); use unique_ptr to manage scope var in data_feed_test.cc * commit for update code style * commit for update code style * Add async_executor demo; Remove some methods test=develop * commit for update code style * commit for update code style * commit for update code style * update API.spec * AsyncExecutor test=develop * AsyncExecutor test=develop * AsyncExecutor test=develop * AsyncExecutor test=develop * Fix API.spec test=develop * Fix API.spec test=develop * Fix windows build error test=develop * FIx windows build error test=develop * FIx windows build error test=develop * FIx windows build error test=develop * Fix Windows Build test=develop * Fix Windows Build test=develop * Fix Windows Build test=develop * Fix code style test=develop * Fix code style test=develop * update datafeed * Fix code style test=develop * update data_feed_test for test Tensor test=develop * Fix code style test=develop * Fix windows build failure test=develop * Fix code style and windows build failure test=develop * Fix PYTHON3.5 build failure test=develop * AsyncExecutor API test=developf7c96f079b
							parent
							
								
									78738d6c86
								
							
						
					
					
						commit
						41e19eb431
					
				@ -0,0 +1,138 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/async_executor.h"
 | 
				
			||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
 | 
				
			||||
#include "google/protobuf/message.h"
 | 
				
			||||
#include "google/protobuf/text_format.h"
 | 
				
			||||
 | 
				
			||||
#include "gflags/gflags.h"
 | 
				
			||||
#include "paddle/fluid/framework/data_feed_factory.h"
 | 
				
			||||
#include "paddle/fluid/framework/executor_thread_worker.h"
 | 
				
			||||
#include "paddle/fluid/framework/feed_fetch_method.h"
 | 
				
			||||
#include "paddle/fluid/framework/feed_fetch_type.h"
 | 
				
			||||
#include "paddle/fluid/framework/lod_rank_table.h"
 | 
				
			||||
#include "paddle/fluid/framework/lod_tensor_array.h"
 | 
				
			||||
#include "paddle/fluid/framework/op_registry.h"
 | 
				
			||||
#include "paddle/fluid/framework/reader.h"
 | 
				
			||||
#include "paddle/fluid/inference/io.h"
 | 
				
			||||
#include "paddle/fluid/platform/place.h"
 | 
				
			||||
#include "paddle/fluid/pybind/pybind.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place)
 | 
				
			||||
    : root_scope_(scope), place_(place) {}
 | 
				
			||||
 | 
				
			||||
void AsyncExecutor::CreateThreads(
 | 
				
			||||
    ExecutorThreadWorker* worker, const ProgramDesc& main_program,
 | 
				
			||||
    const std::shared_ptr<DataFeed>& reader,
 | 
				
			||||
    const std::vector<std::string>& fetch_var_names, Scope* root_scope,
 | 
				
			||||
    const int thread_index, const bool debug) {
 | 
				
			||||
  worker->SetThreadId(thread_index);
 | 
				
			||||
  worker->SetDebug(debug);
 | 
				
			||||
  worker->SetRootScope(root_scope);
 | 
				
			||||
  worker->CreateThreadResource(main_program, place_);
 | 
				
			||||
  worker->SetDataFeed(reader);
 | 
				
			||||
  worker->SetFetchVarNames(fetch_var_names);
 | 
				
			||||
  worker->BindingDataFeedMemory();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers,  // NOLINT
 | 
				
			||||
                    const int thread_num, const DataFeedDesc& data_feed_desc,
 | 
				
			||||
                    const std::vector<std::string>& filelist) {
 | 
				
			||||
  readers.resize(thread_num);
 | 
				
			||||
  for (size_t i = 0; i < readers.size(); ++i) {
 | 
				
			||||
    readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
 | 
				
			||||
    readers[i]->Init(data_feed_desc);  // set batch_size and queue_size here
 | 
				
			||||
  }
 | 
				
			||||
  readers[0]->SetFileList(filelist);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
 | 
				
			||||
                                const std::string& data_feed_desc_str,
 | 
				
			||||
                                const std::vector<std::string>& filelist,
 | 
				
			||||
                                const int thread_num,
 | 
				
			||||
                                const std::vector<std::string>& fetch_var_names,
 | 
				
			||||
                                const bool debug) {
 | 
				
			||||
  std::vector<std::thread> threads;
 | 
				
			||||
 | 
				
			||||
  auto& block = main_program.Block(0);
 | 
				
			||||
  for (auto var_name : fetch_var_names) {
 | 
				
			||||
    auto var_desc = block.FindVar(var_name);
 | 
				
			||||
    auto shapes = var_desc->GetShape();
 | 
				
			||||
    PADDLE_ENFORCE(shapes[shapes.size() - 1] == 1,
 | 
				
			||||
                   "var %s: Fetched var has wrong shape, "
 | 
				
			||||
                   "only variables with the last dimension size 1 supported",
 | 
				
			||||
                   var_name);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  DataFeedDesc data_feed_desc;
 | 
				
			||||
  google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
 | 
				
			||||
                                                &data_feed_desc);
 | 
				
			||||
 | 
				
			||||
  int actual_thread_num = thread_num;
 | 
				
			||||
  int file_cnt = filelist.size();
 | 
				
			||||
  PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");
 | 
				
			||||
 | 
				
			||||
  if (actual_thread_num > file_cnt) {
 | 
				
			||||
    VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
 | 
				
			||||
            << ". Changing thread_num = " << file_cnt;
 | 
				
			||||
    actual_thread_num = file_cnt;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  /*
 | 
				
			||||
    readerDesc: protobuf description for reader initlization
 | 
				
			||||
    argument: class_name, batch_size, use_slot, queue_size, buffer_size,
 | 
				
			||||
    padding_index
 | 
				
			||||
 | 
				
			||||
    reader:
 | 
				
			||||
    1) each thread has a reader, reader will read input data and
 | 
				
			||||
    put it into input queue
 | 
				
			||||
    2) each reader has a Next() iterface, that can fetch an instance
 | 
				
			||||
    from the input queue
 | 
				
			||||
   */
 | 
				
			||||
  // todo: should be factory method for creating datafeed
 | 
				
			||||
  std::vector<std::shared_ptr<DataFeed>> readers;
 | 
				
			||||
  PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);
 | 
				
			||||
 | 
				
			||||
  std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
 | 
				
			||||
  workers.resize(actual_thread_num);
 | 
				
			||||
  for (auto& worker : workers) {
 | 
				
			||||
    worker.reset(new ExecutorThreadWorker);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  // prepare thread resource here
 | 
				
			||||
  for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
 | 
				
			||||
    CreateThreads(workers[thidx].get(), main_program, readers[thidx],
 | 
				
			||||
                  fetch_var_names, root_scope_, thidx, debug);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  // start executing ops in multiple threads
 | 
				
			||||
  for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
 | 
				
			||||
    threads.push_back(
 | 
				
			||||
        std::thread(&ExecutorThreadWorker::TrainFiles, workers[thidx].get()));
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  for (auto& th : threads) {
 | 
				
			||||
    th.join();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  root_scope_->DropKids();
 | 
				
			||||
 | 
				
			||||
  return;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // einit_modelnd namespace framework
 | 
				
			||||
}  // end namespace paddle
 | 
				
			||||
@ -0,0 +1,58 @@
 | 
				
			||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
  http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include <map>
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <mutex>  // NOLINT
 | 
				
			||||
#include <set>
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <thread>  // NOLINT
 | 
				
			||||
#include <typeinfo>
 | 
				
			||||
#include <vector>
 | 
				
			||||
#include "paddle/fluid/framework/data_feed.pb.h"
 | 
				
			||||
#include "paddle/fluid/framework/executor.h"
 | 
				
			||||
#include "paddle/fluid/framework/executor_thread_worker.h"
 | 
				
			||||
#include "paddle/fluid/framework/program_desc.h"
 | 
				
			||||
#include "paddle/fluid/framework/scope.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
class AsyncExecutor {
 | 
				
			||||
 public:
 | 
				
			||||
  AsyncExecutor(Scope* scope, const platform::Place& place);
 | 
				
			||||
  virtual ~AsyncExecutor() {}
 | 
				
			||||
  void RunFromFile(const ProgramDesc& main_program,
 | 
				
			||||
                   const std::string& data_feed_desc_str,
 | 
				
			||||
                   const std::vector<std::string>& filelist,
 | 
				
			||||
                   const int thread_num,
 | 
				
			||||
                   const std::vector<std::string>& fetch_names,
 | 
				
			||||
                   const bool debug = false);
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  void CreateThreads(ExecutorThreadWorker* worker,
 | 
				
			||||
                     const ProgramDesc& main_program,
 | 
				
			||||
                     const std::shared_ptr<DataFeed>& reader,
 | 
				
			||||
                     const std::vector<std::string>& fetch_var_names,
 | 
				
			||||
                     Scope* root_scope, const int thread_index,
 | 
				
			||||
                     const bool debug);
 | 
				
			||||
 | 
				
			||||
 public:
 | 
				
			||||
  Scope* root_scope_;
 | 
				
			||||
  platform::Place place_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace framework
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,30 @@
 | 
				
			||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
syntax = "proto2";
 | 
				
			||||
package paddle.framework;
 | 
				
			||||
 | 
				
			||||
message Slot {
 | 
				
			||||
  required string name = 1;
 | 
				
			||||
  required string type = 2;
 | 
				
			||||
  optional bool is_dense = 3 [ default = false ];
 | 
				
			||||
  optional bool is_used = 4 [ default = false ];
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
message MultiSlotDesc { repeated Slot slots = 1; }
 | 
				
			||||
 | 
				
			||||
message DataFeedDesc {
 | 
				
			||||
  optional string name = 1;
 | 
				
			||||
  optional int32 batch_size = 2 [ default = 32 ];
 | 
				
			||||
  optional MultiSlotDesc multi_slot_desc = 3;
 | 
				
			||||
}
 | 
				
			||||
@ -0,0 +1,64 @@
 | 
				
			||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
  http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/data_feed_factory.h"
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <unordered_map>
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/data_feed.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
typedef std::shared_ptr<DataFeed> (*Createdata_feedFunction)();
 | 
				
			||||
typedef std::unordered_map<std::string, Createdata_feedFunction> data_feedMap;
 | 
				
			||||
data_feedMap g_data_feed_map;
 | 
				
			||||
 | 
				
			||||
#define REGISTER_DATAFEED_CLASS(data_feed_class)                      \
 | 
				
			||||
  namespace {                                                         \
 | 
				
			||||
  std::shared_ptr<DataFeed> Creator_##data_feed_class() {             \
 | 
				
			||||
    return std::shared_ptr<DataFeed>(new data_feed_class);            \
 | 
				
			||||
  }                                                                   \
 | 
				
			||||
  class __Registerer_##data_feed_class {                              \
 | 
				
			||||
   public:                                                            \
 | 
				
			||||
    __Registerer_##data_feed_class() {                                \
 | 
				
			||||
      g_data_feed_map[#data_feed_class] = &Creator_##data_feed_class; \
 | 
				
			||||
    }                                                                 \
 | 
				
			||||
  };                                                                  \
 | 
				
			||||
  __Registerer_##data_feed_class g_registerer_##data_feed_class;      \
 | 
				
			||||
  }  // namespace
 | 
				
			||||
 | 
				
			||||
std::string DataFeedFactory::DataFeedTypeList() {
 | 
				
			||||
  std::string data_feed_types;
 | 
				
			||||
  for (auto iter = g_data_feed_map.begin(); iter != g_data_feed_map.end();
 | 
				
			||||
       ++iter) {
 | 
				
			||||
    if (iter != g_data_feed_map.begin()) {
 | 
				
			||||
      data_feed_types += ", ";
 | 
				
			||||
    }
 | 
				
			||||
    data_feed_types += iter->first;
 | 
				
			||||
  }
 | 
				
			||||
  return data_feed_types;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
 | 
				
			||||
    std::string data_feed_class) {
 | 
				
			||||
  if (g_data_feed_map.count(data_feed_class) < 1) {
 | 
				
			||||
    exit(-1);
 | 
				
			||||
  }
 | 
				
			||||
  return g_data_feed_map[data_feed_class]();
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
 | 
				
			||||
}  // namespace framework
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,29 @@
 | 
				
			||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
  http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <string>
 | 
				
			||||
#include "paddle/fluid/framework/data_feed.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
class DataFeedFactory {
 | 
				
			||||
 public:
 | 
				
			||||
  static std::string DataFeedTypeList();
 | 
				
			||||
  static std::shared_ptr<DataFeed> CreateDataFeed(std::string data_feed_class);
 | 
				
			||||
};
 | 
				
			||||
}  // namespace framework
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								@ -0,0 +1,223 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/executor_thread_worker.h"
 | 
				
			||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
 | 
				
			||||
#include "google/protobuf/message.h"
 | 
				
			||||
#include "google/protobuf/text_format.h"
 | 
				
			||||
 | 
				
			||||
#include "gflags/gflags.h"
 | 
				
			||||
#include "paddle/fluid/framework/feed_fetch_method.h"
 | 
				
			||||
#include "paddle/fluid/framework/feed_fetch_type.h"
 | 
				
			||||
#include "paddle/fluid/framework/lod_rank_table.h"
 | 
				
			||||
#include "paddle/fluid/framework/lod_tensor_array.h"
 | 
				
			||||
#include "paddle/fluid/framework/op_registry.h"
 | 
				
			||||
#include "paddle/fluid/framework/reader.h"
 | 
				
			||||
#include "paddle/fluid/framework/variable_helper.h"
 | 
				
			||||
#include "paddle/fluid/inference/io.h"
 | 
				
			||||
#include "paddle/fluid/platform/place.h"
 | 
				
			||||
#include "paddle/fluid/pybind/pybind.h"
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
 | 
				
			||||
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
 | 
				
			||||
  auto& block = program.Block(0);
 | 
				
			||||
  op_names_.clear();
 | 
				
			||||
  for (auto& op_desc : block.AllOps()) {
 | 
				
			||||
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
 | 
				
			||||
    op_names_.push_back(op_desc->Type());
 | 
				
			||||
    OperatorBase* local_op_ptr = local_op.release();
 | 
				
			||||
    ops_.push_back(local_op_ptr);
 | 
				
			||||
    continue;
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void ExecutorThreadWorker::CreateThreadResource(
 | 
				
			||||
    const framework::ProgramDesc& program,
 | 
				
			||||
    const paddle::platform::Place& place) {
 | 
				
			||||
  CreateThreadScope(program);
 | 
				
			||||
  CreateThreadOperators(program);
 | 
				
			||||
  SetMainProgram(program);
 | 
				
			||||
  SetPlace(place);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
 | 
				
			||||
  auto& block = program.Block(0);
 | 
				
			||||
 | 
				
			||||
  PADDLE_ENFORCE_NOT_NULL(
 | 
				
			||||
      root_scope_, "root_scope should be set before creating thread scope");
 | 
				
			||||
 | 
				
			||||
  thread_scope_ = &root_scope_->NewScope();
 | 
				
			||||
  for (auto& var : block.AllVars()) {
 | 
				
			||||
    if (var->Persistable()) {
 | 
				
			||||
      auto* ptr = root_scope_->Var(var->Name());
 | 
				
			||||
      InitializeVariable(ptr, var->GetType());
 | 
				
			||||
    } else {
 | 
				
			||||
      auto* ptr = thread_scope_->Var(var->Name());
 | 
				
			||||
      InitializeVariable(ptr, var->GetType());
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void ExecutorThreadWorker::SetDataFeed(
 | 
				
			||||
    const std::shared_ptr<DataFeed>& datafeed) {
 | 
				
			||||
  thread_reader_ = datafeed;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void ExecutorThreadWorker::BindingDataFeedMemory() {
 | 
				
			||||
  const std::vector<std::string>& input_feed =
 | 
				
			||||
      thread_reader_->GetUseSlotAlias();
 | 
				
			||||
  for (auto name : input_feed) {
 | 
				
			||||
    thread_reader_->AddFeedVar(thread_scope_->Var(name), name);
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void ExecutorThreadWorker::SetFetchVarNames(
 | 
				
			||||
    const std::vector<std::string>& fetch_var_names) {
 | 
				
			||||
  fetch_var_names_.clear();
 | 
				
			||||
  fetch_var_names_.insert(fetch_var_names_.end(), fetch_var_names.begin(),
 | 
				
			||||
                          fetch_var_names.end());
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void ExecutorThreadWorker::SetDevice() {
 | 
				
			||||
#if defined _WIN32 || defined __APPLE__
 | 
				
			||||
  return;
 | 
				
			||||
#else
 | 
				
			||||
  static unsigned concurrency_cap = std::thread::hardware_concurrency();
 | 
				
			||||
  int thread_id = this->thread_id_;
 | 
				
			||||
 | 
				
			||||
  if (thread_id < concurrency_cap) {
 | 
				
			||||
    unsigned proc = thread_id;
 | 
				
			||||
 | 
				
			||||
    cpu_set_t mask;
 | 
				
			||||
    CPU_ZERO(&mask);
 | 
				
			||||
    CPU_SET(proc, &mask);
 | 
				
			||||
 | 
				
			||||
    if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
 | 
				
			||||
      VLOG(1) << "WARNING: Failed to set thread affinity for thread "
 | 
				
			||||
              << thread_id;
 | 
				
			||||
    } else {
 | 
				
			||||
      CPU_ZERO(&mask);
 | 
				
			||||
      if ((0 != sched_getaffinity(0, sizeof(mask), &mask)) ||
 | 
				
			||||
          (CPU_ISSET(proc, &mask) == 0)) {
 | 
				
			||||
        VLOG(3) << "WARNING: Failed to set thread affinity for thread "
 | 
				
			||||
                << thread_id;
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
  } else {
 | 
				
			||||
    VLOG(1) << "WARNING: Failed to set thread affinity for thread "
 | 
				
			||||
            << thread_id;
 | 
				
			||||
  }
 | 
				
			||||
#endif
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
 | 
				
			||||
  auto inspect = lod_tensor.data<T>();
 | 
				
			||||
  auto element_num = lod_tensor.numel();
 | 
				
			||||
 | 
				
			||||
  std::ostringstream sstream;
 | 
				
			||||
  sstream << var_name << " (element num " << element_num << "): [";
 | 
				
			||||
  sstream << inspect[0];
 | 
				
			||||
  for (int j = 1; j < element_num; ++j) {
 | 
				
			||||
    sstream << " " << inspect[j];
 | 
				
			||||
  }
 | 
				
			||||
  sstream << "]";
 | 
				
			||||
 | 
				
			||||
  std::cout << sstream.str() << std::endl;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void print_fetch_var(Scope* scope, std::string var_name) {
 | 
				
			||||
  const LoDTensor& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
 | 
				
			||||
 | 
				
			||||
  if (std::type_index(tensor.type()) ==
 | 
				
			||||
      std::type_index(typeid(platform::float16))) {
 | 
				
			||||
    print_lod_tensor<platform::float16>(var_name, tensor);
 | 
				
			||||
  } else if (std::type_index(tensor.type()) == std::type_index(typeid(float))) {
 | 
				
			||||
    print_lod_tensor<float>(var_name, tensor);
 | 
				
			||||
  } else if (std::type_index(tensor.type()) ==
 | 
				
			||||
             std::type_index(typeid(double))) {
 | 
				
			||||
    print_lod_tensor<double>(var_name, tensor);
 | 
				
			||||
  } else if (std::type_index(tensor.type()) == std::type_index(typeid(int))) {
 | 
				
			||||
    print_lod_tensor<int>(var_name, tensor);
 | 
				
			||||
  } else if (std::type_index(tensor.type()) ==
 | 
				
			||||
             std::type_index(typeid(int64_t))) {
 | 
				
			||||
    print_lod_tensor<int64_t>(var_name, tensor);
 | 
				
			||||
  } else if (std::type_index(tensor.type()) == std::type_index(typeid(bool))) {
 | 
				
			||||
    print_lod_tensor<bool>(var_name, tensor);
 | 
				
			||||
  } else if (std::type_index(tensor.type()) ==
 | 
				
			||||
             std::type_index(typeid(uint8_t))) {
 | 
				
			||||
    print_lod_tensor<uint8_t>(var_name, tensor);
 | 
				
			||||
  } else if (std::type_index(tensor.type()) ==
 | 
				
			||||
             std::type_index(typeid(int16_t))) {
 | 
				
			||||
    print_lod_tensor<int16_t>(var_name, tensor);
 | 
				
			||||
  } else if (std::type_index(tensor.type()) ==
 | 
				
			||||
             std::type_index(typeid(int8_t))) {
 | 
				
			||||
    print_lod_tensor<int8_t>(var_name, tensor);
 | 
				
			||||
  } else {
 | 
				
			||||
    VLOG(1) << "print_fetch_var: unrecognized data type:"
 | 
				
			||||
            << tensor.type().name();
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  return;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void ExecutorThreadWorker::TrainFiles() {
 | 
				
			||||
  // todo: configurable
 | 
				
			||||
  SetDevice();
 | 
				
			||||
 | 
				
			||||
  int fetch_var_num = fetch_var_names_.size();
 | 
				
			||||
  fetch_values_.clear();
 | 
				
			||||
  fetch_values_.resize(fetch_var_num);
 | 
				
			||||
 | 
				
			||||
  thread_reader_->Start();
 | 
				
			||||
 | 
				
			||||
  int cur_batch;
 | 
				
			||||
  int batch_cnt = 0;
 | 
				
			||||
  while ((cur_batch = thread_reader_->Next()) > 0) {
 | 
				
			||||
    // executor run here
 | 
				
			||||
    for (auto& op : ops_) {
 | 
				
			||||
      op->Run(*thread_scope_, place_);
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
    ++batch_cnt;
 | 
				
			||||
    thread_scope_->DropKids();
 | 
				
			||||
 | 
				
			||||
    if (debug_ == false || thread_id_ != 0) {
 | 
				
			||||
      continue;
 | 
				
			||||
    }
 | 
				
			||||
 | 
				
			||||
    for (int i = 0; i < fetch_var_num; ++i) {
 | 
				
			||||
      print_fetch_var(thread_scope_, fetch_var_names_[i]);
 | 
				
			||||
    }  // end for (int i = 0...)
 | 
				
			||||
  }    // end while ()
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void ExecutorThreadWorker::SetThreadId(int tid) { thread_id_ = tid; }
 | 
				
			||||
 | 
				
			||||
void ExecutorThreadWorker::SetPlace(const platform::Place& place) {
 | 
				
			||||
  place_ = place;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void ExecutorThreadWorker::SetMainProgram(
 | 
				
			||||
    const ProgramDesc& main_program_desc) {
 | 
				
			||||
  main_program_.reset(new ProgramDesc(main_program_desc));
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
 | 
				
			||||
  root_scope_ = g_scope;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // einit_modelnd namespace framework
 | 
				
			||||
}  // end namespace paddle
 | 
				
			||||
@ -0,0 +1,88 @@
 | 
				
			||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
  http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include <map>
 | 
				
			||||
#include <memory>
 | 
				
			||||
#include <mutex>  // NOLINT
 | 
				
			||||
#include <set>
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <thread>  // NOLINT
 | 
				
			||||
#include <vector>
 | 
				
			||||
#include "paddle/fluid/framework/data_feed.h"
 | 
				
			||||
#include "paddle/fluid/framework/executor.h"
 | 
				
			||||
#include "paddle/fluid/framework/program_desc.h"
 | 
				
			||||
#include "paddle/fluid/framework/scope.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
void CreateTensor(Variable* var, proto::VarType::Type var_type);
 | 
				
			||||
 | 
				
			||||
class ExecutorThreadWorker {
 | 
				
			||||
 public:
 | 
				
			||||
  ExecutorThreadWorker()
 | 
				
			||||
      : thread_id_(-1), root_scope_(NULL), thread_scope_(NULL), debug_(false) {}
 | 
				
			||||
  ~ExecutorThreadWorker() {}
 | 
				
			||||
 | 
				
			||||
  void CreateThreadResource(const framework::ProgramDesc& program,
 | 
				
			||||
                            const paddle::platform::Place& place);
 | 
				
			||||
  void SetThreadId(int tid);
 | 
				
			||||
  void SetDebug(const bool debug) { debug_ = debug; }
 | 
				
			||||
  void SetRootScope(Scope* g_scope);
 | 
				
			||||
  // set cpu device in this function
 | 
				
			||||
  // cpu binding is used by default
 | 
				
			||||
  void SetDevice();
 | 
				
			||||
  // since we read data into memory that can not be accessed by program
 | 
				
			||||
  // we need to bind memory of data with corresponding variables in program
 | 
				
			||||
  // this function should be called after data feed is set
 | 
				
			||||
  void BindingDataFeedMemory();
 | 
				
			||||
  // set data feed declared in executor
 | 
				
			||||
  void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
 | 
				
			||||
  // A multi-thread training function
 | 
				
			||||
  void TrainFiles();
 | 
				
			||||
  // set fetch variable names from python interface assigned by users
 | 
				
			||||
  void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  void CreateThreadScope(const framework::ProgramDesc& program);
 | 
				
			||||
  void CreateThreadOperators(const framework::ProgramDesc& program);
 | 
				
			||||
  void SetMainProgram(const ProgramDesc& main_program_desc);
 | 
				
			||||
  void SetPlace(const paddle::platform::Place& place);
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  // thread index
 | 
				
			||||
  std::shared_ptr<DataFeed> thread_reader_;  // shared queue, thread buffer
 | 
				
			||||
  int thread_id_;
 | 
				
			||||
  // operator name
 | 
				
			||||
  std::vector<std::string> op_names_;
 | 
				
			||||
  // thread level, local operators for forward and backward
 | 
				
			||||
  std::vector<OperatorBase*> ops_;
 | 
				
			||||
  // main program for training
 | 
				
			||||
  std::unique_ptr<framework::ProgramDesc> main_program_;
 | 
				
			||||
  // execution place
 | 
				
			||||
  platform::Place place_;
 | 
				
			||||
  // root scope for model parameters
 | 
				
			||||
  Scope* root_scope_;
 | 
				
			||||
  // a thread scope, father scope is global score which is shared
 | 
				
			||||
  Scope* thread_scope_;
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  std::vector<std::string> fetch_var_names_;
 | 
				
			||||
  std::vector<std::vector<float>> fetch_values_;
 | 
				
			||||
  bool debug_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace framework
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,60 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/variable_helper.h"
 | 
				
			||||
 | 
				
			||||
#include <vector>
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/feed_fetch_type.h"
 | 
				
			||||
#include "paddle/fluid/framework/lod_rank_table.h"
 | 
				
			||||
#include "paddle/fluid/framework/lod_tensor.h"
 | 
				
			||||
#include "paddle/fluid/framework/lod_tensor_array.h"
 | 
				
			||||
#include "paddle/fluid/framework/reader.h"
 | 
				
			||||
#include "paddle/fluid/framework/scope.h"
 | 
				
			||||
#include "paddle/fluid/framework/selected_rows.h"
 | 
				
			||||
#include "paddle/fluid/platform/place.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
 | 
				
			||||
  if (var_type == proto::VarType::LOD_TENSOR) {
 | 
				
			||||
    var->GetMutable<LoDTensor>();
 | 
				
			||||
  } else if (var_type == proto::VarType::SELECTED_ROWS) {
 | 
				
			||||
    var->GetMutable<SelectedRows>();
 | 
				
			||||
  } else if (var_type == proto::VarType::FEED_MINIBATCH) {
 | 
				
			||||
    var->GetMutable<FeedFetchList>();
 | 
				
			||||
  } else if (var_type == proto::VarType::FETCH_LIST) {
 | 
				
			||||
    var->GetMutable<FeedFetchList>();
 | 
				
			||||
  } else if (var_type == proto::VarType::STEP_SCOPES) {
 | 
				
			||||
    var->GetMutable<std::vector<framework::Scope*>>();
 | 
				
			||||
  } else if (var_type == proto::VarType::LOD_RANK_TABLE) {
 | 
				
			||||
    var->GetMutable<LoDRankTable>();
 | 
				
			||||
  } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
 | 
				
			||||
    var->GetMutable<LoDTensorArray>();
 | 
				
			||||
  } else if (var_type == proto::VarType::PLACE_LIST) {
 | 
				
			||||
    var->GetMutable<platform::PlaceList>();
 | 
				
			||||
  } else if (var_type == proto::VarType::READER) {
 | 
				
			||||
    var->GetMutable<ReaderHolder>();
 | 
				
			||||
  } else if (var_type == proto::VarType::RAW) {
 | 
				
			||||
    // GetMutable will be called in operator
 | 
				
			||||
  } else {
 | 
				
			||||
    PADDLE_THROW(
 | 
				
			||||
        "Variable type %d is not in "
 | 
				
			||||
        "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
 | 
				
			||||
        "LOD_RANK_TABLE, PLACE_LIST, READER, RAW]",
 | 
				
			||||
        var_type);
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
}  // namespace framework
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,22 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/framework.pb.h"
 | 
				
			||||
#include "paddle/fluid/framework/variable.h"
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
void InitializeVariable(Variable *var, proto::VarType::Type var_type);
 | 
				
			||||
}
 | 
				
			||||
}
 | 
				
			||||
@ -0,0 +1,53 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
#include <fcntl.h>
 | 
				
			||||
 | 
				
			||||
// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3)
 | 
				
			||||
#ifdef _POSIX_C_SOURCE
 | 
				
			||||
#undef _POSIX_C_SOURCE
 | 
				
			||||
#endif
 | 
				
			||||
 | 
				
			||||
#ifdef _XOPEN_SOURCE
 | 
				
			||||
#undef _XOPEN_SOURCE
 | 
				
			||||
#endif
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <vector>
 | 
				
			||||
 | 
				
			||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
 | 
				
			||||
#include "google/protobuf/text_format.h"
 | 
				
			||||
#include "paddle/fluid/framework/async_executor.h"
 | 
				
			||||
#include "paddle/fluid/framework/data_feed.h"
 | 
				
			||||
#include "paddle/fluid/framework/data_feed.pb.h"
 | 
				
			||||
#include "paddle/fluid/framework/scope.h"
 | 
				
			||||
#include "paddle/fluid/inference/io.h"
 | 
				
			||||
#include "paddle/fluid/platform/place.h"
 | 
				
			||||
#include "paddle/fluid/platform/variant.h"
 | 
				
			||||
#include "paddle/fluid/pybind/async_executor_py.h"
 | 
				
			||||
 | 
				
			||||
namespace py = pybind11;
 | 
				
			||||
namespace pd = paddle::framework;
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace pybind {
 | 
				
			||||
using set_name_func = void (pd::DataFeedDesc::*)(const std::string&);
 | 
				
			||||
void BindAsyncExecutor(py::module* m) {
 | 
				
			||||
  py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
 | 
				
			||||
      .def(py::init([](framework::Scope* scope, const platform::Place& place) {
 | 
				
			||||
        return std::unique_ptr<framework::AsyncExecutor>(
 | 
				
			||||
            new framework::AsyncExecutor(scope, place));
 | 
				
			||||
      }))
 | 
				
			||||
      .def("run_from_files", &framework::AsyncExecutor::RunFromFile);
 | 
				
			||||
}  // end BindAsyncExecutor
 | 
				
			||||
}  // end namespace pybind
 | 
				
			||||
}  // end namespace paddle
 | 
				
			||||
@ -0,0 +1,28 @@
 | 
				
			||||
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
//
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
//
 | 
				
			||||
//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
//
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include "pybind11/pybind11.h"
 | 
				
			||||
#include "pybind11/stl.h"
 | 
				
			||||
 | 
				
			||||
namespace py = pybind11;
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace pybind {
 | 
				
			||||
 | 
				
			||||
void BindAsyncExecutor(py::module* m);
 | 
				
			||||
 | 
				
			||||
}  // namespace pybind
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,151 @@
 | 
				
			||||
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
 | 
				
			||||
from __future__ import print_function
 | 
				
			||||
 | 
				
			||||
import numpy as np
 | 
				
			||||
import contextlib
 | 
				
			||||
import six
 | 
				
			||||
from .framework import Program, default_main_program, Variable
 | 
				
			||||
from . import core
 | 
				
			||||
from .executor import global_scope, Executor
 | 
				
			||||
from paddle.fluid.proto import data_feed_pb2
 | 
				
			||||
from google.protobuf import text_format
 | 
				
			||||
from . import io
 | 
				
			||||
from .data_feed_desc import DataFeedDesc
 | 
				
			||||
 | 
				
			||||
__all__ = ['AsyncExecutor']
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class AsyncExecutor(object):
 | 
				
			||||
    """
 | 
				
			||||
    An asynchronous Executor in Python. Through exploiting the power of
 | 
				
			||||
    multi-core processor and data queueing, AsyncExecutor makes data reading
 | 
				
			||||
    and cosuming decoupled, each run in multiple threads in parallel.
 | 
				
			||||
 | 
				
			||||
    Instead of reading data in python side, AsyncExecutor accepts a training
 | 
				
			||||
    file list, which will be retrieved in C++, then training inputs will be
 | 
				
			||||
    read, parsed and fed to training network within C++ code.
 | 
				
			||||
 | 
				
			||||
    AsyncExecutor is in active development and the API might change in the near
 | 
				
			||||
    future.
 | 
				
			||||
 | 
				
			||||
    Example:
 | 
				
			||||
        >>> data_feed = fluid.DataFeedDesc('data.proto')
 | 
				
			||||
        >>> startup_program = fluid.default_startup_program()
 | 
				
			||||
        >>> main_program = fluid.default_main_program()
 | 
				
			||||
        >>> filelist = ["train_data/part-%d" % i for i in range(100)]
 | 
				
			||||
        >>> thread_num = len(filelist) / 4
 | 
				
			||||
        >>>
 | 
				
			||||
        >>> place = fluid.CPUPlace()
 | 
				
			||||
        >>> async_executor = fluid.AsyncExecutor(place)
 | 
				
			||||
        >>>
 | 
				
			||||
        >>> async_executor.run_startup_program(startup_program)
 | 
				
			||||
        >>>
 | 
				
			||||
        >>> epoch = 10
 | 
				
			||||
        >>> for i in range(epoch):
 | 
				
			||||
        >>>     async_executor.run(main_program,
 | 
				
			||||
        >>>                        data_feed,
 | 
				
			||||
        >>>                        filelist,
 | 
				
			||||
        >>>                        thread_num,
 | 
				
			||||
        >>>                        [acc],
 | 
				
			||||
        >>>                        debug=False)
 | 
				
			||||
 | 
				
			||||
    Args:
 | 
				
			||||
        place(fluid.CPUPlace|None): indicate the executor run on which device.
 | 
				
			||||
                                   Only CPUPlace supported
 | 
				
			||||
 | 
				
			||||
    Note:
 | 
				
			||||
        For debugging complicated network in parallel-GPUs, you can test it
 | 
				
			||||
        on the executor. They has the exactly same arguments, and expected
 | 
				
			||||
        the same results.
 | 
				
			||||
 | 
				
			||||
    Note: Only running on CPUPlace supported.
 | 
				
			||||
    """
 | 
				
			||||
 | 
				
			||||
    def __init__(self, place=None):
 | 
				
			||||
        if place is None:
 | 
				
			||||
            place = core.CPUPlace()
 | 
				
			||||
        if not isinstance(place, core.CPUPlace):
 | 
				
			||||
            raise ValueError("AsyncExecutor only supports CPU device")
 | 
				
			||||
 | 
				
			||||
        p = core.Place()
 | 
				
			||||
        p.set_place(place)
 | 
				
			||||
 | 
				
			||||
        scope = global_scope()
 | 
				
			||||
        self.executor = core.AsyncExecutor(scope, p)
 | 
				
			||||
 | 
				
			||||
    def run(self, program, data_feed, filelist, thread_num, fetch, debug=False):
 | 
				
			||||
        """
 | 
				
			||||
        Run program by this AsyncExecutor. Training dataset will be in filelist.
 | 
				
			||||
        Users can also inspect certain variables by naming them in parameter
 | 
				
			||||
        :code:`fetch`, like in fluid.Executor. Unlike fluid.Executor, however,
 | 
				
			||||
        AsyncExecutor doesn't return fetched variables, instead, it will dump
 | 
				
			||||
        the values of each fetched variable to stdandard output.
 | 
				
			||||
 | 
				
			||||
        Running the dataset will be on multiple threads, within each a thread
 | 
				
			||||
        local scope will be created, then all OPs also created in that scope.
 | 
				
			||||
        Parameters are updated by all the OPs simultaneously.
 | 
				
			||||
 | 
				
			||||
        Args:
 | 
				
			||||
            program(Program): the program that need to run, if not provied,
 | 
				
			||||
                              then default_main_program will be used.
 | 
				
			||||
            data_feed(DataFeedDesc): A DataFeedDesc object
 | 
				
			||||
            filelist(str): a file containing the training dataset file list
 | 
				
			||||
            thread_num(int): number of concurrent training threads. See
 | 
				
			||||
                             :code:`Note` for how to set this properly
 | 
				
			||||
            fetch(str|list): the var name or a list of var names to inspect
 | 
				
			||||
            debug(bool): When set to True, fetch vars will be printed to
 | 
				
			||||
                         standard output after each minibatch
 | 
				
			||||
 | 
				
			||||
        Note:
 | 
				
			||||
            the executor will run all operators in the program but not only
 | 
				
			||||
            the operators dependent by the fetch_list.
 | 
				
			||||
 | 
				
			||||
        Note:
 | 
				
			||||
            Running AsyncExecutor will be on multiple threads, each bound to a
 | 
				
			||||
            CPU core. To achieve best performance, it's suggested to set thread
 | 
				
			||||
            num to be equal or slightly less than that of CPU cores.
 | 
				
			||||
        """
 | 
				
			||||
        if program is None:
 | 
				
			||||
            program = default_main_program()
 | 
				
			||||
        program_desc = program.desc
 | 
				
			||||
 | 
				
			||||
        if data_feed is None:
 | 
				
			||||
            raise ValueError('ValueError: data_feed should be provided')
 | 
				
			||||
 | 
				
			||||
        if filelist is None:
 | 
				
			||||
            raise ValueError('ValueError: filelist should be provided')
 | 
				
			||||
 | 
				
			||||
        if isinstance(filelist, str):
 | 
				
			||||
            filelist = [filelist]
 | 
				
			||||
 | 
				
			||||
        if not isinstance(thread_num, int):
 | 
				
			||||
            raise TypeError('TypeError: thread_num should be a positive number')
 | 
				
			||||
 | 
				
			||||
        if fetch is not None:
 | 
				
			||||
            if isinstance(fetch, Variable):
 | 
				
			||||
                fetch = [fetch]
 | 
				
			||||
            fetch_var_names = [var.name for var in fetch]
 | 
				
			||||
            for fetch_var in fetch:
 | 
				
			||||
                shape = fetch_var.shape
 | 
				
			||||
                if shape[len(shape) - 1] != 1:
 | 
				
			||||
                    raise AssertionError(
 | 
				
			||||
                        "%s: Fetch variable has wrong shape. Only varibles "
 | 
				
			||||
                        "with the last dimension size 1 supported." %
 | 
				
			||||
                        (fetch_var.name))
 | 
				
			||||
 | 
				
			||||
        self.executor.run_from_files(program_desc,
 | 
				
			||||
                                     data_feed.desc(), filelist, thread_num,
 | 
				
			||||
                                     fetch_var_names, debug)
 | 
				
			||||
Some files were not shown because too many files have changed in this diff Show More
					Loading…
					
					
				
		Reference in new issue