revert-4814-Add_sequence_project_op
commit
c22e7ff71e
@ -0,0 +1,105 @@
|
|||||||
|
## Optimizer Design
|
||||||
|
|
||||||
|
### The Problem
|
||||||
|
|
||||||
|
A PaddlePaddle program, or a block, is a sequence of operators operating variables. A training program needs to do three kinds of works:
|
||||||
|
|
||||||
|
1. the forward pass, which computes intermediate results and the cost(s),
|
||||||
|
1. the backward pass, which derives gradients from intermediate results and costs, and
|
||||||
|
1. the optimization pass, which update model parameters to optimize the cost(s).
|
||||||
|
|
||||||
|
These works rely on three kinds of operators:
|
||||||
|
|
||||||
|
1. forward operators,
|
||||||
|
1. gradient operators, and
|
||||||
|
1. optimization operators.
|
||||||
|
|
||||||
|
It's true that users should be able to create all these operators manually by calling some low-level API, but it would be much more convenient if they could only describe the forward pass and let PaddlePaddle create the backward and optimization operators automatically.
|
||||||
|
|
||||||
|
In this design, we propose a high-level API that automatically derives the optimisation pass and operators from the forward pass.
|
||||||
|
|
||||||
|
|
||||||
|
### High-level Python API to describe the training process
|
||||||
|
|
||||||
|
1. User write code to describe the network:
|
||||||
|
|
||||||
|
```python
|
||||||
|
images = layer.data("images")
|
||||||
|
labels = layer.data("labels")
|
||||||
|
w1 = pd.var("w1")
|
||||||
|
b1 = pd.var("b1")
|
||||||
|
hidden = layer.fc(images, w=w1, b=b1)
|
||||||
|
cost = layer.mse(hidden, labels)
|
||||||
|
```
|
||||||
|
|
||||||
|
The above code snippet will create forward operators in [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md).
|
||||||
|
|
||||||
|
|
||||||
|
2. Users create a certain kind of Optimizer with some argument.
|
||||||
|
|
||||||
|
```python
|
||||||
|
optimizer = AdagradOptimizer(learing_rate=0.001)
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Users use the optimizer to `minimize` a certain `cost` through updating parameters in parameter_list.
|
||||||
|
|
||||||
|
```python
|
||||||
|
opt_op_list = optimizer.minimize(cost, parameter_list=[w1, b1])
|
||||||
|
```
|
||||||
|
The above code snippet will create gradient and optimization operators in Block. The return value of `minimize()` is list of optimization operators that will be run by session.
|
||||||
|
|
||||||
|
4. Users use Session/Executor to run this opt_op_list as target to do training.
|
||||||
|
|
||||||
|
```python
|
||||||
|
sess.run(target= opt_op_list, ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Optimizer Python interface:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Optimizer(object):
|
||||||
|
"""Optimizer Base class.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create_backward_pass(self, loss, parameter_list=None):
|
||||||
|
"""
|
||||||
|
create and add gradient Operators in BlockDesc to Compute gradients of `loss`
|
||||||
|
for parameters in parameter_list
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss: an variable generated by cost function.
|
||||||
|
parameter_list: parameters that need to compute gradient and update to optimize the lost.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of (parameters, gradients) pair.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def create_optimization_pass(self, parameters_and_grads):
|
||||||
|
"""Add optimization operators to update gradients to variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parameters_and_grads: a list of (variable, gradient) pair to update.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
optmization_op_list: a list of optimization operator that will update parameter using gradient.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def minimize(self, loss, parameter_list):
|
||||||
|
"""Add operations to minimize `loss` by updating `parameter_list`.
|
||||||
|
|
||||||
|
This method combines interface `create_backward_pass()` and
|
||||||
|
`create_optimization_pass()` into one.
|
||||||
|
"""
|
||||||
|
params_grads = self.create_backward_pass(loss, parameter_list)
|
||||||
|
update_ops = self.create_optimization_pass(params_grads)
|
||||||
|
return update_ops
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Users can inherit the Optimizer above to create their own Optimizer with some special logic, such as AdagradOptimizer.
|
@ -0,0 +1,74 @@
|
|||||||
|
# Design Doc: Selected Rows
|
||||||
|
|
||||||
|
`SelectedRows` is a kind of sparse tensor data type, which is designed to support `embedding` operators. The gradient of embedding table is a sparse tensor. Only a few rows are non-zero values in that tensor. It is straightforward to represent the sparse tensor by the following sparse tensor data structure:
|
||||||
|
|
||||||
|
```cpp
|
||||||
|
class SelectedRows {
|
||||||
|
private:
|
||||||
|
vector<int> rows_;
|
||||||
|
Tensor value_;
|
||||||
|
int height_;
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
The field `height_` shows the first dimension of `SelectedRows`. The `rows` are the indices of which rows of `SelectedRows` are non-zeros. The `value_` field is an N-dim tensor and shape is `[rows.size() /* NUM_ROWS */, ...]`, which supplies values for each row. The dimension of `SelectedRows` satisfies `[height_] + value_.shape[1:]`.
|
||||||
|
|
||||||
|
Suppose that a SelectedRows-typed variable `x` has many rows, but only two of them have values -- row 73 is `[1, 2]` and row 84 is `[3, 4]`, the `SelectedRows` representation would be:
|
||||||
|
|
||||||
|
```
|
||||||
|
x = SelectedRow {
|
||||||
|
rows = [73, 84],
|
||||||
|
value = [[1, 2], [3,4]]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## SelectedRows in Protobuf
|
||||||
|
|
||||||
|
`SelectedRows` is a kind of `Variable`. `VarDesc` in protobuf should describe the `SelectedRows` information. Only the tensor dimension of a `SelectedRows` will be described in compile-time since the `rows_` and `value_` are related to training data.
|
||||||
|
So we use `TensorDesc` to unify `data_type` and `dims`. A LodTensorDesc contains a `TensorDesc` and `lod_level`. The description of `SelectedRows` is a Tensor description.
|
||||||
|
|
||||||
|
```proto
|
||||||
|
message TensorDesc {
|
||||||
|
required DataType data_type = 1;
|
||||||
|
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
|
||||||
|
}
|
||||||
|
|
||||||
|
message LodTensorDesc {
|
||||||
|
required TensorDesc tensor = 1;
|
||||||
|
optional int lod_level = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message VarDesc {
|
||||||
|
required string name = 1;
|
||||||
|
enum VarType {
|
||||||
|
LOD_TENSOR = 0;
|
||||||
|
SELECTED_ROWS = 1;
|
||||||
|
}
|
||||||
|
required VarType type = 2;
|
||||||
|
optional LodTensorDesc lod_desc = 3;
|
||||||
|
optional TensorDesc selected_rows_desc = 4;
|
||||||
|
optional bool persistable = 5 [ default = false ];
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## InferShape for Selected Rows
|
||||||
|
|
||||||
|
Just like `LoD` information, `InferShape` method will inference output tensor type as well. The operator should decide whether its output is a `SelectedRows` or `Dense` tensor.
|
||||||
|
|
||||||
|
For example, the gradient operator of `TableLookup` will always generate `SelectedRows`. Its `InferShape` method should be like following
|
||||||
|
|
||||||
|
```cpp
|
||||||
|
void TableLookupGrad::InferShape(context) {
|
||||||
|
...
|
||||||
|
context.SetDataType("Embedding.Grad", kSelectedRows);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Sparse Operators
|
||||||
|
|
||||||
|
There are several operators should be written to support `SelectedRows`. They are:
|
||||||
|
|
||||||
|
1. Operators which generates `SelectedRows` gradient. e.g. Gradient of `TableLookupOp`.
|
||||||
|
2. Optimize operators which support `SelectedRows` gradient. e.g. `SGD` or `AdaGrad` for `SelectedRows`. However, there should be only one `SGD` operator. `OpWithKernel::Run` should select a suitable kernel for both `dense` tensor or `SelectedRows`.
|
@ -0,0 +1,165 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/framework/executor.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/framework/lod_tensor.h"
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
#include "paddle/framework/scope.h"
|
||||||
|
|
||||||
|
#include <boost/range/adaptor/reversed.hpp>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
const std::string kFeedOpType = "feed";
|
||||||
|
const std::string kFetchOpType = "fetch";
|
||||||
|
|
||||||
|
Executor::Executor(const std::vector<platform::Place>& places) {
|
||||||
|
PADDLE_ENFORCE_GT(places.size(), 0);
|
||||||
|
device_contexts_.resize(places.size());
|
||||||
|
for (size_t i = 0; i < places.size(); i++) {
|
||||||
|
if (platform::is_cpu_place(places[i])) {
|
||||||
|
device_contexts_[i] = new platform::CPUDeviceContext(
|
||||||
|
boost::get<platform::CPUPlace>(places[i]));
|
||||||
|
} else if (platform::is_gpu_place(places[i])) {
|
||||||
|
#ifdef PADDLE_WITH_CUDA
|
||||||
|
device_contexts_[i] = new platform::CUDADeviceContext(
|
||||||
|
boost::get<platform::GPUPlace>(places[i]));
|
||||||
|
#else
|
||||||
|
PADDLE_THROW(
|
||||||
|
"'GPUPlace' is not supported, Please re-compile with WITH_GPU "
|
||||||
|
"option");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Executor::~Executor() {
|
||||||
|
for (auto& device_context : device_contexts_) {
|
||||||
|
delete device_context;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
|
||||||
|
// TODO(tonyyang-svail):
|
||||||
|
// - only runs on the first device (i.e. no interdevice communication)
|
||||||
|
// - will change to use multiple blocks for RNN op and Cond Op
|
||||||
|
PADDLE_ENFORCE_GT(pdesc.blocks_size(), block_id);
|
||||||
|
auto& block = pdesc.blocks(block_id);
|
||||||
|
auto& device = device_contexts_[0];
|
||||||
|
|
||||||
|
// Instantiate all the vars in the global scope
|
||||||
|
for (auto& var : block.vars()) {
|
||||||
|
scope->NewVar(var.name());
|
||||||
|
}
|
||||||
|
|
||||||
|
Scope& local_scope = scope->NewScope();
|
||||||
|
|
||||||
|
std::vector<bool> should_run = Prune(pdesc, block_id);
|
||||||
|
PADDLE_ENFORCE_EQ(should_run.size(), static_cast<size_t>(block.ops_size()));
|
||||||
|
for (size_t i = 0; i < should_run.size(); ++i) {
|
||||||
|
if (should_run[i]) {
|
||||||
|
for (auto& var : block.ops(i).outputs()) {
|
||||||
|
for (auto& argu : var.arguments()) {
|
||||||
|
if (local_scope.FindVar(argu) == nullptr) {
|
||||||
|
local_scope.NewVar(argu);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto op = paddle::framework::OpRegistry::CreateOp(block.ops(i));
|
||||||
|
op->Run(local_scope, *device);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(tonyyang-svail):
|
||||||
|
// - Destroy local_scope
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<bool> Prune(const ProgramDesc& pdesc, int block_id) {
|
||||||
|
// TODO(tonyyang-svail):
|
||||||
|
// - will change to use multiple blocks for RNN op and Cond Op
|
||||||
|
|
||||||
|
auto& block = pdesc.blocks(block_id);
|
||||||
|
auto& ops = block.ops();
|
||||||
|
|
||||||
|
bool expect_feed = true;
|
||||||
|
for (auto& op_desc : ops) {
|
||||||
|
PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed,
|
||||||
|
"All FeedOps are at the beginning of the ProgramDesc");
|
||||||
|
expect_feed = (op_desc.type() == kFeedOpType);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool expect_fetch = true;
|
||||||
|
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
|
||||||
|
auto& op_desc = *op_iter;
|
||||||
|
PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch,
|
||||||
|
"All FetchOps must at the end of the ProgramDesc");
|
||||||
|
expect_fetch = (op_desc.type() == kFetchOpType);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::set<std::string> dependent_vars;
|
||||||
|
std::vector<bool> should_run;
|
||||||
|
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
|
||||||
|
auto& op_desc = *op_iter;
|
||||||
|
|
||||||
|
bool found_dependent_vars = false;
|
||||||
|
for (auto& var : op_desc.outputs()) {
|
||||||
|
for (auto& argu : var.arguments()) {
|
||||||
|
if (dependent_vars.count(argu) != 0) {
|
||||||
|
found_dependent_vars = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_desc.type() == kFetchOpType || found_dependent_vars) {
|
||||||
|
// erase its output to the dependency graph
|
||||||
|
for (auto& var : op_desc.outputs()) {
|
||||||
|
for (auto& argu : var.arguments()) {
|
||||||
|
dependent_vars.erase(argu);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert its input to the dependency graph
|
||||||
|
for (auto& var : op_desc.inputs()) {
|
||||||
|
for (auto& argu : var.arguments()) {
|
||||||
|
dependent_vars.insert(argu);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
should_run.push_back(true);
|
||||||
|
} else {
|
||||||
|
should_run.push_back(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(tonyyang-svail):
|
||||||
|
// - check this after integration of Init
|
||||||
|
// PADDLE_ENFORCE(dependent_vars.empty());
|
||||||
|
|
||||||
|
// since we are traversing the ProgramDesc in reverse order
|
||||||
|
// we reverse the should_run vector
|
||||||
|
std::reverse(should_run.begin(), should_run.end());
|
||||||
|
|
||||||
|
return should_run;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,55 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/framework/framework.pb.h"
|
||||||
|
#include "paddle/framework/op_info.h"
|
||||||
|
#include "paddle/framework/scope.h"
|
||||||
|
#include "paddle/framework/tensor.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
class Executor {
|
||||||
|
public:
|
||||||
|
explicit Executor(const std::vector<platform::Place>& places);
|
||||||
|
~Executor();
|
||||||
|
|
||||||
|
/* @Brief
|
||||||
|
* Runtime evaluation of the given ProgramDesc under certain Scope
|
||||||
|
*
|
||||||
|
* @param
|
||||||
|
* ProgramDesc
|
||||||
|
* Scope
|
||||||
|
*/
|
||||||
|
void Run(const ProgramDesc&, Scope*, int);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<platform::DeviceContext*> device_contexts_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/* @Brief
|
||||||
|
* Pruning the graph
|
||||||
|
*
|
||||||
|
* @param
|
||||||
|
* ProgramDesc
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
* vector<bool> Same size as ops. Indicates whether an op should be run.
|
||||||
|
*/
|
||||||
|
std::vector<bool> Prune(const ProgramDesc& pdesc, int block_id);
|
||||||
|
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,59 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/feed_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class FeedOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null.");
|
||||||
|
auto& shape = ctx->Attrs().Get<std::vector<int>>("dims");
|
||||||
|
std::vector<int64_t> shape_int64(shape.size(), 0);
|
||||||
|
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
|
||||||
|
[](int a) { return static_cast<int64_t>(a); });
|
||||||
|
ctx->SetOutputDim("Out", framework::make_ddim(shape_int64));
|
||||||
|
// TODO(qijun): need to handle LodTensor later
|
||||||
|
}
|
||||||
|
|
||||||
|
framework::DataType IndicateDataType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
return static_cast<framework::DataType>(Attr<int>("dataType"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class FeedOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddAttr<int>("dataType", "output data type")
|
||||||
|
.SetDefault(framework::DataType::FP32);
|
||||||
|
AddAttr<int>("col", "The col in global feed variable").SetDefault(0);
|
||||||
|
AddAttr<std::vector<int>>("dims", "The dimension of feed tensor.");
|
||||||
|
AddOutput("Out", "The output of feed op.");
|
||||||
|
AddComment(R"DOC(Feed data from global feed variable)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_WITHOUT_GRADIENT(feed, ops::FeedOp, ops::FeedOpMaker);
|
||||||
|
REGISTER_OP_CPU_KERNEL(feed, ops::FeedKernel<float>);
|
@ -0,0 +1,18 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/feed_op.h"
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_GPU_KERNEL(feed, ops::FeedKernel<float>);
|
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "paddle/framework/eigen.h"
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class FeedKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
|
||||||
|
out->mutable_data<T>(ctx.GetPlace());
|
||||||
|
framework::Variable* g_feed_variable =
|
||||||
|
framework::GetGlobalScope().FindVar("feed_value");
|
||||||
|
const auto& tensors =
|
||||||
|
g_feed_variable->Get<std::vector<framework::Tensor>>();
|
||||||
|
int col = ctx.template Attr<int>("col");
|
||||||
|
PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
|
||||||
|
// TODO(qijun):
|
||||||
|
// check tensors[col].dims() with attribute,
|
||||||
|
// except the first dimenson.
|
||||||
|
out->CopyFrom<T>(tensors[col], ctx.GetPlace());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,52 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/fetch_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class FetchOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null.");
|
||||||
|
}
|
||||||
|
|
||||||
|
framework::DataType IndicateDataType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
return static_cast<framework::DataType>(Attr<int>("dataType"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class FetchOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddAttr<int>("dataType", "output data type")
|
||||||
|
.SetDefault(framework::DataType::FP32);
|
||||||
|
AddAttr<int>("col", "The col in global fetch variable").SetDefault(0);
|
||||||
|
AddInput("Input", "The output of fetch op.");
|
||||||
|
AddComment(R"DOC(Fetch data to global fetch variable)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_WITHOUT_GRADIENT(fetch, ops::FetchOp, ops::FetchOpMaker);
|
||||||
|
REGISTER_OP_CPU_KERNEL(fetch, ops::FetchKernel<float>);
|
@ -0,0 +1,18 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/fetch_op.h"
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel<float>);
|
@ -0,0 +1,44 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "paddle/framework/eigen.h"
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class FetchKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
const framework::Tensor* input = ctx.Input<framework::Tensor>("Input");
|
||||||
|
framework::Variable* g_fetch_variable =
|
||||||
|
framework::GetGlobalScope().FindVar("fetch_value");
|
||||||
|
auto* tensors =
|
||||||
|
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
|
||||||
|
int col = ctx.template Attr<int>("col");
|
||||||
|
if (tensors->size() < static_cast<size_t>(col + 1)) {
|
||||||
|
tensors->resize(col + 1);
|
||||||
|
}
|
||||||
|
PADDLE_ENFORCE_GT(tensors->size(), static_cast<size_t>(col));
|
||||||
|
(*tensors)[col].Resize(input->dims());
|
||||||
|
(*tensors)[col].mutable_data<T>(platform::CPUPlace());
|
||||||
|
(*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace());
|
||||||
|
// TODO(qijun): need to handle LodTensor later
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,155 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/math/vol2col.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* vol = [input_channels, input_depth, input_height, input_width]
|
||||||
|
* col =
|
||||||
|
* [input_channels, filter_depth, filter_height, filter_width,
|
||||||
|
* output_depth, output_height, output_width]
|
||||||
|
*/
|
||||||
|
template <class T>
|
||||||
|
class Vol2ColFunctor<platform::CPUPlace, T> {
|
||||||
|
public:
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::Tensor& vol, framework::Tensor& col,
|
||||||
|
int stride_depth, int stride_height, int stride_width,
|
||||||
|
int padding_depth, int padding_height,
|
||||||
|
int padding_width) const {
|
||||||
|
PADDLE_ENFORCE(vol.dims().size() == 4);
|
||||||
|
PADDLE_ENFORCE(col.dims().size() == 7);
|
||||||
|
|
||||||
|
int input_channels = vol.dims()[0];
|
||||||
|
int input_depth = vol.dims()[1];
|
||||||
|
int input_height = vol.dims()[2];
|
||||||
|
int input_width = vol.dims()[3];
|
||||||
|
int filter_depth = col.dims()[1];
|
||||||
|
int filter_height = col.dims()[2];
|
||||||
|
int filter_width = col.dims()[3];
|
||||||
|
int output_depth = col.dims()[4];
|
||||||
|
int output_height = col.dims()[5];
|
||||||
|
int output_width = col.dims()[6];
|
||||||
|
int channels_col =
|
||||||
|
input_channels * filter_depth * filter_height * filter_width;
|
||||||
|
|
||||||
|
const T* vol_data = vol.data<T>();
|
||||||
|
T* col_data = col.data<T>();
|
||||||
|
|
||||||
|
for (int c = 0; c < channels_col; ++c) {
|
||||||
|
int w_offset = c % filter_width;
|
||||||
|
int h_offset = (c / filter_width) % filter_height;
|
||||||
|
int d_offset = (c / filter_width / filter_height) % filter_depth;
|
||||||
|
int c_in = c / filter_width / filter_height / filter_depth;
|
||||||
|
for (int d = 0; d < output_depth; ++d) {
|
||||||
|
int d_pad = d * stride_depth - padding_depth + d_offset;
|
||||||
|
for (int h = 0; h < output_height; ++h) {
|
||||||
|
int h_pad = h * stride_height - padding_height + h_offset;
|
||||||
|
for (int w = 0; w < output_width; ++w) {
|
||||||
|
int w_pad = w * stride_width - padding_width + w_offset;
|
||||||
|
|
||||||
|
int col_idx =
|
||||||
|
((c * output_depth + d) * output_height + h) * output_width + w;
|
||||||
|
if (h_pad < 0 || h_pad >= input_height || w_pad < 0 ||
|
||||||
|
w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) {
|
||||||
|
col_data[col_idx] = static_cast<T>(0);
|
||||||
|
} else {
|
||||||
|
int vol_idx =
|
||||||
|
((c_in * input_depth + d_pad) * input_height + h_pad) *
|
||||||
|
input_width +
|
||||||
|
w_pad;
|
||||||
|
col_data[col_idx] = vol_data[vol_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* vol = [input_channels,input_depth, input_height, input_width]
|
||||||
|
* col =
|
||||||
|
* [input_channels, filter_depth, filter_height, filter_width,
|
||||||
|
* output_depth, output_height, output_width]
|
||||||
|
*/
|
||||||
|
template <class T>
|
||||||
|
class Col2VolFunctor<platform::CPUPlace, T> {
|
||||||
|
public:
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
framework::Tensor& vol, const framework::Tensor& col,
|
||||||
|
int stride_depth, int stride_height, int stride_width,
|
||||||
|
int padding_depth, int padding_height,
|
||||||
|
int padding_width) const {
|
||||||
|
PADDLE_ENFORCE(vol.dims().size() == 4);
|
||||||
|
PADDLE_ENFORCE(col.dims().size() == 7);
|
||||||
|
|
||||||
|
int input_channels = vol.dims()[0];
|
||||||
|
int input_depth = vol.dims()[1];
|
||||||
|
int input_height = vol.dims()[2];
|
||||||
|
int input_width = vol.dims()[3];
|
||||||
|
int filter_depth = col.dims()[1];
|
||||||
|
int filter_height = col.dims()[2];
|
||||||
|
int filter_width = col.dims()[3];
|
||||||
|
int output_depth = col.dims()[4];
|
||||||
|
int output_height = col.dims()[5];
|
||||||
|
int output_width = col.dims()[6];
|
||||||
|
int channels_col =
|
||||||
|
input_channels * filter_depth * filter_height * filter_width;
|
||||||
|
|
||||||
|
T* vol_data = vol.data<T>();
|
||||||
|
const T* col_data = col.data<T>();
|
||||||
|
|
||||||
|
for (int c = 0; c < channels_col; ++c) {
|
||||||
|
int w_offset = c % filter_width;
|
||||||
|
int h_offset = (c / filter_width) % filter_height;
|
||||||
|
int d_offset = (c / filter_width / filter_height) % filter_depth;
|
||||||
|
int cIm = c / filter_width / filter_height / filter_depth;
|
||||||
|
for (int d = 0; d < output_depth; ++d) {
|
||||||
|
int d_pad = d * stride_depth - padding_depth + d_offset;
|
||||||
|
for (int h = 0; h < output_height; ++h) {
|
||||||
|
int h_pad = h * stride_height - padding_height + h_offset;
|
||||||
|
for (int w = 0; w < output_width; ++w) {
|
||||||
|
int w_pad = w * stride_width - padding_width + w_offset;
|
||||||
|
|
||||||
|
if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
|
||||||
|
w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
|
||||||
|
int vol_idx =
|
||||||
|
((cIm * input_depth + d_pad) * input_height + h_pad) *
|
||||||
|
input_width +
|
||||||
|
w_pad;
|
||||||
|
int col_idx =
|
||||||
|
((c * output_depth + d) * output_height + h) * output_width +
|
||||||
|
w;
|
||||||
|
vol_data[vol_idx] += col_data[col_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template class Vol2ColFunctor<platform::CPUPlace, float>;
|
||||||
|
template class Vol2ColFunctor<platform::CPUPlace, double>;
|
||||||
|
template class Col2VolFunctor<platform::CPUPlace, float>;
|
||||||
|
template class Col2VolFunctor<platform::CPUPlace, double>;
|
||||||
|
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,204 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/math/vol2col.h"
|
||||||
|
#include "paddle/platform/cuda_helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
__global__ void vol2col(int num_kernels, const T* data_vol, int depth,
|
||||||
|
int height, int width, int filter_depth,
|
||||||
|
int filter_height, int filter_width, int stride_depth,
|
||||||
|
int stride_height, int stride_width, int padding_depth,
|
||||||
|
int padding_height, int padding_width, int output_detph,
|
||||||
|
int output_height, int output_width, T* data_col) {
|
||||||
|
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
|
||||||
|
index += blockDim.x * gridDim.x) {
|
||||||
|
int w_out = index % output_width;
|
||||||
|
int h_out = (index / output_width) % output_height;
|
||||||
|
int d_out = (index / output_width / output_height) % output_detph;
|
||||||
|
int channel_in = index / output_width / output_height / output_detph;
|
||||||
|
int channel_out = channel_in * filter_depth * filter_height * filter_width;
|
||||||
|
int w_in = w_out * stride_width - padding_width;
|
||||||
|
int h_in = h_out * stride_height - padding_height;
|
||||||
|
int d_in = d_out * stride_depth - padding_depth;
|
||||||
|
|
||||||
|
data_col += ((channel_out * output_detph + d_out) * output_height + h_out) *
|
||||||
|
output_width +
|
||||||
|
w_out;
|
||||||
|
data_vol += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
|
||||||
|
for (int k = 0; k < filter_depth; ++k) {
|
||||||
|
for (int i = 0; i < filter_height; ++i) {
|
||||||
|
for (int j = 0; j < filter_width; ++j) {
|
||||||
|
int d = d_in + k;
|
||||||
|
int h = h_in + i;
|
||||||
|
int w = w_in + j;
|
||||||
|
*data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
|
||||||
|
w < width)
|
||||||
|
? data_vol[(k * height + i) * width + j]
|
||||||
|
: 0;
|
||||||
|
data_col += output_detph * output_height * output_width;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* im = [input_channels,intpu_depth, input_height, input_width]
|
||||||
|
* col =
|
||||||
|
* [input_channels, filter_depth, filter_height, filter_width,
|
||||||
|
* output_depth, output_height, output_width]
|
||||||
|
*/
|
||||||
|
template <class T>
|
||||||
|
class Vol2ColFunctor<platform::GPUPlace, T> {
|
||||||
|
public:
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::Tensor& vol, framework::Tensor& col,
|
||||||
|
int stride_depth, int stride_height, int stride_width,
|
||||||
|
int padding_depth, int padding_height,
|
||||||
|
int padding_width) const {
|
||||||
|
PADDLE_ENFORCE(vol.dims().size() == 4);
|
||||||
|
PADDLE_ENFORCE(col.dims().size() == 7);
|
||||||
|
|
||||||
|
int input_channels = vol.dims()[0];
|
||||||
|
int input_depth = vol.dims()[1];
|
||||||
|
int input_height = vol.dims()[2];
|
||||||
|
int input_width = vol.dims()[3];
|
||||||
|
int filter_depth = col.dims()[1];
|
||||||
|
int filter_height = col.dims()[2];
|
||||||
|
int filter_width = col.dims()[3];
|
||||||
|
int output_depth = col.dims()[4];
|
||||||
|
int output_height = col.dims()[5];
|
||||||
|
int output_width = col.dims()[6];
|
||||||
|
|
||||||
|
int num_outputs =
|
||||||
|
input_channels * output_depth * output_height * output_width;
|
||||||
|
|
||||||
|
const int threads = 1024;
|
||||||
|
const int blocks = (num_outputs + 1024 - 1) / 1024;
|
||||||
|
vol2col<T><<<blocks, threads, 0,
|
||||||
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
||||||
|
.stream()>>>(
|
||||||
|
num_outputs, vol.data<T>(), input_depth, input_height, input_width,
|
||||||
|
filter_depth, filter_height, filter_width, stride_depth, stride_height,
|
||||||
|
stride_width, padding_depth, padding_height, padding_width,
|
||||||
|
output_depth, output_height, output_width, col.data<T>());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
__global__ void col2vol(int num_kernels, const T* data_col, int depth,
|
||||||
|
int height, int width, int filter_depth,
|
||||||
|
int filter_height, int filter_width, int stride_depth,
|
||||||
|
int stride_height, int stride_width, int padding_depth,
|
||||||
|
int padding_height, int padding_width, int output_detph,
|
||||||
|
int output_height, int output_width, T* data_vol) {
|
||||||
|
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
|
||||||
|
index += blockDim.x * gridDim.x) {
|
||||||
|
T src_val = 0;
|
||||||
|
int w = index % width + padding_width;
|
||||||
|
int h = (index / width) % height + padding_height;
|
||||||
|
int d = (index / width / height) % depth + padding_depth;
|
||||||
|
int c = index / width / height / depth;
|
||||||
|
// compute the start and end of the output
|
||||||
|
int w_col_start =
|
||||||
|
(w < filter_width) ? 0 : (w - filter_width) / stride_width + 1;
|
||||||
|
int w_col_end = min(w / stride_width + 1, output_width);
|
||||||
|
int h_col_start =
|
||||||
|
(h < filter_height) ? 0 : (h - filter_height) / stride_height + 1;
|
||||||
|
int h_col_end = min(h / stride_height + 1, output_height);
|
||||||
|
int d_col_start =
|
||||||
|
(d < filter_depth) ? 0 : (d - filter_depth) / stride_depth + 1;
|
||||||
|
int d_col_end = min(d / stride_depth + 1, output_detph);
|
||||||
|
|
||||||
|
int offset = (c * filter_depth * filter_height * filter_width +
|
||||||
|
d * filter_width * filter_height + h * filter_width + w) *
|
||||||
|
output_detph * output_height * output_width;
|
||||||
|
|
||||||
|
int coeff_d_col =
|
||||||
|
(1 - stride_depth * filter_width * filter_height * output_detph) *
|
||||||
|
output_height * output_width;
|
||||||
|
int coeff_h_col =
|
||||||
|
(1 - stride_height * filter_width * output_detph * output_height) *
|
||||||
|
output_width;
|
||||||
|
int coeff_w_col =
|
||||||
|
(1 - stride_width * output_detph * output_height * output_width);
|
||||||
|
|
||||||
|
for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
|
||||||
|
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
|
||||||
|
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
|
||||||
|
src_val += data_col[offset + d_col * coeff_d_col +
|
||||||
|
h_col * coeff_h_col + w_col * coeff_w_col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data_vol[index] = src_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* im = [input_channels, input_depth, input_height, input_width]
|
||||||
|
* col =
|
||||||
|
* [input_channels, filter_depth, filter_height, filter_width,
|
||||||
|
* output_depth, output_height, output_width]
|
||||||
|
*/
|
||||||
|
template <class T>
|
||||||
|
class Col2VolFunctor<platform::GPUPlace, T> {
|
||||||
|
public:
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
framework::Tensor& vol, const framework::Tensor& col,
|
||||||
|
int stride_depth, int stride_height, int stride_width,
|
||||||
|
int padding_depth, int padding_height,
|
||||||
|
int padding_width) const {
|
||||||
|
PADDLE_ENFORCE(vol.dims().size() == 4);
|
||||||
|
PADDLE_ENFORCE(col.dims().size() == 7);
|
||||||
|
|
||||||
|
int input_channels = vol.dims()[0];
|
||||||
|
int input_depth = vol.dims()[1];
|
||||||
|
int input_height = vol.dims()[2];
|
||||||
|
int input_width = vol.dims()[3];
|
||||||
|
int filter_depth = col.dims()[1];
|
||||||
|
int filter_height = col.dims()[2];
|
||||||
|
int filter_width = col.dims()[3];
|
||||||
|
int output_depth = col.dims()[4];
|
||||||
|
int output_height = col.dims()[5];
|
||||||
|
int output_width = col.dims()[6];
|
||||||
|
|
||||||
|
int num_kernels = input_channels * input_depth * input_height * input_width;
|
||||||
|
|
||||||
|
const int threads = 1024;
|
||||||
|
const int blocks = (num_kernels + 1024 - 1) / 1024;
|
||||||
|
|
||||||
|
col2vol<T><<<blocks, threads, 0,
|
||||||
|
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
||||||
|
.stream()>>>(
|
||||||
|
num_kernels, col.data<T>(), input_depth, input_height, input_width,
|
||||||
|
filter_depth, filter_height, filter_width, stride_depth, stride_height,
|
||||||
|
stride_width, padding_depth, padding_height, padding_width,
|
||||||
|
output_depth, output_height, output_width, vol.data<T>());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template class Vol2ColFunctor<platform::GPUPlace, float>;
|
||||||
|
template class Vol2ColFunctor<platform::GPUPlace, double>;
|
||||||
|
template class Col2VolFunctor<platform::GPUPlace, float>;
|
||||||
|
template class Col2VolFunctor<platform::GPUPlace, double>;
|
||||||
|
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,78 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/framework/tensor.h"
|
||||||
|
#include "paddle/platform/device_context.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
/*
|
||||||
|
* \brief Converts the feature data of four dimensions(CDHW) into a colData of
|
||||||
|
* seven dimensions in the Vol2ColFunctor calculation,
|
||||||
|
* And in the Col2VolFunctor calculation, it is reversed.
|
||||||
|
*
|
||||||
|
* \param volData Vol data.
|
||||||
|
* \param volShape The shape of volData,
|
||||||
|
* [input_channels, input_depth, input_height, input_width].
|
||||||
|
* \param colData Column data.
|
||||||
|
* \param colShape The shape of colData.
|
||||||
|
*
|
||||||
|
* The shape of colData is:
|
||||||
|
* [input_channels, filter_depth, filter_height, filter_width, output_depth,
|
||||||
|
* output_height, output_width]
|
||||||
|
* So, it is easy to reshape into a convolution matrix for convolution
|
||||||
|
* calculation based on matrix multiplication.
|
||||||
|
* The shape of convolution matrix is [height, width], where the height is equal
|
||||||
|
* input_channels * filter_depth * filter_height * filter_width, and the width
|
||||||
|
* is equal output_depth * output_height * output_width.
|
||||||
|
*
|
||||||
|
* Reshape:
|
||||||
|
* shape of colData shape of convolution matrix
|
||||||
|
* [input_channels,
|
||||||
|
* filter_depth,
|
||||||
|
* filter_height,
|
||||||
|
* filter_width, ======> [height, width]
|
||||||
|
* output_depth,
|
||||||
|
* output_height,
|
||||||
|
* output_width]
|
||||||
|
*
|
||||||
|
* \note The caller needs to ensure that volShape.inputChannels is equal to
|
||||||
|
* colShape.inputChannels.
|
||||||
|
*/
|
||||||
|
template <typename Place, typename T>
|
||||||
|
class Vol2ColFunctor {
|
||||||
|
public:
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
const framework::Tensor& vol, framework::Tensor& col,
|
||||||
|
int stride_depth, int stride_height, int stride_width,
|
||||||
|
int padding_depth, int padding_height,
|
||||||
|
int padding_width) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Place, typename T>
|
||||||
|
class Col2VolFunctor {
|
||||||
|
public:
|
||||||
|
void operator()(const platform::DeviceContext& context,
|
||||||
|
framework::Tensor& vol, const framework::Tensor& col,
|
||||||
|
int stride_depth, int stride_height, int stride_width,
|
||||||
|
int padding_depth, int padding_height,
|
||||||
|
int padding_width) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,135 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/math/vol2col.h"
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
template <typename Place>
|
||||||
|
void testVol2col() {
|
||||||
|
paddle::framework::Tensor input;
|
||||||
|
paddle::framework::Tensor input_tmp;
|
||||||
|
paddle::framework::Tensor output;
|
||||||
|
paddle::framework::Tensor output_tmp;
|
||||||
|
|
||||||
|
auto* place = new Place();
|
||||||
|
paddle::platform::DeviceContext* context;
|
||||||
|
if (paddle::platform::is_cpu_place(*place)) {
|
||||||
|
context =
|
||||||
|
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
|
||||||
|
} else {
|
||||||
|
#ifdef PADDLE_WITH_CUDA
|
||||||
|
context =
|
||||||
|
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
|
||||||
|
#else
|
||||||
|
PADDLE_THROW("no GPU support");
|
||||||
|
#endif // PADDLE_WITH_CUDA
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* input = [[0, 1, 2,
|
||||||
|
* 3, 4, 5]
|
||||||
|
* [6, 7, 8,
|
||||||
|
* 9, 10, 11]]
|
||||||
|
*
|
||||||
|
* output = [0, 1
|
||||||
|
* 1, 2
|
||||||
|
* 3, 4
|
||||||
|
* 4, 5
|
||||||
|
* 6, 7
|
||||||
|
* 7, 8
|
||||||
|
* 9, 10
|
||||||
|
* 10, 11]
|
||||||
|
*
|
||||||
|
* col2vol = [[0, 2, 2,
|
||||||
|
* 3, 8, 5]
|
||||||
|
* [6, 14, 8,
|
||||||
|
* 9, 20, 11]]
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
int input_depth = 2;
|
||||||
|
int input_height = 2;
|
||||||
|
int input_width = 3;
|
||||||
|
int filter_size = 2;
|
||||||
|
int stride = 1;
|
||||||
|
int padding = 0;
|
||||||
|
int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1;
|
||||||
|
int output_height = (input_height - filter_size + 2 * padding) / stride + 1;
|
||||||
|
int output_width = (input_width - filter_size + 2 * padding) / stride + 1;
|
||||||
|
|
||||||
|
// Vol2Col test
|
||||||
|
float* input_ptr =
|
||||||
|
input_tmp.mutable_data<float>({1, input_depth, input_height, input_width},
|
||||||
|
paddle::platform::CPUPlace());
|
||||||
|
float arr[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
|
||||||
|
memcpy(input_ptr, arr, 12 * sizeof(float));
|
||||||
|
|
||||||
|
if (paddle::platform::is_cpu_place(*place)) {
|
||||||
|
input = input_tmp;
|
||||||
|
} else {
|
||||||
|
input.CopyFrom<float>(input_tmp, *place);
|
||||||
|
}
|
||||||
|
output.mutable_data<float>({1, filter_size, filter_size, filter_size,
|
||||||
|
output_depth, output_height, output_width},
|
||||||
|
*place);
|
||||||
|
|
||||||
|
paddle::operators::math::Vol2ColFunctor<Place, float> vol2col;
|
||||||
|
vol2col(*context, input, output, stride, stride, stride, padding, padding,
|
||||||
|
padding);
|
||||||
|
|
||||||
|
float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11};
|
||||||
|
float* out_cfo_ptr;
|
||||||
|
if (paddle::platform::is_cpu_place(*place)) {
|
||||||
|
out_cfo_ptr = output.data<float>();
|
||||||
|
} else {
|
||||||
|
output_tmp.CopyFrom<float>(output, paddle::platform::CPUPlace());
|
||||||
|
out_cfo_ptr = output_tmp.data<float>();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 16; ++i) {
|
||||||
|
EXPECT_EQ(out_cfo_ptr[i], vol_2_col[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Col2Vol test
|
||||||
|
float col_2_vol[] = {0, 2, 2, 3, 8, 5, 6, 14, 8, 9, 20, 11};
|
||||||
|
memset(input_ptr, 0, 12 * sizeof(float));
|
||||||
|
if (paddle::platform::is_cpu_place(*place)) {
|
||||||
|
input = input_tmp;
|
||||||
|
} else {
|
||||||
|
input.CopyFrom<float>(input_tmp, *place);
|
||||||
|
}
|
||||||
|
|
||||||
|
paddle::operators::math::Col2VolFunctor<Place, float> col2vol;
|
||||||
|
col2vol(*context, input, output, stride, stride, stride, padding, padding,
|
||||||
|
padding);
|
||||||
|
|
||||||
|
float* in_ptr;
|
||||||
|
if (paddle::platform::is_cpu_place(*place)) {
|
||||||
|
in_ptr = input.data<float>();
|
||||||
|
} else {
|
||||||
|
input_tmp.CopyFrom<float>(input, paddle::platform::CPUPlace());
|
||||||
|
in_ptr = input_tmp.data<float>();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 12; ++i) {
|
||||||
|
EXPECT_EQ(in_ptr[i], col_2_vol[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(math, vol2col) {
|
||||||
|
testVol2col<paddle::platform::CPUPlace>();
|
||||||
|
#ifdef PADDLE_WITH_CUDA
|
||||||
|
testVol2col<paddle::platform::GPUPlace>();
|
||||||
|
#endif // PADDLE_WITH_CUDA
|
||||||
|
}
|
Loading…
Reference in new issue