commit
119da44954
@ -0,0 +1,64 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stddef.h> // for size_t
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
// Channel is the abstract class of buffered and un-buffered channels.
|
||||||
|
template <typename T>
|
||||||
|
class Channel {
|
||||||
|
public:
|
||||||
|
virtual void Send(T*) = 0;
|
||||||
|
virtual void Receive(T*) = 0;
|
||||||
|
virtual size_t Cap() = 0;
|
||||||
|
|
||||||
|
// Don't delete channels; instead, call Channel::Close.
|
||||||
|
protected:
|
||||||
|
virtual ~Channel() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Forward declaration of channel implementations.
|
||||||
|
namespace details {
|
||||||
|
template <typename T>
|
||||||
|
class Buffered;
|
||||||
|
template <typename T>
|
||||||
|
class UnBuffered;
|
||||||
|
} // namespace details
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Channel<T>* MakeChannel(size_t buffer_size) {
|
||||||
|
if (buffer_size > 0) {
|
||||||
|
return new details::Buffered<T>(buffer_size);
|
||||||
|
}
|
||||||
|
return new details::UnBuffered<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CloseChannel(Channel<T>* ch) {
|
||||||
|
if (ch->Cap() > 0) {
|
||||||
|
delete dynamic_cast<details::Buffered<T>*>(ch);
|
||||||
|
} else {
|
||||||
|
delete dynamic_cast<details::UnBuffered<T>*>(ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
#include "paddle/framework/details/buffered_channel.h"
|
||||||
|
#include "paddle/framework/details/unbuffered_channel.h"
|
||||||
@ -0,0 +1,26 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/framework/channel.h"
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
TEST(Channel, MakeAndClose) {
|
||||||
|
using paddle::framework::Channel;
|
||||||
|
using paddle::framework::MakeChannel;
|
||||||
|
using paddle::framework::CloseChannel;
|
||||||
|
|
||||||
|
Channel<int>* ch = MakeChannel<int>(10);
|
||||||
|
CloseChannel(ch);
|
||||||
|
}
|
||||||
@ -0,0 +1,82 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <condition_variable>
|
||||||
|
#include <deque>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
#include "paddle/framework/channel.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace details {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class Buffered : public paddle::framework::Channel<T> {
|
||||||
|
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
|
||||||
|
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
|
||||||
|
|
||||||
|
public:
|
||||||
|
virtual void Send(T*);
|
||||||
|
virtual void Receive(T*);
|
||||||
|
virtual size_t Cap() { return cap_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t cap_;
|
||||||
|
std::mutex mu_;
|
||||||
|
std::condition_variable empty_cond_var_;
|
||||||
|
std::condition_variable full_cond_var_;
|
||||||
|
std::deque<T> channel_;
|
||||||
|
|
||||||
|
Buffered(size_t cap) : cap_(cap) {}
|
||||||
|
virtual ~Buffered();
|
||||||
|
|
||||||
|
void NotifyAllSenders(std::unique_lock<std::mutex>*);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void Buffered<T>::Send(T* item) {
|
||||||
|
std::unique_lock<std::mutex> lock(mu_);
|
||||||
|
full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_; });
|
||||||
|
channel_.push_back(std::move(*item));
|
||||||
|
lock.unlock();
|
||||||
|
empty_cond_var_.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void Buffered<T>::Receive(T* item) {
|
||||||
|
std::unique_lock<std::mutex> lock(mu_);
|
||||||
|
empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); });
|
||||||
|
*item = std::move(channel_.front());
|
||||||
|
channel_.pop_front();
|
||||||
|
NotifyAllSenders(&lock);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Buffered<T>::~Buffered() {
|
||||||
|
std::unique_lock<std::mutex> lock(mu_);
|
||||||
|
channel_.clear();
|
||||||
|
NotifyAllSenders(&lock);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void Buffered<T>::NotifyAllSenders(std::unique_lock<std::mutex>* lock) {
|
||||||
|
lock->unlock();
|
||||||
|
full_cond_var_.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace details
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
@ -0,0 +1,52 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <condition_variable>
|
||||||
|
#include <deque>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
#include "paddle/framework/channel.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace details {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class UnBuffered : public paddle::framework::Channel<T> {
|
||||||
|
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
|
||||||
|
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
|
||||||
|
|
||||||
|
public:
|
||||||
|
virtual void Send(T*);
|
||||||
|
virtual void Receive(T*);
|
||||||
|
virtual size_t Cap() { return 0; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
UnBuffered() {}
|
||||||
|
virtual ~UnBuffered();
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void UnBuffered<T>::Send(T* channel_element) {}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void UnBuffered<T>::Receive(T*) {}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
UnBuffered<T>::~UnBuffered() {}
|
||||||
|
|
||||||
|
} // namespace details
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
@ -1,24 +1,95 @@
|
|||||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License. */
|
limitations under the License. */
|
||||||
|
|
||||||
#include "paddle/framework/threadpool.h"
|
#include "paddle/framework/threadpool.h"
|
||||||
|
|
||||||
|
#include "paddle/platform/enforce.h"
|
||||||
|
|
||||||
namespace paddle {
|
namespace paddle {
|
||||||
namespace framework {
|
namespace framework {
|
||||||
|
|
||||||
std::unique_ptr<ThreadPool> ThreadPool::threadpool(nullptr);
|
std::unique_ptr<ThreadPool> ThreadPool::threadpool_(nullptr);
|
||||||
std::once_flag ThreadPool::init_flag;
|
std::once_flag ThreadPool::init_flag_;
|
||||||
|
|
||||||
|
ThreadPool* ThreadPool::GetInstance() {
|
||||||
|
std::call_once(init_flag_, &ThreadPool::Init);
|
||||||
|
return threadpool_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThreadPool::Init() {
|
||||||
|
if (threadpool_.get() == nullptr) {
|
||||||
|
// TODO(Yancey1989): specify the max threads number
|
||||||
|
int num_threads = std::thread::hardware_concurrency();
|
||||||
|
PADDLE_ENFORCE_GT(num_threads, 0);
|
||||||
|
threadpool_.reset(new ThreadPool(num_threads));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ThreadPool::ThreadPool(int num_threads)
|
||||||
|
: total_threads_(num_threads), idle_threads_(num_threads), running_(true) {
|
||||||
|
threads_.resize(num_threads);
|
||||||
|
for (auto& thread : threads_) {
|
||||||
|
// TODO(Yancey1989): binding the thread on the specify CPU number
|
||||||
|
thread.reset(new std::thread(std::bind(&ThreadPool::TaskLoop, this)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ThreadPool::~ThreadPool() {
|
||||||
|
{
|
||||||
|
// notify all threads to stop running
|
||||||
|
running_ = false;
|
||||||
|
scheduled_.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& t : threads_) {
|
||||||
|
t->join();
|
||||||
|
t.reset(nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThreadPool::Wait() {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
completed_.wait(lock, [=] { return Done() == true; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThreadPool::TaskLoop() {
|
||||||
|
while (running_) {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
|
||||||
|
|
||||||
|
if (!running_) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// pop a task from the task queue
|
||||||
|
auto task = std::move(tasks_.front());
|
||||||
|
tasks_.pop();
|
||||||
|
|
||||||
|
--idle_threads_;
|
||||||
|
lock.unlock();
|
||||||
|
|
||||||
|
// run the task
|
||||||
|
task();
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
++idle_threads_;
|
||||||
|
if (Done()) {
|
||||||
|
completed_.notify_all();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace framework
|
} // namespace framework
|
||||||
} // namespace paddle
|
} // namespace paddle
|
||||||
|
|||||||
@ -1,56 +0,0 @@
|
|||||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License. */
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "paddle/framework/block_desc.h"
|
|
||||||
#include "paddle/framework/executor.h"
|
|
||||||
#include "paddle/framework/lod_tensor.h"
|
|
||||||
#include "paddle/framework/program_desc.h"
|
|
||||||
#include "paddle/framework/scope.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
|
|
||||||
class InferenceEngine {
|
|
||||||
public:
|
|
||||||
InferenceEngine() : program_(nullptr), load_program_(nullptr) {}
|
|
||||||
~InferenceEngine() {
|
|
||||||
delete program_;
|
|
||||||
delete load_program_;
|
|
||||||
}
|
|
||||||
|
|
||||||
framework::ProgramDesc* LoadInferenceModel(framework::Executor& exe,
|
|
||||||
framework::Scope* scope,
|
|
||||||
const std::string& dirname);
|
|
||||||
|
|
||||||
const std::vector<std::string>& GetFeedVarNames() const {
|
|
||||||
return feed_var_names_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<std::string>& GetFetchVarNames() const {
|
|
||||||
return fetch_var_names_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
bool IsParameter(const framework::VarDesc* var);
|
|
||||||
void GenerateLoadProgram(const std::string& dirname);
|
|
||||||
|
|
||||||
private:
|
|
||||||
framework::ProgramDesc* program_;
|
|
||||||
framework::ProgramDesc* load_program_;
|
|
||||||
std::vector<std::string> feed_var_names_;
|
|
||||||
std::vector<std::string> fetch_var_names_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace paddle
|
|
||||||
@ -0,0 +1,41 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/framework/block_desc.h"
|
||||||
|
#include "paddle/framework/executor.h"
|
||||||
|
#include "paddle/framework/program_desc.h"
|
||||||
|
#include "paddle/framework/scope.h"
|
||||||
|
#include "paddle/framework/var_desc.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace inference {
|
||||||
|
|
||||||
|
bool IsParameter(const framework::VarDesc* var,
|
||||||
|
const framework::ProgramDesc* main_program);
|
||||||
|
|
||||||
|
void LoadPersistables(framework::Executor& executor,
|
||||||
|
framework::Scope& scope,
|
||||||
|
const std::string& dirname,
|
||||||
|
framework::ProgramDesc* main_program);
|
||||||
|
|
||||||
|
framework::ProgramDesc* Load(framework::Executor& executor,
|
||||||
|
framework::Scope& scope,
|
||||||
|
const std::string& dirname);
|
||||||
|
|
||||||
|
} // namespace inference
|
||||||
|
} // namespace paddle
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue