Add Go_op, Channel_create, channel_close, channel_send and channel_receive ops (#8593)
* Adding Python boilerplate code for Go op * Add very basic test case * Adding the python logic for go routine * Fix syntax * Changing test to notest * Rename Routine to Go * Combining GoGuard and Go in one class * Modify test * Adding fluid close channel * Fixing __init__.py for calling fluid.go() * Adding stubs for channel methods and updating test case * Removing import * * Adding imports from concurrency * Initial commit of GO_OP (for varun) * Creating local scopes and go through them * Updated go op inputs persistability enforcement * Add thread execution; compile failing though * Fix go op * Cleaned up Go op * Fix yapf format issue * Readd warp ctc dir for unit tests * Updated make_channel, channel_send, channel_recv and channel_close * Moved thread function to another method, update unit tests * remove output var * Add stubs for channel operators * Updating concurrency with signatures * Updated the signature with return status * Fixed dtype in variables * Updating stub of ChannelSend + add infershape * Updating stub of ChannelRecv + add infershape * Updated signature * Adding the channel_create operator * Merge channel send+receive ops * Update concurrency tests using all operators * Updating the create op with ChannelHolder * Fix issues with channel_create_op * Add the implementation for channel_close op * Add channel close operator, fix channel close op * Adding the channel_send op * Comment channels C++ and Python code * Concurrency python api comment fix * Update unit test to add Status variable * Adding channel receive operator * Update concurrency test to demonstrate a complete CSP flow * Fix clang-format issues * Fixed "Out" parameter name * Fixing merge conflict in framework.py * Add channel ops to framework.py no_kernel_op_set * Seperating channel_send and channel_recv operators * Documenting capacity type * Update concurrency test to create go block as child block of main program * Changing set status implementationoptimizer
parent
2edeb639e2
commit
0d878e4c09
@ -0,0 +1,122 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <thread>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/framework/block_desc.h"
|
||||
#include "paddle/fluid/framework/channel.h"
|
||||
#include "paddle/fluid/framework/executor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
|
||||
USE_NO_KERNEL_OP(go);
|
||||
USE_NO_KERNEL_OP(channel_close);
|
||||
USE_NO_KERNEL_OP(channel_create);
|
||||
USE_NO_KERNEL_OP(channel_recv);
|
||||
USE_NO_KERNEL_OP(channel_send);
|
||||
USE_NO_KERNEL_OP(elementwise_add);
|
||||
|
||||
namespace f = paddle::framework;
|
||||
namespace p = paddle::platform;
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
template <typename T>
|
||||
void CreateIntVariable(Scope &scope, p::CPUPlace &place, std::string name,
|
||||
T value) {
|
||||
// Create LoDTensor<int> of dim [1,1]
|
||||
auto var = scope.Var(name);
|
||||
auto tensor = var->GetMutable<LoDTensor>();
|
||||
tensor->Resize({1, 1});
|
||||
T *expect = tensor->mutable_data<T>(place);
|
||||
expect[0] = value;
|
||||
}
|
||||
|
||||
void InitTensorsInScope(Scope &scope, p::CPUPlace &place) {
|
||||
p::CPUDeviceContext ctx(place);
|
||||
|
||||
// Create channel variable
|
||||
scope.Var("Channel");
|
||||
|
||||
// Create Variables, x0 will be put into channel,
|
||||
// result will be pulled from channel
|
||||
CreateIntVariable(scope, place, "Status", false);
|
||||
CreateIntVariable(scope, place, "x0", 99);
|
||||
CreateIntVariable(scope, place, "result", 0);
|
||||
}
|
||||
|
||||
void AddOp(const std::string &type, const VariableNameMap &inputs,
|
||||
const VariableNameMap &outputs, AttributeMap attrs,
|
||||
BlockDesc *block) {
|
||||
// insert op
|
||||
auto op = block->AppendOp();
|
||||
op->SetType(type);
|
||||
for (auto &kv : inputs) {
|
||||
op->SetInput(kv.first, kv.second);
|
||||
}
|
||||
for (auto &kv : outputs) {
|
||||
op->SetOutput(kv.first, kv.second);
|
||||
}
|
||||
op->SetAttrMap(attrs);
|
||||
}
|
||||
|
||||
TEST(Concurrency, Go_Op) {
|
||||
Scope scope;
|
||||
p::CPUPlace place;
|
||||
|
||||
// Initialize scope variables
|
||||
InitTensorsInScope(scope, place);
|
||||
|
||||
framework::Executor executor(place);
|
||||
ProgramDesc program;
|
||||
BlockDesc *block = program.MutableBlock(0);
|
||||
|
||||
// Create channel OP
|
||||
AddOp("channel_create", {}, {{"Out", {"Channel"}}},
|
||||
{{"capacity", 10}, {"data_type", f::proto::VarType::LOD_TENSOR}},
|
||||
block);
|
||||
|
||||
// Create Go Op routine
|
||||
BlockDesc *goOpBlock = program.AppendBlock(program.Block(0));
|
||||
AddOp("channel_send", {{"Channel", {"Channel"}}, {"X", {"x0"}}},
|
||||
{{"Status", {"Status"}}}, {}, goOpBlock);
|
||||
|
||||
// Create Go Op
|
||||
AddOp("go", {{"X", {"Channel", "x0"}}}, {}, {{"sub_block", goOpBlock}},
|
||||
block);
|
||||
|
||||
// Create Channel Receive Op
|
||||
AddOp("channel_recv", {{"Channel", {"Channel"}}},
|
||||
{{"Status", {"Status"}}, {"Out", {"result"}}}, {}, block);
|
||||
|
||||
// Create Channel Close Op
|
||||
AddOp("channel_close", {{"Channel", {"Channel"}}}, {}, {}, block);
|
||||
|
||||
// Check the result tensor to make sure it is set to 0
|
||||
const LoDTensor &tensor = (scope.FindVar("result"))->Get<LoDTensor>();
|
||||
auto *initialData = tensor.data<int>();
|
||||
EXPECT_EQ(initialData[0], 0);
|
||||
|
||||
executor.Run(program, &scope, 0, true, true);
|
||||
|
||||
// After we call executor.run, the Go operator should do a channel_send to set
|
||||
// the
|
||||
// "result" variable to 99
|
||||
auto *finalData = tensor.data<int>();
|
||||
EXPECT_EQ(finalData[0], 99);
|
||||
}
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,71 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/channel.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace pf = paddle::framework;
|
||||
static constexpr char kChannel[] = "Channel";
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class ChannelCloseOp : public framework::OperatorBase {
|
||||
public:
|
||||
ChannelCloseOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: framework::OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &dev_place) const override {
|
||||
auto &inp = *scope.FindVar(Input(kChannel));
|
||||
|
||||
// Get the mutable version of the channel variable and closes it.
|
||||
pf::ChannelHolder *ch = inp.GetMutable<framework::ChannelHolder>();
|
||||
ch->close();
|
||||
}
|
||||
};
|
||||
|
||||
class ChannelCloseOpOpInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *context) const override {
|
||||
PADDLE_ENFORCE(context->HasInput("Channel"),
|
||||
"The input of ChannelClose op must be set");
|
||||
}
|
||||
};
|
||||
|
||||
class ChannelCloseOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
ChannelCloseOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput(kChannel,
|
||||
"The Channel Variable that should be closed by"
|
||||
" the ChannelClose Op.");
|
||||
AddComment(R"DOC(
|
||||
Channel Close Operator.
|
||||
|
||||
This operator closes an open channel.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OPERATOR(channel_close, paddle::operators::ChannelCloseOp,
|
||||
paddle::framework::EmptyGradOpMaker,
|
||||
paddle::operators::ChannelCloseOpMaker);
|
@ -0,0 +1,114 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/channel.h"
|
||||
#include "paddle/fluid/framework/lod_rank_table.h"
|
||||
#include "paddle/fluid/framework/lod_tensor_array.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/reader.h"
|
||||
|
||||
namespace pf = paddle::framework;
|
||||
|
||||
static constexpr char kOutput[] = "Out";
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class ChannelCreateOp : public framework::OperatorBase {
|
||||
public:
|
||||
ChannelCreateOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: framework::OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &dev_place) const override {
|
||||
auto &out = *scope.FindVar(Output(kOutput));
|
||||
|
||||
// Determine the datatype and capacity of the channel to be created
|
||||
// from the attributes provided.
|
||||
auto dtype =
|
||||
static_cast<framework::proto::VarType::Type>(Attr<int>("data_type"));
|
||||
auto capacity = Attr<int>("capacity");
|
||||
|
||||
// Based on the datatype, create a new channel holder initialized with
|
||||
// the given capacity. When capacity is 0, an unbuffered channel is
|
||||
// created.
|
||||
pf::ChannelHolder *ch = out.GetMutable<framework::ChannelHolder>();
|
||||
if (dtype == framework::proto::VarType::LOD_TENSOR) {
|
||||
ch->Reset<pf::LoDTensor>(capacity);
|
||||
} else if (dtype == framework::proto::VarType::SELECTED_ROWS) {
|
||||
ch->Reset<pf::SelectedRows>(capacity);
|
||||
} else if (dtype == framework::proto::VarType::LOD_RANK_TABLE) {
|
||||
ch->Reset<pf::LoDRankTable>(capacity);
|
||||
} else if (dtype == framework::proto::VarType::LOD_TENSOR_ARRAY) {
|
||||
ch->Reset<pf::LoDTensorArray>(capacity);
|
||||
} else if (dtype == framework::proto::VarType::READER) {
|
||||
ch->Reset<pf::ReaderHolder>(capacity);
|
||||
} else if (dtype == framework::proto::VarType::CHANNEL) {
|
||||
ch->Reset<pf::ChannelHolder>(capacity);
|
||||
} else if (dtype == framework::proto::VarType::BOOL) {
|
||||
ch->Reset<bool>(capacity);
|
||||
} else if (dtype == framework::proto::VarType::INT32) {
|
||||
ch->Reset<int>(capacity);
|
||||
} else if (dtype == framework::proto::VarType::INT64) {
|
||||
ch->Reset<int64_t>(capacity);
|
||||
} else if (dtype == framework::proto::VarType::FP32) {
|
||||
ch->Reset<float>(capacity);
|
||||
} else if (dtype == framework::proto::VarType::FP64) {
|
||||
ch->Reset<double>(capacity);
|
||||
} else {
|
||||
PADDLE_THROW(
|
||||
"Data type %d is not in "
|
||||
"[LOD_TENSOR, SELECTED_ROWS, LOD_RANK_TABLE, LOD_TENSOR_ARRAY, "
|
||||
"READER, CHANNEL, BOOL, INT32, INT64, FP32, FP64]",
|
||||
dtype);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ChannelCreateOpOpInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *context) const override {
|
||||
PADDLE_ENFORCE(context->HasOutput(kOutput),
|
||||
"The output of ChannelCreate op must be set");
|
||||
context->SetOutputDim(kOutput, {1});
|
||||
}
|
||||
};
|
||||
|
||||
class ChannelCreateOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
ChannelCreateOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddOutput(kOutput,
|
||||
"The object of a Channel type created by ChannelCreate Op.");
|
||||
AddAttr<int>("capacity", "The size of the buffer of Channel.")
|
||||
.SetDefault(0);
|
||||
AddAttr<int>("data_type", "The data type of elements inside the Channel.");
|
||||
AddComment(R"DOC(
|
||||
Channel Create Operator.
|
||||
|
||||
This operator creates an object of the VarType Channel and returns it.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OPERATOR(channel_create, paddle::operators::ChannelCreateOp,
|
||||
paddle::framework::EmptyGradOpMaker,
|
||||
paddle::operators::ChannelCreateOpMaker);
|
@ -0,0 +1,117 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/channel.h"
|
||||
#include <paddle/fluid/framework/lod_rank_table.h>
|
||||
#include <paddle/fluid/framework/lod_tensor_array.h>
|
||||
#include <paddle/fluid/framework/reader.h>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/var_type.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
static constexpr char Channel[] = "Channel";
|
||||
static constexpr char Status[] = "Status";
|
||||
static constexpr char Out[] = "Out";
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
void SetReceiveStatus(const platform::Place &dev_place,
|
||||
framework::Variable &status_var, bool status) {
|
||||
auto cpu = platform::CPUPlace();
|
||||
auto status_tensor =
|
||||
status_var.GetMutable<framework::LoDTensor>()->mutable_data<bool>({1},
|
||||
cpu);
|
||||
status_tensor[0] = status;
|
||||
}
|
||||
|
||||
bool ChannelReceive(framework::ChannelHolder *ch, framework::Variable *var) {
|
||||
// Get type of channel and use that to call mutable data for Variable
|
||||
auto type = framework::ToVarType(ch->Type());
|
||||
if (type == framework::proto::VarType_Type_LOD_TENSOR)
|
||||
return ch->Receive(var->GetMutable<framework::LoDTensor>());
|
||||
else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE)
|
||||
return ch->Receive(var->GetMutable<framework::LoDRankTable>());
|
||||
else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY)
|
||||
return ch->Receive(var->GetMutable<framework::LoDTensorArray>());
|
||||
else if (type == framework::proto::VarType_Type_SELECTED_ROWS)
|
||||
return ch->Receive(var->GetMutable<framework::SelectedRows>());
|
||||
else if (type == framework::proto::VarType_Type_READER)
|
||||
return ch->Receive(var->GetMutable<framework::ReaderHolder>());
|
||||
else if (type == framework::proto::VarType_Type_CHANNEL)
|
||||
return ch->Receive(var->GetMutable<framework::ChannelHolder>());
|
||||
else
|
||||
PADDLE_THROW("ChannelReceive:Unsupported type");
|
||||
}
|
||||
|
||||
class ChannelRecvOp : public framework::OperatorBase {
|
||||
public:
|
||||
ChannelRecvOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: framework::OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const {
|
||||
PADDLE_ENFORCE(ctx->HasInput(Channel),
|
||||
"Input(Channel) of ChannelRecvOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(Out),
|
||||
"Input(Channel) of ChannelRecvOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(Status),
|
||||
"Output(Status) of ChannelRecvOp should not be null.");
|
||||
ctx->SetOutputDim("Status", {1});
|
||||
}
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &dev_place) const override {
|
||||
// Get the channel holder created by channel_create op, passed as input.
|
||||
framework::ChannelHolder *ch =
|
||||
scope.FindVar(Input(Channel))->GetMutable<framework::ChannelHolder>();
|
||||
auto output_var = scope.FindVar(Output(Out));
|
||||
// Receive the data from the channel.
|
||||
bool ok = ChannelReceive(ch, output_var);
|
||||
|
||||
// Set the status output of the `ChannelReceive` call.
|
||||
SetReceiveStatus(dev_place, *scope.FindVar(Output(Status)), ok);
|
||||
}
|
||||
};
|
||||
|
||||
class ChannelRecvOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
ChannelRecvOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput(Channel,
|
||||
"(Channel) A variable which \"receives\" the a value sent"
|
||||
"to it by a channel_send op.")
|
||||
.AsDuplicable();
|
||||
AddOutput(Out,
|
||||
"(Variable) Output Variable that will hold the data received"
|
||||
" from the Channel")
|
||||
.AsDuplicable();
|
||||
AddOutput(Status,
|
||||
"(Tensor) An LoD Tensor that returns a boolean status of the"
|
||||
"result of the receive operation.")
|
||||
.AsDuplicable();
|
||||
AddComment(R"DOC(
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OPERATOR(channel_recv, paddle::operators::ChannelRecvOp,
|
||||
paddle::framework::EmptyGradOpMaker,
|
||||
paddle::operators::ChannelRecvOpMaker);
|
@ -0,0 +1,117 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/channel.h"
|
||||
#include <paddle/fluid/framework/lod_rank_table.h>
|
||||
#include <paddle/fluid/framework/lod_tensor_array.h>
|
||||
#include <paddle/fluid/framework/reader.h>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/var_type.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
static constexpr char Channel[] = "Channel";
|
||||
static constexpr char X[] = "X";
|
||||
static constexpr char Status[] = "Status";
|
||||
static constexpr char copy[] = "copy";
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
void SetSendStatus(const platform::Place &dev_place,
|
||||
framework::Variable &status_var, bool status) {
|
||||
auto cpu = platform::CPUPlace();
|
||||
auto status_tensor =
|
||||
status_var.GetMutable<framework::LoDTensor>()->mutable_data<bool>({1},
|
||||
cpu);
|
||||
status_tensor[0] = status;
|
||||
}
|
||||
|
||||
bool ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) {
|
||||
auto type = framework::ToVarType(var->Type());
|
||||
if (type == framework::proto::VarType_Type_LOD_TENSOR)
|
||||
return ch->Send(var->GetMutable<framework::LoDTensor>());
|
||||
else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE)
|
||||
return ch->Send(var->GetMutable<framework::LoDRankTable>());
|
||||
else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY)
|
||||
return ch->Send(var->GetMutable<framework::LoDTensorArray>());
|
||||
else if (type == framework::proto::VarType_Type_SELECTED_ROWS)
|
||||
return ch->Send(var->GetMutable<framework::SelectedRows>());
|
||||
else if (type == framework::proto::VarType_Type_READER)
|
||||
return ch->Send(var->GetMutable<framework::ReaderHolder>());
|
||||
else if (type == framework::proto::VarType_Type_CHANNEL)
|
||||
return ch->Send(var->GetMutable<framework::ChannelHolder>());
|
||||
else
|
||||
PADDLE_THROW("ChannelSend:Unsupported type");
|
||||
}
|
||||
|
||||
class ChannelSendOp : public framework::OperatorBase {
|
||||
public:
|
||||
ChannelSendOp(const std::string &type,
|
||||
const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: framework::OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const {
|
||||
PADDLE_ENFORCE(ctx->HasInput(Channel),
|
||||
"Input(Channel) of ChannelSendOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput(X),
|
||||
"Input(X) of ChannelSendOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(Status),
|
||||
"Output(Status) of ChannelSendOp should not be null.");
|
||||
ctx->SetOutputDim("Status", {1});
|
||||
}
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &dev_place) const override {
|
||||
// Get the channel holder created by channel_create op, passed as input.
|
||||
framework::ChannelHolder *ch =
|
||||
scope.FindVar(Input(Channel))->GetMutable<framework::ChannelHolder>();
|
||||
auto input_var = scope.FindVar(Input(X));
|
||||
|
||||
// Send the input data through the channel.
|
||||
bool ok = ChannelSend(ch, input_var);
|
||||
|
||||
// Set the status output of the `ChannelSend` call.
|
||||
SetSendStatus(dev_place, *scope.FindVar(Output(Status)), ok);
|
||||
}
|
||||
};
|
||||
|
||||
class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
ChannelSendOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput(Channel,
|
||||
"(Channel) A variable which \"sends\" the passed in value to "
|
||||
"a listening receiver.")
|
||||
.AsDuplicable();
|
||||
AddInput(X, "(Variable) The value which gets sent by the channel.")
|
||||
.AsDuplicable();
|
||||
AddOutput(Status,
|
||||
"(Tensor) An LoD Tensor that returns a boolean status of the"
|
||||
"result of the send operation.")
|
||||
.AsDuplicable();
|
||||
AddAttr<bool>(copy, "(bool, default false) Should copy before send")
|
||||
.SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OPERATOR(channel_send, paddle::operators::ChannelSendOp,
|
||||
paddle::framework::EmptyGradOpMaker,
|
||||
paddle::operators::ChannelSendOpMaker);
|
@ -0,0 +1,111 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/executor.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using StepScopeVar = std::vector<framework::Scope *>;
|
||||
|
||||
static constexpr char kBlock[] = "sub_block";
|
||||
static constexpr char kX[] = "X";
|
||||
|
||||
class GoOp : public framework::OperatorBase {
|
||||
public:
|
||||
GoOp(const std::string &type, const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: framework::OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
private:
|
||||
void ExecuteOnThread(framework::Executor *executor,
|
||||
framework::BlockDesc *block,
|
||||
framework::Scope *scope) const {
|
||||
framework::ProgramDesc *program = block->Program();
|
||||
executor->Run(*program, scope, block->ID(), false /*create_local_scope*/);
|
||||
}
|
||||
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &dev_place) const override {
|
||||
/*
|
||||
* Determine the global scope. Create a new child scope.
|
||||
* Within the child scope, add all the local variables relevant
|
||||
* to that scope.
|
||||
*
|
||||
* Now go through all the inputs to the op to ensure that
|
||||
* all of them are in the newly created scope. This is important
|
||||
* to ensure that they don't get destroyed when the parent scope
|
||||
* is deleted.
|
||||
* */
|
||||
|
||||
// TODO(varunarora): Consider moving this root scope lookup to scope.h.
|
||||
const framework::Scope *root_scope = &scope;
|
||||
const framework::Scope *parent_scope = &(root_scope->parent());
|
||||
|
||||
while (parent_scope != nullptr) {
|
||||
root_scope = parent_scope;
|
||||
parent_scope = &(parent_scope->parent());
|
||||
}
|
||||
|
||||
framework::BlockDesc *block = Attr<framework::BlockDesc *>(kBlock);
|
||||
framework::Executor executor(dev_place);
|
||||
framework::Scope &new_scope = root_scope->NewScope();
|
||||
|
||||
for (auto &var : block->AllVars()) {
|
||||
new_scope.Var(var->Name());
|
||||
}
|
||||
|
||||
auto &inputs = Inputs(kX);
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
PADDLE_ENFORCE_NOT_NULL(new_scope.FindVar(inputs.at(i)),
|
||||
"All variables used in the go block "
|
||||
"should be created in the global scope");
|
||||
}
|
||||
|
||||
// Now execute the go op with the newly created scope.
|
||||
std::thread go_thread([dev_place, block, &new_scope, this]() {
|
||||
framework::Executor executor(dev_place);
|
||||
ExecuteOnThread(&executor, block, &new_scope);
|
||||
});
|
||||
go_thread.detach();
|
||||
}
|
||||
};
|
||||
|
||||
class GoOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
GoOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput(kX,
|
||||
"A set of variables, which are required by operators inside the "
|
||||
"block of Go Op.")
|
||||
.AsDuplicable();
|
||||
AddAttr<framework::BlockDesc *>(kBlock, "The block inside GoOp");
|
||||
AddComment(R"DOC(
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(thuan): Look into Gradient Operator for GO_OP
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OPERATOR(go, paddle::operators::GoOp,
|
||||
paddle::framework::EmptyGradOpMaker,
|
||||
paddle::operators::GoOpMaker);
|
Loading…
Reference in new issue