Implement Select OP (#9088)
* Fix old documentation for channel_recv * Initial design of CSP select * Redesign channel implementation for Select Op * Remove unecessary header * Initial checkin of select op, currently will read all the conditional_op in the cases block and also pull out all channels involved in the select. * Init python select op API * Python select bug fix when checking op creates block * Add case_to_execute as (a) input to select, (b) into the passed inputs into the select op * Add in addition code for select op * Init fibonacci test from python * implement fibonnaci sequence test * update fib unit test * Improve select test cases * Shorten non-pep-8-ed lines * Add methods on channel needed by select op * Fix compile issues, finish implementation, still need to debug code * Fix issue with fibonncci test, it works now! * Change QueueMessage callback to take in an ChannelAction enum, fix select unit test * Fix case attributes * Fix issue with select control flow * Make cases - previously on each selectcase conditional_block - attributes to select * Use class constants for type of channel * Change select op to take in "cases" attribute * return boolean from select callback function to tell Channel if this RECV or SEND should be executed * Improve attributes and inputs comments on select op * Fix issues with python unit test * Assert fibonacci final output * Fix issue when channel name / channel var is null for "default" case in select op * Assert base select test output * Make QueueMessage use shared pointer and modify the order of the callback * Fixing the order in which the callback is called * Move channel utility methods to paddle/fluid/operators/concurrency/channel_util * Create channel_util and move channel util methods * Fix crash when calling select_op * Fix deadlock * Fix issue of channel destructor deadlock * Fix precommit issues * Accidentally checked in changes to beam_search_op, reverting change. * Fix dependency issue in concurrency cmake * add device_context dependency for concurrency targetshanyi15-patch-2
parent
45073b7c39
commit
1e4c504e60
@ -0,0 +1 @@
|
||||
cc_library(concurrency SRCS channel_util.cc DEPS device_context framework_proto boost eigen3)
|
@ -0,0 +1,111 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "channel_util.h"
|
||||
#include "paddle/fluid/framework/var_type.h"
|
||||
|
||||
namespace poc = paddle::operators::concurrency;
|
||||
|
||||
bool poc::ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) {
|
||||
auto type = framework::ToVarType(var->Type());
|
||||
if (type == framework::proto::VarType_Type_LOD_TENSOR)
|
||||
return ch->Send(var->GetMutable<framework::LoDTensor>());
|
||||
else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE)
|
||||
return ch->Send(var->GetMutable<framework::LoDRankTable>());
|
||||
else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY)
|
||||
return ch->Send(var->GetMutable<framework::LoDTensorArray>());
|
||||
else if (type == framework::proto::VarType_Type_SELECTED_ROWS)
|
||||
return ch->Send(var->GetMutable<framework::SelectedRows>());
|
||||
else if (type == framework::proto::VarType_Type_READER)
|
||||
return ch->Send(var->GetMutable<framework::ReaderHolder>());
|
||||
else if (type == framework::proto::VarType_Type_CHANNEL)
|
||||
return ch->Send(var->GetMutable<framework::ChannelHolder>());
|
||||
else
|
||||
PADDLE_THROW("ChannelSend:Unsupported type");
|
||||
}
|
||||
|
||||
bool poc::ChannelReceive(framework::ChannelHolder *ch,
|
||||
framework::Variable *var) {
|
||||
// Get type of channel and use that to call mutable data for Variable
|
||||
auto type = framework::ToVarType(ch->Type());
|
||||
if (type == framework::proto::VarType_Type_LOD_TENSOR)
|
||||
return ch->Receive(var->GetMutable<framework::LoDTensor>());
|
||||
else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE)
|
||||
return ch->Receive(var->GetMutable<framework::LoDRankTable>());
|
||||
else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY)
|
||||
return ch->Receive(var->GetMutable<framework::LoDTensorArray>());
|
||||
else if (type == framework::proto::VarType_Type_SELECTED_ROWS)
|
||||
return ch->Receive(var->GetMutable<framework::SelectedRows>());
|
||||
else if (type == framework::proto::VarType_Type_READER)
|
||||
return ch->Receive(var->GetMutable<framework::ReaderHolder>());
|
||||
else if (type == framework::proto::VarType_Type_CHANNEL)
|
||||
return ch->Receive(var->GetMutable<framework::ChannelHolder>());
|
||||
else
|
||||
PADDLE_THROW("ChannelReceive:Unsupported type");
|
||||
}
|
||||
|
||||
void poc::ChannelAddToSendQ(framework::ChannelHolder *ch, const void *referrer,
|
||||
framework::Variable *var,
|
||||
std::shared_ptr<std::condition_variable_any> cond,
|
||||
std::function<bool(framework::ChannelAction)> cb) {
|
||||
auto type = framework::ToVarType(var->Type());
|
||||
if (type == framework::proto::VarType_Type_LOD_TENSOR) {
|
||||
ch->AddToSendQ(referrer, var->GetMutable<framework::LoDTensor>(), cond, cb);
|
||||
} else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) {
|
||||
ch->AddToSendQ(referrer, var->GetMutable<framework::LoDRankTable>(), cond,
|
||||
cb);
|
||||
} else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) {
|
||||
ch->AddToSendQ(referrer, var->GetMutable<framework::LoDTensorArray>(), cond,
|
||||
cb);
|
||||
} else if (type == framework::proto::VarType_Type_SELECTED_ROWS) {
|
||||
ch->AddToSendQ(referrer, var->GetMutable<framework::SelectedRows>(), cond,
|
||||
cb);
|
||||
} else if (type == framework::proto::VarType_Type_READER) {
|
||||
ch->AddToSendQ(referrer, var->GetMutable<framework::ReaderHolder>(), cond,
|
||||
cb);
|
||||
} else if (type == framework::proto::VarType_Type_CHANNEL) {
|
||||
ch->AddToSendQ(referrer, var->GetMutable<framework::ChannelHolder>(), cond,
|
||||
cb);
|
||||
} else {
|
||||
PADDLE_THROW("ChannelAddToSendQ:Unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
void poc::ChannelAddToReceiveQ(
|
||||
framework::ChannelHolder *ch, const void *referrer,
|
||||
framework::Variable *var, std::shared_ptr<std::condition_variable_any> cond,
|
||||
std::function<bool(framework::ChannelAction)> cb) {
|
||||
auto type = framework::ToVarType(var->Type());
|
||||
if (type == framework::proto::VarType_Type_LOD_TENSOR) {
|
||||
ch->AddToReceiveQ(referrer, var->GetMutable<framework::LoDTensor>(), cond,
|
||||
cb);
|
||||
} else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) {
|
||||
ch->AddToReceiveQ(referrer, var->GetMutable<framework::LoDRankTable>(),
|
||||
cond, cb);
|
||||
} else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) {
|
||||
ch->AddToReceiveQ(referrer, var->GetMutable<framework::LoDTensorArray>(),
|
||||
cond, cb);
|
||||
} else if (type == framework::proto::VarType_Type_SELECTED_ROWS) {
|
||||
ch->AddToReceiveQ(referrer, var->GetMutable<framework::SelectedRows>(),
|
||||
cond, cb);
|
||||
} else if (type == framework::proto::VarType_Type_READER) {
|
||||
ch->AddToReceiveQ(referrer, var->GetMutable<framework::ReaderHolder>(),
|
||||
cond, cb);
|
||||
} else if (type == framework::proto::VarType_Type_CHANNEL) {
|
||||
ch->AddToReceiveQ(referrer, var->GetMutable<framework::ChannelHolder>(),
|
||||
cond, cb);
|
||||
} else {
|
||||
PADDLE_THROW("ChannelAddToReceiveQ:Unsupported type");
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/channel.h"
|
||||
#include "paddle/fluid/framework/variable.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace concurrency {
|
||||
|
||||
bool ChannelSend(framework::ChannelHolder *ch, framework::Variable *var);
|
||||
bool ChannelReceive(framework::ChannelHolder *ch, framework::Variable *var);
|
||||
|
||||
void ChannelAddToSendQ(framework::ChannelHolder *ch, const void *referrer,
|
||||
framework::Variable *var,
|
||||
std::shared_ptr<std::condition_variable_any> cond,
|
||||
std::function<bool(framework::ChannelAction)> cb);
|
||||
void ChannelAddToReceiveQ(framework::ChannelHolder *ch, const void *referrer,
|
||||
framework::Variable *var,
|
||||
std::shared_ptr<std::condition_variable_any> cond,
|
||||
std::function<bool(framework::ChannelAction)> cb);
|
||||
|
||||
} // namespace concurrency
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue