|
|
|
@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include <boost/tokenizer.hpp>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <thread>
|
|
|
|
|
#include <thread> // NOLINT
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/channel.h"
|
|
|
|
|
#include "paddle/fluid/framework/executor.h"
|
|
|
|
@ -22,6 +21,8 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/concurrency/channel_util.h"
|
|
|
|
|
|
|
|
|
|
#include <boost/tokenizer.hpp>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -254,8 +255,8 @@ class SelectOp : public framework::OperatorBase {
|
|
|
|
|
auto selectCond = std::make_shared<std::condition_variable_any>();
|
|
|
|
|
|
|
|
|
|
std::recursive_mutex callbackMutex;
|
|
|
|
|
pushThreadOnChannelQueues(scope, cases, selectCond, caseToExecute,
|
|
|
|
|
completed, callbackMutex);
|
|
|
|
|
pushThreadOnChannelQueues(scope, cases, selectCond, &caseToExecute,
|
|
|
|
|
&completed, &callbackMutex);
|
|
|
|
|
|
|
|
|
|
// TODO(thuan): Atomically unlock all channels and sleep current thread
|
|
|
|
|
unlockChannels(channels);
|
|
|
|
@ -302,8 +303,8 @@ class SelectOp : public framework::OperatorBase {
|
|
|
|
|
const framework::Scope *scope,
|
|
|
|
|
std::vector<std::shared_ptr<SelectOpCase>> *cases,
|
|
|
|
|
std::shared_ptr<std::condition_variable_any> rCond,
|
|
|
|
|
std::atomic<int> &caseToExecute, std::atomic<bool> &completed,
|
|
|
|
|
std::recursive_mutex &callbackMutex) const {
|
|
|
|
|
std::atomic<int> *caseToExecute, std::atomic<bool> *completed,
|
|
|
|
|
std::recursive_mutex *callbackMutex) const {
|
|
|
|
|
std::vector<std::shared_ptr<SelectOpCase>>::iterator it = cases->begin();
|
|
|
|
|
while (it != cases->end()) {
|
|
|
|
|
std::shared_ptr<SelectOpCase> c = *it;
|
|
|
|
@ -315,17 +316,17 @@ class SelectOp : public framework::OperatorBase {
|
|
|
|
|
std::function<bool(framework::ChannelAction channelAction)> cb =
|
|
|
|
|
[&caseToExecute, &completed, &callbackMutex,
|
|
|
|
|
c](framework::ChannelAction channelAction) {
|
|
|
|
|
std::lock_guard<std::recursive_mutex> lock{callbackMutex};
|
|
|
|
|
std::lock_guard<std::recursive_mutex> lock{*callbackMutex};
|
|
|
|
|
|
|
|
|
|
bool canProcess = false;
|
|
|
|
|
if (!completed) {
|
|
|
|
|
if (!(*completed)) {
|
|
|
|
|
// If the channel wasn't closed, we set the caseToExecute index
|
|
|
|
|
// as this current case
|
|
|
|
|
if (channelAction != framework::ChannelAction::CLOSE) {
|
|
|
|
|
caseToExecute = c->caseIndex;
|
|
|
|
|
*caseToExecute = c->caseIndex;
|
|
|
|
|
}
|
|
|
|
|
// This will allow our conditional variable to break out of wait
|
|
|
|
|
completed = true;
|
|
|
|
|
*completed = true;
|
|
|
|
|
canProcess = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|