clean up channel

test=develop
revert-13637-optimize-opyreader
Xin Pan 7 years ago
parent 35b713c3fd
commit ddd60581b7

@ -169,15 +169,8 @@ cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
# cc_test(channel_test SRCS channel_test.cc)
cc_test(tuple_test SRCS tuple_test.cc )
if (NOT WIN32)
cc_test(rw_lock_test SRCS rw_lock_test.cc)
endif (NOT WIN32)
# disable test temporarily.
# TODO https://github.com/PaddlePaddle/Paddle/issues/11971
# cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op
# channel_send_op channel_recv_op sum_op select_op elementwise_add_op compare_op
# conditional_block_op while_op assign_op print_op executor proto_desc)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
@ -76,15 +75,13 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::CHANNEL) {
var->GetMutable<ChannelHolder>();
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
"LOD_RANK_TABLE, PLACE_LIST, READER, RAW]",
var_type);
}
}

@ -126,7 +126,6 @@ message VarType {
LOD_TENSOR_ARRAY = 13;
PLACE_LIST = 14;
READER = 15;
CHANNEL = 16;
// Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators like nccl_op
@ -158,12 +157,6 @@ message VarType {
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
optional ReaderDesc reader = 5;
message ChannelDesc {
required Type data_type = 1;
required int64 capacity = 2;
}
optional ChannelDesc channel = 6;
message Tuple { repeated Type element_type = 1; }
optional Tuple tuple = 7;
}

@ -17,7 +17,6 @@ limitations under the License. */
#include <stdexcept>
#include <string>
#include <vector>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_desc.h"

@ -88,13 +88,7 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
}
void VarDesc::SetDataType(proto::VarType::Type data_type) {
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
mutable_channel_desc()->set_data_type(data_type);
break;
default:
mutable_tensor_desc()->set_data_type(data_type);
}
mutable_tensor_desc()->set_data_type(data_type);
}
void VarDesc::SetDataTypes(
@ -115,13 +109,7 @@ void VarDesc::SetDataTypes(
}
proto::VarType::Type VarDesc::GetDataType() const {
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return channel_desc().data_type();
break;
default:
return tensor_desc().data_type();
}
return tensor_desc().data_type();
}
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
@ -134,17 +122,6 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
return res;
}
void VarDesc::SetCapacity(int64_t capacity) {
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
desc_.mutable_type()->mutable_channel()->set_capacity(capacity);
break;
default:
PADDLE_THROW("Setting 'capacity' is not supported by the type of var %s.",
this->Name());
}
}
void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type().type()) {
case proto::VarType::LOD_TENSOR:
@ -214,19 +191,6 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
}
}
const proto::VarType::ChannelDesc &VarDesc::channel_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return desc_.type().channel();
default:
PADDLE_THROW(
"Getting 'channel_desc' is not supported by the type of var %s.",
this->Name());
}
}
const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
@ -262,20 +226,6 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
}
}
proto::VarType::ChannelDesc *VarDesc::mutable_channel_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return desc_.mutable_type()->mutable_channel();
default:
PADDLE_THROW(
"Getting 'mutable_channel_desc' is not supported by the type of var "
"%s.",
this->Name());
}
}
proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");

@ -87,8 +87,6 @@ class VarDesc {
void SetDataTypes(
const std::vector<proto::VarType::Type> &multiple_data_type);
void SetCapacity(int64_t capacity);
proto::VarType::Type GetDataType() const;
std::vector<proto::VarType::Type> GetDataTypes() const;
@ -110,10 +108,8 @@ class VarDesc {
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
private:
const proto::VarType::ChannelDesc &channel_desc() const;
const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::VarType::TensorDesc> tensor_descs() const;
proto::VarType::ChannelDesc *mutable_channel_desc();
proto::VarType::TensorDesc *mutable_tensor_desc();
std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();

@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
@ -41,8 +40,6 @@ inline proto::VarType::Type ToVarType(std::type_index type) {
return proto::VarType_Type_SELECTED_ROWS;
} else if (IsType<ReaderHolder>(type)) {
return proto::VarType_Type_READER;
} else if (IsType<ChannelHolder>(type)) {
return proto::VarType_Type_CHANNEL;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
}
@ -66,9 +63,6 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
case proto::VarType_Type_READER:
visitor(var.Get<ReaderHolder>());
return;
case proto::VarType_Type_CHANNEL:
visitor(var.Get<ChannelHolder>());
return;
default:
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
}

@ -41,12 +41,6 @@ class AnalysisPass {
// all passes have run.
virtual bool Finalize() { return false; }
// Get a Pass appropriate to print the Node this pass operates on.
virtual AnalysisPass *CreatePrinterPass(std::ostream &os,
const std::string &banner) const {
return nullptr;
}
// Create a debugger Pass that draw the DFG by graphviz toolkit.
virtual AnalysisPass *CreateGraphvizDebugerPass() const { return nullptr; }

@ -313,11 +313,6 @@ op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor)
op_library(concat_op DEPS concat)
# FIXME(thuan): Move CSP operators to paddle/fluid/framework/operators/concurrency
add_subdirectory(concurrency)
op_library(channel_send_op DEPS concurrency)
op_library(channel_recv_op DEPS concurrency)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})

@ -1,70 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/op_registry.h"
namespace pf = paddle::framework;
static constexpr char kChannel[] = "Channel";
namespace paddle {
namespace operators {
class ChannelCloseOp : public framework::OperatorBase {
public:
ChannelCloseOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto &inp = *scope.FindVar(Input(kChannel));
// Get the mutable version of the channel variable and closes it.
pf::ChannelHolder *ch = inp.GetMutable<framework::ChannelHolder>();
ch->close();
}
};
class ChannelCloseOpOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("Channel"),
"The input of ChannelClose op must be set");
}
};
class ChannelCloseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(kChannel,
"The Channel Variable that should be closed by"
" the ChannelClose Op.");
AddComment(R"DOC(
Channel Close Operator.
This operator closes an open channel.
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(channel_close, paddle::operators::ChannelCloseOp,
paddle::framework::EmptyGradOpMaker,
paddle::operators::ChannelCloseOpMaker);

@ -1,113 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
namespace pf = paddle::framework;
static constexpr char kOutput[] = "Out";
namespace paddle {
namespace operators {
class ChannelCreateOp : public framework::OperatorBase {
public:
ChannelCreateOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto &out = *scope.FindVar(Output(kOutput));
// Determine the datatype and capacity of the channel to be created
// from the attributes provided.
auto dtype =
static_cast<framework::proto::VarType::Type>(Attr<int>("data_type"));
auto capacity = Attr<int>("capacity");
// Based on the datatype, create a new channel holder initialized with
// the given capacity. When capacity is 0, an unbuffered channel is
// created.
pf::ChannelHolder *ch = out.GetMutable<framework::ChannelHolder>();
if (dtype == framework::proto::VarType::LOD_TENSOR) {
ch->Reset<pf::LoDTensor>(capacity);
} else if (dtype == framework::proto::VarType::SELECTED_ROWS) {
ch->Reset<pf::SelectedRows>(capacity);
} else if (dtype == framework::proto::VarType::LOD_RANK_TABLE) {
ch->Reset<pf::LoDRankTable>(capacity);
} else if (dtype == framework::proto::VarType::LOD_TENSOR_ARRAY) {
ch->Reset<pf::LoDTensorArray>(capacity);
} else if (dtype == framework::proto::VarType::READER) {
ch->Reset<pf::ReaderHolder>(capacity);
} else if (dtype == framework::proto::VarType::CHANNEL) {
ch->Reset<pf::ChannelHolder>(capacity);
} else if (dtype == framework::proto::VarType::BOOL) {
ch->Reset<bool>(capacity);
} else if (dtype == framework::proto::VarType::INT32) {
ch->Reset<int>(capacity);
} else if (dtype == framework::proto::VarType::INT64) {
ch->Reset<int64_t>(capacity);
} else if (dtype == framework::proto::VarType::FP32) {
ch->Reset<float>(capacity);
} else if (dtype == framework::proto::VarType::FP64) {
ch->Reset<double>(capacity);
} else {
PADDLE_THROW(
"Data type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, LOD_RANK_TABLE, LOD_TENSOR_ARRAY, "
"READER, CHANNEL, BOOL, INT32, INT64, FP32, FP64]",
dtype);
}
}
};
class ChannelCreateOpOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasOutput(kOutput),
"The output of ChannelCreate op must be set");
context->SetOutputDim(kOutput, {1});
}
};
class ChannelCreateOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput(kOutput,
"The object of a Channel type created by ChannelCreate Op.");
AddAttr<int>("capacity", "The size of the buffer of Channel.")
.SetDefault(0);
AddAttr<int>("data_type", "The data type of elements inside the Channel.");
AddComment(R"DOC(
Channel Create Operator.
This operator creates an object of the VarType Channel and returns it.
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(channel_create, paddle::operators::ChannelCreateOp,
paddle::framework::EmptyGradOpMaker,
paddle::operators::ChannelCreateOpMaker);

@ -1,98 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include <paddle/fluid/framework/lod_rank_table.h>
#include <paddle/fluid/framework/lod_tensor_array.h>
#include <paddle/fluid/framework/reader.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/concurrency/channel_util.h"
#include "paddle/fluid/operators/math/math_function.h"
static constexpr char Channel[] = "Channel";
static constexpr char Status[] = "Status";
static constexpr char Out[] = "Out";
namespace paddle {
namespace operators {
void SetReceiveStatus(const platform::Place &dev_place,
framework::Variable *status_var, bool status) {
auto cpu = platform::CPUPlace();
auto status_tensor =
status_var->GetMutable<framework::LoDTensor>()->mutable_data<bool>({1},
cpu);
status_tensor[0] = status;
}
class ChannelRecvOp : public framework::OperatorBase {
public:
ChannelRecvOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE(ctx->HasInput(Channel),
"Input(Channel) of ChannelRecvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(Out),
"Input(Channel) of ChannelRecvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(Status),
"Output(Status) of ChannelRecvOp should not be null.");
ctx->SetOutputDim("Status", {1});
}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
// Get the channel holder created by channel_create op, passed as input.
framework::ChannelHolder *ch =
scope.FindVar(Input(Channel))->GetMutable<framework::ChannelHolder>();
auto output_var = scope.FindVar(Output(Out));
// Receive the data from the channel.
bool ok = concurrency::ChannelReceive(ch, output_var);
// Set the status output of the `ChannelReceive` call.
SetReceiveStatus(dev_place, scope.FindVar(Output(Status)), ok);
}
};
class ChannelRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(Channel,
"(Channel) A variable which \"receives\" the a value sent"
"to it by a channel_send op.")
.AsDuplicable();
AddOutput(Out,
"(Variable) Output Variable that will hold the data received"
" from the Channel")
.AsDuplicable();
AddOutput(Status,
"(Tensor) An LoD Tensor that returns a boolean status of the"
"result of the receive operation.")
.AsDuplicable();
AddComment(R"DOC(
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(channel_recv, paddle::operators::ChannelRecvOp,
paddle::framework::EmptyGradOpMaker,
paddle::operators::ChannelRecvOpMaker);

@ -1,76 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/channel.h"
#include <paddle/fluid/framework/lod_rank_table.h>
#include <paddle/fluid/framework/lod_tensor_array.h>
#include <paddle/fluid/framework/reader.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/concurrency/channel_util.h"
#include "paddle/fluid/operators/math/math_function.h"
static constexpr char Channel[] = "Channel";
static constexpr char X[] = "X";
namespace paddle {
namespace operators {
class ChannelSendOp : public framework::OperatorBase {
public:
ChannelSendOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE(ctx->HasInput(Channel),
"Input(Channel) of ChannelSendOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(X),
"Input(X) of ChannelSendOp should not be null.");
}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
// Get the channel holder created by channel_create op, passed as input.
framework::ChannelHolder *ch =
scope.FindVar(Input(Channel))->GetMutable<framework::ChannelHolder>();
auto input_var = scope.FindVar(Input(X));
// Send the input data through the channel.
concurrency::ChannelSend(ch, input_var);
}
};
class ChannelSendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(Channel,
"(Channel) A variable which \"sends\" the passed in value to "
"a listening receiver.")
.AsDuplicable();
AddInput(X, "(Variable) The value which gets sent by the channel.")
.AsDuplicable();
AddComment(R"DOC(
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(channel_send, paddle::operators::ChannelSendOp,
paddle::framework::EmptyGradOpMaker,
paddle::operators::ChannelSendOpMaker);

@ -1 +0,0 @@
cc_library(concurrency SRCS channel_util.cc DEPS device_context framework_proto boost eigen3)

@ -1,111 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/concurrency/channel_util.h"
#include "paddle/fluid/framework/var_type.h"
namespace poc = paddle::operators::concurrency;
void poc::ChannelSend(framework::ChannelHolder *ch, framework::Variable *var) {
auto type = framework::ToVarType(var->Type());
if (type == framework::proto::VarType_Type_LOD_TENSOR)
ch->Send(var->GetMutable<framework::LoDTensor>());
else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE)
ch->Send(var->GetMutable<framework::LoDRankTable>());
else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY)
ch->Send(var->GetMutable<framework::LoDTensorArray>());
else if (type == framework::proto::VarType_Type_SELECTED_ROWS)
ch->Send(var->GetMutable<framework::SelectedRows>());
else if (type == framework::proto::VarType_Type_READER)
ch->Send(var->GetMutable<framework::ReaderHolder>());
else if (type == framework::proto::VarType_Type_CHANNEL)
ch->Send(var->GetMutable<framework::ChannelHolder>());
else
PADDLE_THROW("ChannelSend:Unsupported type");
}
bool poc::ChannelReceive(framework::ChannelHolder *ch,
framework::Variable *var) {
// Get type of channel and use that to call mutable data for Variable
auto type = framework::ToVarType(ch->Type());
if (type == framework::proto::VarType_Type_LOD_TENSOR)
return ch->Receive(var->GetMutable<framework::LoDTensor>());
else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE)
return ch->Receive(var->GetMutable<framework::LoDRankTable>());
else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY)
return ch->Receive(var->GetMutable<framework::LoDTensorArray>());
else if (type == framework::proto::VarType_Type_SELECTED_ROWS)
return ch->Receive(var->GetMutable<framework::SelectedRows>());
else if (type == framework::proto::VarType_Type_READER)
return ch->Receive(var->GetMutable<framework::ReaderHolder>());
else if (type == framework::proto::VarType_Type_CHANNEL)
return ch->Receive(var->GetMutable<framework::ChannelHolder>());
else
PADDLE_THROW("ChannelReceive:Unsupported type");
}
void poc::ChannelAddToSendQ(framework::ChannelHolder *ch, const void *referrer,
framework::Variable *var,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(framework::ChannelAction)> cb) {
auto type = framework::ToVarType(var->Type());
if (type == framework::proto::VarType_Type_LOD_TENSOR) {
ch->AddToSendQ(referrer, var->GetMutable<framework::LoDTensor>(), cond, cb);
} else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) {
ch->AddToSendQ(referrer, var->GetMutable<framework::LoDRankTable>(), cond,
cb);
} else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) {
ch->AddToSendQ(referrer, var->GetMutable<framework::LoDTensorArray>(), cond,
cb);
} else if (type == framework::proto::VarType_Type_SELECTED_ROWS) {
ch->AddToSendQ(referrer, var->GetMutable<framework::SelectedRows>(), cond,
cb);
} else if (type == framework::proto::VarType_Type_READER) {
ch->AddToSendQ(referrer, var->GetMutable<framework::ReaderHolder>(), cond,
cb);
} else if (type == framework::proto::VarType_Type_CHANNEL) {
ch->AddToSendQ(referrer, var->GetMutable<framework::ChannelHolder>(), cond,
cb);
} else {
PADDLE_THROW("ChannelAddToSendQ:Unsupported type");
}
}
void poc::ChannelAddToReceiveQ(
framework::ChannelHolder *ch, const void *referrer,
framework::Variable *var, std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(framework::ChannelAction)> cb) {
auto type = framework::ToVarType(var->Type());
if (type == framework::proto::VarType_Type_LOD_TENSOR) {
ch->AddToReceiveQ(referrer, var->GetMutable<framework::LoDTensor>(), cond,
cb);
} else if (type == framework::proto::VarType_Type_LOD_RANK_TABLE) {
ch->AddToReceiveQ(referrer, var->GetMutable<framework::LoDRankTable>(),
cond, cb);
} else if (type == framework::proto::VarType_Type_LOD_TENSOR_ARRAY) {
ch->AddToReceiveQ(referrer, var->GetMutable<framework::LoDTensorArray>(),
cond, cb);
} else if (type == framework::proto::VarType_Type_SELECTED_ROWS) {
ch->AddToReceiveQ(referrer, var->GetMutable<framework::SelectedRows>(),
cond, cb);
} else if (type == framework::proto::VarType_Type_READER) {
ch->AddToReceiveQ(referrer, var->GetMutable<framework::ReaderHolder>(),
cond, cb);
} else if (type == framework::proto::VarType_Type_CHANNEL) {
ch->AddToReceiveQ(referrer, var->GetMutable<framework::ChannelHolder>(),
cond, cb);
} else {
PADDLE_THROW("ChannelAddToReceiveQ:Unsupported type");
}
}

@ -1,38 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace operators {
namespace concurrency {
void ChannelSend(framework::ChannelHolder *ch, framework::Variable *var);
bool ChannelReceive(framework::ChannelHolder *ch, framework::Variable *var);
void ChannelAddToSendQ(framework::ChannelHolder *ch, const void *referrer,
framework::Variable *var,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(framework::ChannelAction)> cb);
void ChannelAddToReceiveQ(framework::ChannelHolder *ch, const void *referrer,
framework::Variable *var,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(framework::ChannelAction)> cb);
} // namespace concurrency
} // namespace operators
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -214,7 +214,6 @@ void BindVarDsec(pybind11::module *m) {
.def("set_shapes", &pd::VarDesc::SetShapes)
.def("set_dtype", &pd::VarDesc::SetDataType)
.def("set_dtypes", &pd::VarDesc::SetDataTypes)
.def("set_capacity", &pd::VarDesc::SetCapacity)
.def("shape", &pd::VarDesc::GetShape,
pybind11::return_value_policy::reference)
.def("shapes", &pd::VarDesc::GetShapes,
@ -251,7 +250,6 @@ void BindVarDsec(pybind11::module *m) {
.value("STEP_SCOPES", pd::proto::VarType::STEP_SCOPES)
.value("LOD_RANK_TABLE", pd::proto::VarType::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", pd::proto::VarType::LOD_TENSOR_ARRAY)
.value("CHANNEL", pd::proto::VarType::CHANNEL)
.value("PLACE_LIST", pd::proto::VarType::PLACE_LIST)
.value("READER", pd::proto::VarType::READER)
.value("RAW", pd::proto::VarType::RAW);

@ -21,7 +21,6 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"

File diff suppressed because it is too large Load Diff

@ -541,8 +541,7 @@ class Operator(object):
'feed', 'fetch', 'save', 'load', 'recurrent', 'go',
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv',
'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine',
'ncclInit', 'channel_create', 'channel_close', 'channel_send',
'channel_recv', 'select', 'checkpoint_notify', 'gen_nccl_id'
'ncclInit', 'select', 'checkpoint_notify', 'gen_nccl_id'
}
def __init__(self,

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save