Merge pull request #4537 from QiJune/executor_impl

Executor interface design and implementation
revert-4814-Add_sequence_project_op
Yang Yang(Tony) 8 years ago committed by GitHub
commit c3bf332666

@ -42,5 +42,12 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward ${GLOB_OP_LIB})
if(WITH_GPU)
nv_test(executor_test SRCS executor_test.cc DEPS executor)
else()
cc_test(executor_test SRCS executor_test.cc DEPS executor)
endif()
cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor)
cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place)

@ -0,0 +1,165 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/executor.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <set>
#include <vector>
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/scope.h"
#include <boost/range/adaptor/reversed.hpp>
namespace paddle {
namespace framework {
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
Executor::Executor(const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
device_contexts_.resize(places.size());
for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) {
device_contexts_[i] = new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i]));
} else if (platform::is_gpu_place(places[i])) {
#ifdef PADDLE_WITH_CUDA
device_contexts_[i] = new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i]));
#else
PADDLE_THROW(
"'GPUPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
}
}
}
Executor::~Executor() {
for (auto& device_context : device_contexts_) {
delete device_context;
}
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op
PADDLE_ENFORCE_GT(pdesc.blocks_size(), block_id);
auto& block = pdesc.blocks(block_id);
auto& device = device_contexts_[0];
// Instantiate all the vars in the global scope
for (auto& var : block.vars()) {
scope->NewVar(var.name());
}
Scope& local_scope = scope->NewScope();
std::vector<bool> should_run = Prune(pdesc, block_id);
PADDLE_ENFORCE_EQ(should_run.size(), static_cast<size_t>(block.ops_size()));
for (size_t i = 0; i < should_run.size(); ++i) {
if (should_run[i]) {
for (auto& var : block.ops(i).outputs()) {
for (auto& argu : var.arguments()) {
if (local_scope.FindVar(argu) == nullptr) {
local_scope.NewVar(argu);
}
}
}
auto op = paddle::framework::OpRegistry::CreateOp(block.ops(i));
op->Run(local_scope, *device);
}
}
// TODO(tonyyang-svail):
// - Destroy local_scope
}
std::vector<bool> Prune(const ProgramDesc& pdesc, int block_id) {
// TODO(tonyyang-svail):
// - will change to use multiple blocks for RNN op and Cond Op
auto& block = pdesc.blocks(block_id);
auto& ops = block.ops();
bool expect_feed = true;
for (auto& op_desc : ops) {
PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed,
"All FeedOps are at the beginning of the ProgramDesc");
expect_feed = (op_desc.type() == kFeedOpType);
}
bool expect_fetch = true;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch,
"All FetchOps must at the end of the ProgramDesc");
expect_fetch = (op_desc.type() == kFetchOpType);
}
std::set<std::string> dependent_vars;
std::vector<bool> should_run;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
bool found_dependent_vars = false;
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
if (dependent_vars.count(argu) != 0) {
found_dependent_vars = true;
}
}
}
if (op_desc.type() == kFetchOpType || found_dependent_vars) {
// erase its output to the dependency graph
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
dependent_vars.erase(argu);
}
}
// insert its input to the dependency graph
for (auto& var : op_desc.inputs()) {
for (auto& argu : var.arguments()) {
dependent_vars.insert(argu);
}
}
should_run.push_back(true);
} else {
should_run.push_back(false);
}
}
// TODO(tonyyang-svail):
// - check this after integration of Init
// PADDLE_ENFORCE(dependent_vars.empty());
// since we are traversing the ProgramDesc in reverse order
// we reverse the should_run vector
std::reverse(should_run.begin(), should_run.end());
return should_run;
}
} // namespace framework
} // namespace paddle

@ -0,0 +1,55 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace framework {
class Executor {
public:
explicit Executor(const std::vector<platform::Place>& places);
~Executor();
/* @Brief
* Runtime evaluation of the given ProgramDesc under certain Scope
*
* @param
* ProgramDesc
* Scope
*/
void Run(const ProgramDesc&, Scope*, int);
private:
std::vector<platform::DeviceContext*> device_contexts_;
};
/* @Brief
* Pruning the graph
*
* @param
* ProgramDesc
*
* @return
* vector<bool> Same size as ops. Indicates whether an op should be run.
*/
std::vector<bool> Prune(const ProgramDesc& pdesc, int block_id);
} // namespace framework
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/scope.h"
#include <memory> // for unique_ptr
#include <mutex> // for call_once
#include "paddle/string/printf.h"
namespace paddle {
@ -62,5 +65,17 @@ void Scope::DropKids() {
kids_.clear();
}
std::once_flag feed_variable_flag;
framework::Scope& GetGlobalScope() {
static std::unique_ptr<framework::Scope> g_scope{nullptr};
std::call_once(feed_variable_flag, [&]() {
g_scope.reset(new framework::Scope());
g_scope->NewVar("feed_value");
g_scope->NewVar("fetch_value");
});
return *(g_scope.get());
}
} // namespace framework
} // namespace paddle

@ -73,5 +73,7 @@ class Scope {
DISABLE_COPY_AND_ASSIGN(Scope);
};
framework::Scope& GetGlobalScope();
} // namespace framework
} // namespace paddle

@ -0,0 +1,59 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/feed_op.h"
namespace paddle {
namespace operators {
class FeedOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null.");
auto& shape = ctx->Attrs().Get<std::vector<int>>("dims");
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
ctx->SetOutputDim("Out", framework::make_ddim(shape_int64));
// TODO(qijun): need to handle LodTensor later
}
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("dataType"));
}
};
class FeedOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("dataType", "output data type")
.SetDefault(framework::DataType::FP32);
AddAttr<int>("col", "The col in global feed variable").SetDefault(0);
AddAttr<std::vector<int>>("dims", "The dimension of feed tensor.");
AddOutput("Out", "The output of feed op.");
AddComment(R"DOC(Feed data from global feed variable)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(feed, ops::FeedOp, ops::FeedOpMaker);
REGISTER_OP_CPU_KERNEL(feed, ops::FeedKernel<float>);

@ -0,0 +1,18 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/feed_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(feed, ops::FeedKernel<float>);

@ -0,0 +1,42 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class FeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
framework::Variable* g_feed_variable =
framework::GetGlobalScope().FindVar("feed_value");
const auto& tensors =
g_feed_variable->Get<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col");
PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
// TODO(qijun):
// check tensors[col].dims() with attribute,
// except the first dimenson.
out->CopyFrom<T>(tensors[col], ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle

@ -0,0 +1,52 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/fetch_op.h"
namespace paddle {
namespace operators {
class FetchOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null.");
}
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("dataType"));
}
};
class FetchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("dataType", "output data type")
.SetDefault(framework::DataType::FP32);
AddAttr<int>("col", "The col in global fetch variable").SetDefault(0);
AddInput("Input", "The output of fetch op.");
AddComment(R"DOC(Fetch data to global fetch variable)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fetch, ops::FetchOp, ops::FetchOpMaker);
REGISTER_OP_CPU_KERNEL(fetch, ops::FetchKernel<float>);

@ -0,0 +1,18 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/fetch_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel<float>);

@ -0,0 +1,44 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class FetchKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor* input = ctx.Input<framework::Tensor>("Input");
framework::Variable* g_fetch_variable =
framework::GetGlobalScope().FindVar("fetch_value");
auto* tensors =
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col");
if (tensors->size() < static_cast<size_t>(col + 1)) {
tensors->resize(col + 1);
}
PADDLE_ENFORCE_GT(tensors->size(), static_cast<size_t>(col));
(*tensors)[col].Resize(input->dims());
(*tensors)[col].mutable_data<T>(platform::CPUPlace());
(*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace());
// TODO(qijun): need to handle LodTensor later
}
};
} // namespace operators
} // namespace paddle

@ -43,6 +43,8 @@ int GetCurrentDeviceId() {
}
void SetDeviceId(int id) {
// TODO(qijun): find a better way to cache the cuda device count
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
PADDLE_ENFORCE(cudaSetDevice(id),
"cudaSetDevice failed in paddle::platform::SetDeviceId");
}

Loading…
Cancel
Save