Merge pull request #4537 from QiJune/executor_impl
Executor interface design and implementationrevert-4814-Add_sequence_project_op
commit
c3bf332666
@ -0,0 +1,165 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/executor.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/framework/scope.h"
|
||||
|
||||
#include <boost/range/adaptor/reversed.hpp>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
const std::string kFeedOpType = "feed";
|
||||
const std::string kFetchOpType = "fetch";
|
||||
|
||||
Executor::Executor(const std::vector<platform::Place>& places) {
|
||||
PADDLE_ENFORCE_GT(places.size(), 0);
|
||||
device_contexts_.resize(places.size());
|
||||
for (size_t i = 0; i < places.size(); i++) {
|
||||
if (platform::is_cpu_place(places[i])) {
|
||||
device_contexts_[i] = new platform::CPUDeviceContext(
|
||||
boost::get<platform::CPUPlace>(places[i]));
|
||||
} else if (platform::is_gpu_place(places[i])) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
device_contexts_[i] = new platform::CUDADeviceContext(
|
||||
boost::get<platform::GPUPlace>(places[i]));
|
||||
#else
|
||||
PADDLE_THROW(
|
||||
"'GPUPlace' is not supported, Please re-compile with WITH_GPU "
|
||||
"option");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Executor::~Executor() {
|
||||
for (auto& device_context : device_contexts_) {
|
||||
delete device_context;
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
|
||||
// TODO(tonyyang-svail):
|
||||
// - only runs on the first device (i.e. no interdevice communication)
|
||||
// - will change to use multiple blocks for RNN op and Cond Op
|
||||
PADDLE_ENFORCE_GT(pdesc.blocks_size(), block_id);
|
||||
auto& block = pdesc.blocks(block_id);
|
||||
auto& device = device_contexts_[0];
|
||||
|
||||
// Instantiate all the vars in the global scope
|
||||
for (auto& var : block.vars()) {
|
||||
scope->NewVar(var.name());
|
||||
}
|
||||
|
||||
Scope& local_scope = scope->NewScope();
|
||||
|
||||
std::vector<bool> should_run = Prune(pdesc, block_id);
|
||||
PADDLE_ENFORCE_EQ(should_run.size(), static_cast<size_t>(block.ops_size()));
|
||||
for (size_t i = 0; i < should_run.size(); ++i) {
|
||||
if (should_run[i]) {
|
||||
for (auto& var : block.ops(i).outputs()) {
|
||||
for (auto& argu : var.arguments()) {
|
||||
if (local_scope.FindVar(argu) == nullptr) {
|
||||
local_scope.NewVar(argu);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto op = paddle::framework::OpRegistry::CreateOp(block.ops(i));
|
||||
op->Run(local_scope, *device);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(tonyyang-svail):
|
||||
// - Destroy local_scope
|
||||
}
|
||||
|
||||
std::vector<bool> Prune(const ProgramDesc& pdesc, int block_id) {
|
||||
// TODO(tonyyang-svail):
|
||||
// - will change to use multiple blocks for RNN op and Cond Op
|
||||
|
||||
auto& block = pdesc.blocks(block_id);
|
||||
auto& ops = block.ops();
|
||||
|
||||
bool expect_feed = true;
|
||||
for (auto& op_desc : ops) {
|
||||
PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed,
|
||||
"All FeedOps are at the beginning of the ProgramDesc");
|
||||
expect_feed = (op_desc.type() == kFeedOpType);
|
||||
}
|
||||
|
||||
bool expect_fetch = true;
|
||||
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
|
||||
auto& op_desc = *op_iter;
|
||||
PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch,
|
||||
"All FetchOps must at the end of the ProgramDesc");
|
||||
expect_fetch = (op_desc.type() == kFetchOpType);
|
||||
}
|
||||
|
||||
std::set<std::string> dependent_vars;
|
||||
std::vector<bool> should_run;
|
||||
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
|
||||
auto& op_desc = *op_iter;
|
||||
|
||||
bool found_dependent_vars = false;
|
||||
for (auto& var : op_desc.outputs()) {
|
||||
for (auto& argu : var.arguments()) {
|
||||
if (dependent_vars.count(argu) != 0) {
|
||||
found_dependent_vars = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (op_desc.type() == kFetchOpType || found_dependent_vars) {
|
||||
// erase its output to the dependency graph
|
||||
for (auto& var : op_desc.outputs()) {
|
||||
for (auto& argu : var.arguments()) {
|
||||
dependent_vars.erase(argu);
|
||||
}
|
||||
}
|
||||
|
||||
// insert its input to the dependency graph
|
||||
for (auto& var : op_desc.inputs()) {
|
||||
for (auto& argu : var.arguments()) {
|
||||
dependent_vars.insert(argu);
|
||||
}
|
||||
}
|
||||
|
||||
should_run.push_back(true);
|
||||
} else {
|
||||
should_run.push_back(false);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(tonyyang-svail):
|
||||
// - check this after integration of Init
|
||||
// PADDLE_ENFORCE(dependent_vars.empty());
|
||||
|
||||
// since we are traversing the ProgramDesc in reverse order
|
||||
// we reverse the should_run vector
|
||||
std::reverse(should_run.begin(), should_run.end());
|
||||
|
||||
return should_run;
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,55 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/framework.pb.h"
|
||||
#include "paddle/framework/op_info.h"
|
||||
#include "paddle/framework/scope.h"
|
||||
#include "paddle/framework/tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class Executor {
|
||||
public:
|
||||
explicit Executor(const std::vector<platform::Place>& places);
|
||||
~Executor();
|
||||
|
||||
/* @Brief
|
||||
* Runtime evaluation of the given ProgramDesc under certain Scope
|
||||
*
|
||||
* @param
|
||||
* ProgramDesc
|
||||
* Scope
|
||||
*/
|
||||
void Run(const ProgramDesc&, Scope*, int);
|
||||
|
||||
private:
|
||||
std::vector<platform::DeviceContext*> device_contexts_;
|
||||
};
|
||||
|
||||
/* @Brief
|
||||
* Pruning the graph
|
||||
*
|
||||
* @param
|
||||
* ProgramDesc
|
||||
*
|
||||
* @return
|
||||
* vector<bool> Same size as ops. Indicates whether an op should be run.
|
||||
*/
|
||||
std::vector<bool> Prune(const ProgramDesc& pdesc, int block_id);
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,59 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/feed_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class FeedOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null.");
|
||||
auto& shape = ctx->Attrs().Get<std::vector<int>>("dims");
|
||||
std::vector<int64_t> shape_int64(shape.size(), 0);
|
||||
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
|
||||
[](int a) { return static_cast<int64_t>(a); });
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(shape_int64));
|
||||
// TODO(qijun): need to handle LodTensor later
|
||||
}
|
||||
|
||||
framework::DataType IndicateDataType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return static_cast<framework::DataType>(Attr<int>("dataType"));
|
||||
}
|
||||
};
|
||||
|
||||
class FeedOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddAttr<int>("dataType", "output data type")
|
||||
.SetDefault(framework::DataType::FP32);
|
||||
AddAttr<int>("col", "The col in global feed variable").SetDefault(0);
|
||||
AddAttr<std::vector<int>>("dims", "The dimension of feed tensor.");
|
||||
AddOutput("Out", "The output of feed op.");
|
||||
AddComment(R"DOC(Feed data from global feed variable)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(feed, ops::FeedOp, ops::FeedOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(feed, ops::FeedKernel<float>);
|
@ -0,0 +1,18 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/feed_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(feed, ops::FeedKernel<float>);
|
@ -0,0 +1,42 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
class FeedKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
|
||||
out->mutable_data<T>(ctx.GetPlace());
|
||||
framework::Variable* g_feed_variable =
|
||||
framework::GetGlobalScope().FindVar("feed_value");
|
||||
const auto& tensors =
|
||||
g_feed_variable->Get<std::vector<framework::Tensor>>();
|
||||
int col = ctx.template Attr<int>("col");
|
||||
PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
|
||||
// TODO(qijun):
|
||||
// check tensors[col].dims() with attribute,
|
||||
// except the first dimenson.
|
||||
out->CopyFrom<T>(tensors[col], ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,52 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/fetch_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class FetchOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null.");
|
||||
}
|
||||
|
||||
framework::DataType IndicateDataType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return static_cast<framework::DataType>(Attr<int>("dataType"));
|
||||
}
|
||||
};
|
||||
|
||||
class FetchOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddAttr<int>("dataType", "output data type")
|
||||
.SetDefault(framework::DataType::FP32);
|
||||
AddAttr<int>("col", "The col in global fetch variable").SetDefault(0);
|
||||
AddInput("Input", "The output of fetch op.");
|
||||
AddComment(R"DOC(Fetch data to global fetch variable)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(fetch, ops::FetchOp, ops::FetchOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(fetch, ops::FetchKernel<float>);
|
@ -0,0 +1,18 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/fetch_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel<float>);
|
@ -0,0 +1,44 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
class FetchKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const framework::Tensor* input = ctx.Input<framework::Tensor>("Input");
|
||||
framework::Variable* g_fetch_variable =
|
||||
framework::GetGlobalScope().FindVar("fetch_value");
|
||||
auto* tensors =
|
||||
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
|
||||
int col = ctx.template Attr<int>("col");
|
||||
if (tensors->size() < static_cast<size_t>(col + 1)) {
|
||||
tensors->resize(col + 1);
|
||||
}
|
||||
PADDLE_ENFORCE_GT(tensors->size(), static_cast<size_t>(col));
|
||||
(*tensors)[col].Resize(input->dims());
|
||||
(*tensors)[col].mutable_data<T>(platform::CPUPlace());
|
||||
(*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace());
|
||||
// TODO(qijun): need to handle LodTensor later
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Loading…
Reference in new issue