commit
119da44954
@ -0,0 +1,64 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stddef.h> // for size_t
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
// Channel is the abstract class of buffered and un-buffered channels.
|
||||
template <typename T>
|
||||
class Channel {
|
||||
public:
|
||||
virtual void Send(T*) = 0;
|
||||
virtual void Receive(T*) = 0;
|
||||
virtual size_t Cap() = 0;
|
||||
|
||||
// Don't delete channels; instead, call Channel::Close.
|
||||
protected:
|
||||
virtual ~Channel() {}
|
||||
};
|
||||
|
||||
// Forward declaration of channel implementations.
|
||||
namespace details {
|
||||
template <typename T>
|
||||
class Buffered;
|
||||
template <typename T>
|
||||
class UnBuffered;
|
||||
} // namespace details
|
||||
|
||||
template <typename T>
|
||||
Channel<T>* MakeChannel(size_t buffer_size) {
|
||||
if (buffer_size > 0) {
|
||||
return new details::Buffered<T>(buffer_size);
|
||||
}
|
||||
return new details::UnBuffered<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CloseChannel(Channel<T>* ch) {
|
||||
if (ch->Cap() > 0) {
|
||||
delete dynamic_cast<details::Buffered<T>*>(ch);
|
||||
} else {
|
||||
delete dynamic_cast<details::UnBuffered<T>*>(ch);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
#include "paddle/framework/details/buffered_channel.h"
|
||||
#include "paddle/framework/details/unbuffered_channel.h"
|
||||
@ -0,0 +1,26 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/channel.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
TEST(Channel, MakeAndClose) {
|
||||
using paddle::framework::Channel;
|
||||
using paddle::framework::MakeChannel;
|
||||
using paddle::framework::CloseChannel;
|
||||
|
||||
Channel<int>* ch = MakeChannel<int>(10);
|
||||
CloseChannel(ch);
|
||||
}
|
||||
@ -0,0 +1,82 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <condition_variable>
|
||||
#include <deque>
|
||||
#include <mutex>
|
||||
|
||||
#include "paddle/framework/channel.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
template <typename T>
|
||||
class Buffered : public paddle::framework::Channel<T> {
|
||||
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
|
||||
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
|
||||
|
||||
public:
|
||||
virtual void Send(T*);
|
||||
virtual void Receive(T*);
|
||||
virtual size_t Cap() { return cap_; }
|
||||
|
||||
private:
|
||||
size_t cap_;
|
||||
std::mutex mu_;
|
||||
std::condition_variable empty_cond_var_;
|
||||
std::condition_variable full_cond_var_;
|
||||
std::deque<T> channel_;
|
||||
|
||||
Buffered(size_t cap) : cap_(cap) {}
|
||||
virtual ~Buffered();
|
||||
|
||||
void NotifyAllSenders(std::unique_lock<std::mutex>*);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void Buffered<T>::Send(T* item) {
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_; });
|
||||
channel_.push_back(std::move(*item));
|
||||
lock.unlock();
|
||||
empty_cond_var_.notify_one();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Buffered<T>::Receive(T* item) {
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); });
|
||||
*item = std::move(channel_.front());
|
||||
channel_.pop_front();
|
||||
NotifyAllSenders(&lock);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Buffered<T>::~Buffered() {
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
channel_.clear();
|
||||
NotifyAllSenders(&lock);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Buffered<T>::NotifyAllSenders(std::unique_lock<std::mutex>* lock) {
|
||||
lock->unlock();
|
||||
full_cond_var_.notify_one();
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
@ -0,0 +1,52 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <condition_variable>
|
||||
#include <deque>
|
||||
#include <mutex>
|
||||
|
||||
#include "paddle/framework/channel.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
template <typename T>
|
||||
class UnBuffered : public paddle::framework::Channel<T> {
|
||||
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
|
||||
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
|
||||
|
||||
public:
|
||||
virtual void Send(T*);
|
||||
virtual void Receive(T*);
|
||||
virtual size_t Cap() { return 0; }
|
||||
|
||||
private:
|
||||
UnBuffered() {}
|
||||
virtual ~UnBuffered();
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void UnBuffered<T>::Send(T* channel_element) {}
|
||||
|
||||
template <typename T>
|
||||
void UnBuffered<T>::Receive(T*) {}
|
||||
|
||||
template <typename T>
|
||||
UnBuffered<T>::~UnBuffered() {}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
@ -1,56 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/block_desc.h"
|
||||
#include "paddle/framework/executor.h"
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/program_desc.h"
|
||||
#include "paddle/framework/scope.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
class InferenceEngine {
|
||||
public:
|
||||
InferenceEngine() : program_(nullptr), load_program_(nullptr) {}
|
||||
~InferenceEngine() {
|
||||
delete program_;
|
||||
delete load_program_;
|
||||
}
|
||||
|
||||
framework::ProgramDesc* LoadInferenceModel(framework::Executor& exe,
|
||||
framework::Scope* scope,
|
||||
const std::string& dirname);
|
||||
|
||||
const std::vector<std::string>& GetFeedVarNames() const {
|
||||
return feed_var_names_;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& GetFetchVarNames() const {
|
||||
return fetch_var_names_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool IsParameter(const framework::VarDesc* var);
|
||||
void GenerateLoadProgram(const std::string& dirname);
|
||||
|
||||
private:
|
||||
framework::ProgramDesc* program_;
|
||||
framework::ProgramDesc* load_program_;
|
||||
std::vector<std::string> feed_var_names_;
|
||||
std::vector<std::string> fetch_var_names_;
|
||||
};
|
||||
|
||||
} // namespace paddle
|
||||
@ -0,0 +1,41 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/framework/block_desc.h"
|
||||
#include "paddle/framework/executor.h"
|
||||
#include "paddle/framework/program_desc.h"
|
||||
#include "paddle/framework/scope.h"
|
||||
#include "paddle/framework/var_desc.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
|
||||
bool IsParameter(const framework::VarDesc* var,
|
||||
const framework::ProgramDesc* main_program);
|
||||
|
||||
void LoadPersistables(framework::Executor& executor,
|
||||
framework::Scope& scope,
|
||||
const std::string& dirname,
|
||||
framework::ProgramDesc* main_program);
|
||||
|
||||
framework::ProgramDesc* Load(framework::Executor& executor,
|
||||
framework::Scope& scope,
|
||||
const std::string& dirname);
|
||||
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue