Merge branch 'develop' of https://github.com/paddlepaddle/paddle into add_prelu_gpu
test=developrevert-14398-imperative
commit
e7abe6b654
@ -0,0 +1,138 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/async_executor.h"
|
||||||
|
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||||
|
#include "google/protobuf/message.h"
|
||||||
|
#include "google/protobuf/text_format.h"
|
||||||
|
|
||||||
|
#include "gflags/gflags.h"
|
||||||
|
#include "paddle/fluid/framework/data_feed_factory.h"
|
||||||
|
#include "paddle/fluid/framework/executor_thread_worker.h"
|
||||||
|
#include "paddle/fluid/framework/feed_fetch_method.h"
|
||||||
|
#include "paddle/fluid/framework/feed_fetch_type.h"
|
||||||
|
#include "paddle/fluid/framework/lod_rank_table.h"
|
||||||
|
#include "paddle/fluid/framework/lod_tensor_array.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/framework/reader.h"
|
||||||
|
#include "paddle/fluid/inference/io.h"
|
||||||
|
#include "paddle/fluid/platform/place.h"
|
||||||
|
#include "paddle/fluid/pybind/pybind.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place)
|
||||||
|
: root_scope_(scope), place_(place) {}
|
||||||
|
|
||||||
|
void AsyncExecutor::CreateThreads(
|
||||||
|
ExecutorThreadWorker* worker, const ProgramDesc& main_program,
|
||||||
|
const std::shared_ptr<DataFeed>& reader,
|
||||||
|
const std::vector<std::string>& fetch_var_names, Scope* root_scope,
|
||||||
|
const int thread_index, const bool debug) {
|
||||||
|
worker->SetThreadId(thread_index);
|
||||||
|
worker->SetDebug(debug);
|
||||||
|
worker->SetRootScope(root_scope);
|
||||||
|
worker->CreateThreadResource(main_program, place_);
|
||||||
|
worker->SetDataFeed(reader);
|
||||||
|
worker->SetFetchVarNames(fetch_var_names);
|
||||||
|
worker->BindingDataFeedMemory();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
|
||||||
|
const int thread_num, const DataFeedDesc& data_feed_desc,
|
||||||
|
const std::vector<std::string>& filelist) {
|
||||||
|
readers.resize(thread_num);
|
||||||
|
for (size_t i = 0; i < readers.size(); ++i) {
|
||||||
|
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
|
||||||
|
readers[i]->Init(data_feed_desc); // set batch_size and queue_size here
|
||||||
|
}
|
||||||
|
readers[0]->SetFileList(filelist);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
|
||||||
|
const std::string& data_feed_desc_str,
|
||||||
|
const std::vector<std::string>& filelist,
|
||||||
|
const int thread_num,
|
||||||
|
const std::vector<std::string>& fetch_var_names,
|
||||||
|
const bool debug) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
|
||||||
|
auto& block = main_program.Block(0);
|
||||||
|
for (auto var_name : fetch_var_names) {
|
||||||
|
auto var_desc = block.FindVar(var_name);
|
||||||
|
auto shapes = var_desc->GetShape();
|
||||||
|
PADDLE_ENFORCE(shapes[shapes.size() - 1] == 1,
|
||||||
|
"var %s: Fetched var has wrong shape, "
|
||||||
|
"only variables with the last dimension size 1 supported",
|
||||||
|
var_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
DataFeedDesc data_feed_desc;
|
||||||
|
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
|
||||||
|
&data_feed_desc);
|
||||||
|
|
||||||
|
int actual_thread_num = thread_num;
|
||||||
|
int file_cnt = filelist.size();
|
||||||
|
PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");
|
||||||
|
|
||||||
|
if (actual_thread_num > file_cnt) {
|
||||||
|
VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
|
||||||
|
<< ". Changing thread_num = " << file_cnt;
|
||||||
|
actual_thread_num = file_cnt;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
readerDesc: protobuf description for reader initlization
|
||||||
|
argument: class_name, batch_size, use_slot, queue_size, buffer_size,
|
||||||
|
padding_index
|
||||||
|
|
||||||
|
reader:
|
||||||
|
1) each thread has a reader, reader will read input data and
|
||||||
|
put it into input queue
|
||||||
|
2) each reader has a Next() iterface, that can fetch an instance
|
||||||
|
from the input queue
|
||||||
|
*/
|
||||||
|
// todo: should be factory method for creating datafeed
|
||||||
|
std::vector<std::shared_ptr<DataFeed>> readers;
|
||||||
|
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
|
||||||
|
workers.resize(actual_thread_num);
|
||||||
|
for (auto& worker : workers) {
|
||||||
|
worker.reset(new ExecutorThreadWorker);
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare thread resource here
|
||||||
|
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
|
||||||
|
CreateThreads(workers[thidx].get(), main_program, readers[thidx],
|
||||||
|
fetch_var_names, root_scope_, thidx, debug);
|
||||||
|
}
|
||||||
|
|
||||||
|
// start executing ops in multiple threads
|
||||||
|
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
|
||||||
|
threads.push_back(
|
||||||
|
std::thread(&ExecutorThreadWorker::TrainFiles, workers[thidx].get()));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& th : threads) {
|
||||||
|
th.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
root_scope_->DropKids();
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // einit_modelnd namespace framework
|
||||||
|
} // end namespace paddle
|
@ -0,0 +1,58 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex> // NOLINT
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <thread> // NOLINT
|
||||||
|
#include <typeinfo>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/data_feed.pb.h"
|
||||||
|
#include "paddle/fluid/framework/executor.h"
|
||||||
|
#include "paddle/fluid/framework/executor_thread_worker.h"
|
||||||
|
#include "paddle/fluid/framework/program_desc.h"
|
||||||
|
#include "paddle/fluid/framework/scope.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
class AsyncExecutor {
|
||||||
|
public:
|
||||||
|
AsyncExecutor(Scope* scope, const platform::Place& place);
|
||||||
|
virtual ~AsyncExecutor() {}
|
||||||
|
void RunFromFile(const ProgramDesc& main_program,
|
||||||
|
const std::string& data_feed_desc_str,
|
||||||
|
const std::vector<std::string>& filelist,
|
||||||
|
const int thread_num,
|
||||||
|
const std::vector<std::string>& fetch_names,
|
||||||
|
const bool debug = false);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void CreateThreads(ExecutorThreadWorker* worker,
|
||||||
|
const ProgramDesc& main_program,
|
||||||
|
const std::shared_ptr<DataFeed>& reader,
|
||||||
|
const std::vector<std::string>& fetch_var_names,
|
||||||
|
Scope* root_scope, const int thread_index,
|
||||||
|
const bool debug);
|
||||||
|
|
||||||
|
public:
|
||||||
|
Scope* root_scope_;
|
||||||
|
platform::Place place_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
syntax = "proto2";
|
||||||
|
package paddle.framework;
|
||||||
|
|
||||||
|
message Slot {
|
||||||
|
required string name = 1;
|
||||||
|
required string type = 2;
|
||||||
|
optional bool is_dense = 3 [ default = false ];
|
||||||
|
optional bool is_used = 4 [ default = false ];
|
||||||
|
}
|
||||||
|
|
||||||
|
message MultiSlotDesc { repeated Slot slots = 1; }
|
||||||
|
|
||||||
|
message DataFeedDesc {
|
||||||
|
optional string name = 1;
|
||||||
|
optional int32 batch_size = 2 [ default = 32 ];
|
||||||
|
optional MultiSlotDesc multi_slot_desc = 3;
|
||||||
|
}
|
@ -0,0 +1,64 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/data_feed_factory.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/data_feed.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
typedef std::shared_ptr<DataFeed> (*Createdata_feedFunction)();
|
||||||
|
typedef std::unordered_map<std::string, Createdata_feedFunction> data_feedMap;
|
||||||
|
data_feedMap g_data_feed_map;
|
||||||
|
|
||||||
|
#define REGISTER_DATAFEED_CLASS(data_feed_class) \
|
||||||
|
namespace { \
|
||||||
|
std::shared_ptr<DataFeed> Creator_##data_feed_class() { \
|
||||||
|
return std::shared_ptr<DataFeed>(new data_feed_class); \
|
||||||
|
} \
|
||||||
|
class __Registerer_##data_feed_class { \
|
||||||
|
public: \
|
||||||
|
__Registerer_##data_feed_class() { \
|
||||||
|
g_data_feed_map[#data_feed_class] = &Creator_##data_feed_class; \
|
||||||
|
} \
|
||||||
|
}; \
|
||||||
|
__Registerer_##data_feed_class g_registerer_##data_feed_class; \
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::string DataFeedFactory::DataFeedTypeList() {
|
||||||
|
std::string data_feed_types;
|
||||||
|
for (auto iter = g_data_feed_map.begin(); iter != g_data_feed_map.end();
|
||||||
|
++iter) {
|
||||||
|
if (iter != g_data_feed_map.begin()) {
|
||||||
|
data_feed_types += ", ";
|
||||||
|
}
|
||||||
|
data_feed_types += iter->first;
|
||||||
|
}
|
||||||
|
return data_feed_types;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
|
||||||
|
std::string data_feed_class) {
|
||||||
|
if (g_data_feed_map.count(data_feed_class) < 1) {
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
return g_data_feed_map[data_feed_class]();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,29 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include "paddle/fluid/framework/data_feed.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
class DataFeedFactory {
|
||||||
|
public:
|
||||||
|
static std::string DataFeedTypeList();
|
||||||
|
static std::shared_ptr<DataFeed> CreateDataFeed(std::string data_feed_class);
|
||||||
|
};
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,223 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/executor_thread_worker.h"
|
||||||
|
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||||
|
#include "google/protobuf/message.h"
|
||||||
|
#include "google/protobuf/text_format.h"
|
||||||
|
|
||||||
|
#include "gflags/gflags.h"
|
||||||
|
#include "paddle/fluid/framework/feed_fetch_method.h"
|
||||||
|
#include "paddle/fluid/framework/feed_fetch_type.h"
|
||||||
|
#include "paddle/fluid/framework/lod_rank_table.h"
|
||||||
|
#include "paddle/fluid/framework/lod_tensor_array.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/framework/reader.h"
|
||||||
|
#include "paddle/fluid/framework/variable_helper.h"
|
||||||
|
#include "paddle/fluid/inference/io.h"
|
||||||
|
#include "paddle/fluid/platform/place.h"
|
||||||
|
#include "paddle/fluid/pybind/pybind.h"
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
|
||||||
|
auto& block = program.Block(0);
|
||||||
|
op_names_.clear();
|
||||||
|
for (auto& op_desc : block.AllOps()) {
|
||||||
|
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
|
||||||
|
op_names_.push_back(op_desc->Type());
|
||||||
|
OperatorBase* local_op_ptr = local_op.release();
|
||||||
|
ops_.push_back(local_op_ptr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecutorThreadWorker::CreateThreadResource(
|
||||||
|
const framework::ProgramDesc& program,
|
||||||
|
const paddle::platform::Place& place) {
|
||||||
|
CreateThreadScope(program);
|
||||||
|
CreateThreadOperators(program);
|
||||||
|
SetMainProgram(program);
|
||||||
|
SetPlace(place);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
|
||||||
|
auto& block = program.Block(0);
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(
|
||||||
|
root_scope_, "root_scope should be set before creating thread scope");
|
||||||
|
|
||||||
|
thread_scope_ = &root_scope_->NewScope();
|
||||||
|
for (auto& var : block.AllVars()) {
|
||||||
|
if (var->Persistable()) {
|
||||||
|
auto* ptr = root_scope_->Var(var->Name());
|
||||||
|
InitializeVariable(ptr, var->GetType());
|
||||||
|
} else {
|
||||||
|
auto* ptr = thread_scope_->Var(var->Name());
|
||||||
|
InitializeVariable(ptr, var->GetType());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecutorThreadWorker::SetDataFeed(
|
||||||
|
const std::shared_ptr<DataFeed>& datafeed) {
|
||||||
|
thread_reader_ = datafeed;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecutorThreadWorker::BindingDataFeedMemory() {
|
||||||
|
const std::vector<std::string>& input_feed =
|
||||||
|
thread_reader_->GetUseSlotAlias();
|
||||||
|
for (auto name : input_feed) {
|
||||||
|
thread_reader_->AddFeedVar(thread_scope_->Var(name), name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecutorThreadWorker::SetFetchVarNames(
|
||||||
|
const std::vector<std::string>& fetch_var_names) {
|
||||||
|
fetch_var_names_.clear();
|
||||||
|
fetch_var_names_.insert(fetch_var_names_.end(), fetch_var_names.begin(),
|
||||||
|
fetch_var_names.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecutorThreadWorker::SetDevice() {
|
||||||
|
#if defined _WIN32 || defined __APPLE__
|
||||||
|
return;
|
||||||
|
#else
|
||||||
|
static unsigned concurrency_cap = std::thread::hardware_concurrency();
|
||||||
|
int thread_id = this->thread_id_;
|
||||||
|
|
||||||
|
if (thread_id < concurrency_cap) {
|
||||||
|
unsigned proc = thread_id;
|
||||||
|
|
||||||
|
cpu_set_t mask;
|
||||||
|
CPU_ZERO(&mask);
|
||||||
|
CPU_SET(proc, &mask);
|
||||||
|
|
||||||
|
if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
|
||||||
|
VLOG(1) << "WARNING: Failed to set thread affinity for thread "
|
||||||
|
<< thread_id;
|
||||||
|
} else {
|
||||||
|
CPU_ZERO(&mask);
|
||||||
|
if ((0 != sched_getaffinity(0, sizeof(mask), &mask)) ||
|
||||||
|
(CPU_ISSET(proc, &mask) == 0)) {
|
||||||
|
VLOG(3) << "WARNING: Failed to set thread affinity for thread "
|
||||||
|
<< thread_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
VLOG(1) << "WARNING: Failed to set thread affinity for thread "
|
||||||
|
<< thread_id;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
|
||||||
|
auto inspect = lod_tensor.data<T>();
|
||||||
|
auto element_num = lod_tensor.numel();
|
||||||
|
|
||||||
|
std::ostringstream sstream;
|
||||||
|
sstream << var_name << " (element num " << element_num << "): [";
|
||||||
|
sstream << inspect[0];
|
||||||
|
for (int j = 1; j < element_num; ++j) {
|
||||||
|
sstream << " " << inspect[j];
|
||||||
|
}
|
||||||
|
sstream << "]";
|
||||||
|
|
||||||
|
std::cout << sstream.str() << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_fetch_var(Scope* scope, std::string var_name) {
|
||||||
|
const LoDTensor& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
|
||||||
|
|
||||||
|
if (std::type_index(tensor.type()) ==
|
||||||
|
std::type_index(typeid(platform::float16))) {
|
||||||
|
print_lod_tensor<platform::float16>(var_name, tensor);
|
||||||
|
} else if (std::type_index(tensor.type()) == std::type_index(typeid(float))) {
|
||||||
|
print_lod_tensor<float>(var_name, tensor);
|
||||||
|
} else if (std::type_index(tensor.type()) ==
|
||||||
|
std::type_index(typeid(double))) {
|
||||||
|
print_lod_tensor<double>(var_name, tensor);
|
||||||
|
} else if (std::type_index(tensor.type()) == std::type_index(typeid(int))) {
|
||||||
|
print_lod_tensor<int>(var_name, tensor);
|
||||||
|
} else if (std::type_index(tensor.type()) ==
|
||||||
|
std::type_index(typeid(int64_t))) {
|
||||||
|
print_lod_tensor<int64_t>(var_name, tensor);
|
||||||
|
} else if (std::type_index(tensor.type()) == std::type_index(typeid(bool))) {
|
||||||
|
print_lod_tensor<bool>(var_name, tensor);
|
||||||
|
} else if (std::type_index(tensor.type()) ==
|
||||||
|
std::type_index(typeid(uint8_t))) {
|
||||||
|
print_lod_tensor<uint8_t>(var_name, tensor);
|
||||||
|
} else if (std::type_index(tensor.type()) ==
|
||||||
|
std::type_index(typeid(int16_t))) {
|
||||||
|
print_lod_tensor<int16_t>(var_name, tensor);
|
||||||
|
} else if (std::type_index(tensor.type()) ==
|
||||||
|
std::type_index(typeid(int8_t))) {
|
||||||
|
print_lod_tensor<int8_t>(var_name, tensor);
|
||||||
|
} else {
|
||||||
|
VLOG(1) << "print_fetch_var: unrecognized data type:"
|
||||||
|
<< tensor.type().name();
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecutorThreadWorker::TrainFiles() {
|
||||||
|
// todo: configurable
|
||||||
|
SetDevice();
|
||||||
|
|
||||||
|
int fetch_var_num = fetch_var_names_.size();
|
||||||
|
fetch_values_.clear();
|
||||||
|
fetch_values_.resize(fetch_var_num);
|
||||||
|
|
||||||
|
thread_reader_->Start();
|
||||||
|
|
||||||
|
int cur_batch;
|
||||||
|
int batch_cnt = 0;
|
||||||
|
while ((cur_batch = thread_reader_->Next()) > 0) {
|
||||||
|
// executor run here
|
||||||
|
for (auto& op : ops_) {
|
||||||
|
op->Run(*thread_scope_, place_);
|
||||||
|
}
|
||||||
|
|
||||||
|
++batch_cnt;
|
||||||
|
thread_scope_->DropKids();
|
||||||
|
|
||||||
|
if (debug_ == false || thread_id_ != 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < fetch_var_num; ++i) {
|
||||||
|
print_fetch_var(thread_scope_, fetch_var_names_[i]);
|
||||||
|
} // end for (int i = 0...)
|
||||||
|
} // end while ()
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecutorThreadWorker::SetThreadId(int tid) { thread_id_ = tid; }
|
||||||
|
|
||||||
|
void ExecutorThreadWorker::SetPlace(const platform::Place& place) {
|
||||||
|
place_ = place;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecutorThreadWorker::SetMainProgram(
|
||||||
|
const ProgramDesc& main_program_desc) {
|
||||||
|
main_program_.reset(new ProgramDesc(main_program_desc));
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
|
||||||
|
root_scope_ = g_scope;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // einit_modelnd namespace framework
|
||||||
|
} // end namespace paddle
|
@ -0,0 +1,88 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex> // NOLINT
|
||||||
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <thread> // NOLINT
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/data_feed.h"
|
||||||
|
#include "paddle/fluid/framework/executor.h"
|
||||||
|
#include "paddle/fluid/framework/program_desc.h"
|
||||||
|
#include "paddle/fluid/framework/scope.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
void CreateTensor(Variable* var, proto::VarType::Type var_type);
|
||||||
|
|
||||||
|
class ExecutorThreadWorker {
|
||||||
|
public:
|
||||||
|
ExecutorThreadWorker()
|
||||||
|
: thread_id_(-1), root_scope_(NULL), thread_scope_(NULL), debug_(false) {}
|
||||||
|
~ExecutorThreadWorker() {}
|
||||||
|
|
||||||
|
void CreateThreadResource(const framework::ProgramDesc& program,
|
||||||
|
const paddle::platform::Place& place);
|
||||||
|
void SetThreadId(int tid);
|
||||||
|
void SetDebug(const bool debug) { debug_ = debug; }
|
||||||
|
void SetRootScope(Scope* g_scope);
|
||||||
|
// set cpu device in this function
|
||||||
|
// cpu binding is used by default
|
||||||
|
void SetDevice();
|
||||||
|
// since we read data into memory that can not be accessed by program
|
||||||
|
// we need to bind memory of data with corresponding variables in program
|
||||||
|
// this function should be called after data feed is set
|
||||||
|
void BindingDataFeedMemory();
|
||||||
|
// set data feed declared in executor
|
||||||
|
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
|
||||||
|
// A multi-thread training function
|
||||||
|
void TrainFiles();
|
||||||
|
// set fetch variable names from python interface assigned by users
|
||||||
|
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void CreateThreadScope(const framework::ProgramDesc& program);
|
||||||
|
void CreateThreadOperators(const framework::ProgramDesc& program);
|
||||||
|
void SetMainProgram(const ProgramDesc& main_program_desc);
|
||||||
|
void SetPlace(const paddle::platform::Place& place);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// thread index
|
||||||
|
std::shared_ptr<DataFeed> thread_reader_; // shared queue, thread buffer
|
||||||
|
int thread_id_;
|
||||||
|
// operator name
|
||||||
|
std::vector<std::string> op_names_;
|
||||||
|
// thread level, local operators for forward and backward
|
||||||
|
std::vector<OperatorBase*> ops_;
|
||||||
|
// main program for training
|
||||||
|
std::unique_ptr<framework::ProgramDesc> main_program_;
|
||||||
|
// execution place
|
||||||
|
platform::Place place_;
|
||||||
|
// root scope for model parameters
|
||||||
|
Scope* root_scope_;
|
||||||
|
// a thread scope, father scope is global score which is shared
|
||||||
|
Scope* thread_scope_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<std::string> fetch_var_names_;
|
||||||
|
std::vector<std::vector<float>> fetch_values_;
|
||||||
|
bool debug_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue