Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into dev_add_axis
commit
f6e72c93c7
@ -0,0 +1,124 @@
|
|||||||
|
## Background
|
||||||
|
PaddlePaddle divides the description of neural network computation graph into two stages: compile time and runtime.
|
||||||
|
|
||||||
|
PaddlePaddle use proto message to describe compile time graph for
|
||||||
|
|
||||||
|
1. Computation graph should be able to be saved to a file.
|
||||||
|
1. In distributed training, the graph will be serialized and send to multiple workers.
|
||||||
|
|
||||||
|
The computation graph is constructed by Data Node and Operation Node. The concept to represent them is in the table below.
|
||||||
|
|
||||||
|
| |compile time|runtime|
|
||||||
|
|---|---|---|
|
||||||
|
|Data|VarDesc(proto)|Variable(cpp)|
|
||||||
|
|Operation|OpDesc(proto)|Operator(cpp)|
|
||||||
|
|
||||||
|
|
||||||
|
## Definition of VarDesc
|
||||||
|
|
||||||
|
A VarDesc should have a name and value, in PaddlePaddle, the value will always be a tensor. Since we use LoDTensor most of the time. We add a LoDTesnorDesc to represent it.
|
||||||
|
|
||||||
|
```proto
|
||||||
|
message VarDesc {
|
||||||
|
required string name = 1;
|
||||||
|
optional LoDTensorDesc lod_tensor = 2;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Definition of LodTensorDesc
|
||||||
|
|
||||||
|
```proto
|
||||||
|
enum DataType {
|
||||||
|
BOOL = 0;
|
||||||
|
INT16 = 1;
|
||||||
|
INT32 = 2;
|
||||||
|
INT64 = 3;
|
||||||
|
FP16 = 4;
|
||||||
|
FP32 = 5;
|
||||||
|
FP64 = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message LoDTensorDesc {
|
||||||
|
required DataType data_type = 1;
|
||||||
|
repeated int32 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
|
||||||
|
optional int32 lod_level = 3 [default=0];
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Definition of Variable in Python
|
||||||
|
|
||||||
|
In Python API, layer will take Variable as Input, and return Variable as Output. There should be a class `Variable` in python to help create and manage Variable.
|
||||||
|
|
||||||
|
```python
|
||||||
|
image = Variable(dims=[-1, 640, 480])
|
||||||
|
# fc1 and fc2 are both Variable
|
||||||
|
fc1 = layer.fc(input=image, output_size=10)
|
||||||
|
fc2 = layer.fc(input=fc1, output_size=20)
|
||||||
|
```
|
||||||
|
### what should class `Variable` Have
|
||||||
|
1. `name`.a name of string type is used to mark the value of the Variable.
|
||||||
|
1. `initializer`. Since our Tensor does not have value. we will always use some Operator to fullfill it when run. So we should have a initialize method to help add the init operator.
|
||||||
|
1. `operator`. Variable should record which operator produce itself. The reaon is:
|
||||||
|
- we use pd.eval(targets=[var1, var2]) to run the related ops to get the value of var1 and var2. var.op is used to trace the dependency of the current variable.
|
||||||
|
|
||||||
|
In PaddlePaddle, we use Block to describe Computation Graph, so in the code we will use Block but not Graph.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import VarDesc
|
||||||
|
import LoDTensorDesc
|
||||||
|
import framework
|
||||||
|
|
||||||
|
def AddInitialOperator(variable, initializer):
|
||||||
|
# add an initialize Operator to block to init this Variable
|
||||||
|
|
||||||
|
class Variable(object):
|
||||||
|
def __init__(self, name, dims, type, initializer):
|
||||||
|
self._block = get_default_block()
|
||||||
|
self._name = name
|
||||||
|
self.op = None
|
||||||
|
|
||||||
|
tensor_desc = LoDTensorDesc(data_type=type, dims=dims)
|
||||||
|
_var_desc = VarDesc(name=name, lod_tensor=tensor_desc)
|
||||||
|
self._var = framework.CreateVar(_var_desc)
|
||||||
|
self._block.add_var(self)
|
||||||
|
|
||||||
|
# add initial op according to initializer
|
||||||
|
if initializer is not None:
|
||||||
|
AddInitialOperator(self, initializer)
|
||||||
|
|
||||||
|
def dims(self):
|
||||||
|
return self._var.dims()
|
||||||
|
|
||||||
|
def data_type(self):
|
||||||
|
return self._var.data_type()
|
||||||
|
|
||||||
|
def to_proto(self):
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
Then we can use this Variable to create a fc layer in Python.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import paddle as pd
|
||||||
|
|
||||||
|
def flatten_size(X, num_flatten_dims):
|
||||||
|
prod = 1 # of last num_flatten_dims
|
||||||
|
for i in xrange(num_flatten_dims):
|
||||||
|
prod = prod * X.dims[-i-1]
|
||||||
|
return prod
|
||||||
|
|
||||||
|
def layer.fc(X, output_size, num_flatten_dims):
|
||||||
|
W = Variable(pd.random_uniform(), type=FP32, dims=[flatten_size(X, num_flatten_dims), output_size])
|
||||||
|
b = Variable(pd.random_uniform(), type=FP32, dims=[output_size])
|
||||||
|
out = Variable(type=FP32)
|
||||||
|
y = operator.fc(X, W, b, output=out) # fc will put fc op input into out
|
||||||
|
pd.InferShape(y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Variable(dims=[-1, 640, 480])
|
||||||
|
y = layer.fc(x, output_size=100)
|
||||||
|
z = layer.fc(y, output_size=200)
|
||||||
|
|
||||||
|
paddle.eval(targets=[z], ...)
|
||||||
|
print(z)
|
||||||
|
```
|
@ -0,0 +1,159 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "GemmFunctor.h"
|
||||||
|
#include "hl_cpu_gru.cuh"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
template <DeviceType Device, class T>
|
||||||
|
struct GruFunctor {
|
||||||
|
template <class OpResetOutput, class OpFinalOutput>
|
||||||
|
static void compute(OpResetOutput opResetOutput,
|
||||||
|
OpFinalOutput opFinalOutput,
|
||||||
|
hl_gru_value value,
|
||||||
|
int frameSize,
|
||||||
|
int batchSize,
|
||||||
|
hl_activation_mode_t active_node,
|
||||||
|
hl_activation_mode_t active_gate) {
|
||||||
|
#ifndef __NVCC__
|
||||||
|
if (value.prevOutValue) {
|
||||||
|
BlasGemm<Device, T>::compute(false,
|
||||||
|
false,
|
||||||
|
batchSize,
|
||||||
|
2 * frameSize,
|
||||||
|
frameSize,
|
||||||
|
1,
|
||||||
|
value.prevOutValue,
|
||||||
|
frameSize,
|
||||||
|
value.gateWeight,
|
||||||
|
frameSize * 2,
|
||||||
|
1,
|
||||||
|
value.gateValue,
|
||||||
|
frameSize * 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
forward_reset_output(
|
||||||
|
opResetOutput, value, frameSize, batchSize, active_gate);
|
||||||
|
|
||||||
|
if (value.prevOutValue) {
|
||||||
|
BlasGemm<Device, T>::compute(false,
|
||||||
|
false,
|
||||||
|
batchSize,
|
||||||
|
frameSize,
|
||||||
|
frameSize,
|
||||||
|
1,
|
||||||
|
value.resetOutputValue,
|
||||||
|
frameSize,
|
||||||
|
value.stateWeight,
|
||||||
|
frameSize,
|
||||||
|
1,
|
||||||
|
value.gateValue + frameSize * 2,
|
||||||
|
frameSize * 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
forward_final_output(
|
||||||
|
opFinalOutput, value, frameSize, batchSize, active_node);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <DeviceType Device, class T>
|
||||||
|
struct GruGradFunctor {
|
||||||
|
template <class OpStateGrad, class OpResetGrad>
|
||||||
|
static void compute(OpStateGrad opStateGrad,
|
||||||
|
OpResetGrad opResetGrad,
|
||||||
|
hl_gru_value value,
|
||||||
|
hl_gru_grad grad,
|
||||||
|
int frameSize,
|
||||||
|
int batchSize,
|
||||||
|
hl_activation_mode_t active_node,
|
||||||
|
hl_activation_mode_t active_gate) {
|
||||||
|
#ifndef __NVCC__
|
||||||
|
backward_state_grad(
|
||||||
|
opStateGrad, value, grad, frameSize, batchSize, active_node);
|
||||||
|
|
||||||
|
if (value.prevOutValue && grad.prevOutGrad) {
|
||||||
|
BlasGemm<Device, T>::compute(false,
|
||||||
|
true,
|
||||||
|
batchSize,
|
||||||
|
frameSize,
|
||||||
|
frameSize,
|
||||||
|
1,
|
||||||
|
grad.gateGrad + frameSize * 2,
|
||||||
|
frameSize * 3,
|
||||||
|
value.stateWeight,
|
||||||
|
frameSize,
|
||||||
|
0,
|
||||||
|
grad.resetOutputGrad,
|
||||||
|
frameSize);
|
||||||
|
|
||||||
|
if (grad.stateWeightGrad) {
|
||||||
|
BlasGemm<Device, T>::compute(true,
|
||||||
|
false,
|
||||||
|
frameSize,
|
||||||
|
frameSize,
|
||||||
|
batchSize,
|
||||||
|
1,
|
||||||
|
value.resetOutputValue,
|
||||||
|
frameSize,
|
||||||
|
grad.gateGrad + frameSize * 2,
|
||||||
|
frameSize * 3,
|
||||||
|
1,
|
||||||
|
grad.stateWeightGrad,
|
||||||
|
frameSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
backward_reset_grad(
|
||||||
|
opResetGrad, value, grad, frameSize, batchSize, active_gate);
|
||||||
|
|
||||||
|
if (grad.prevOutGrad && value.prevOutValue) {
|
||||||
|
BlasGemm<Device, T>::compute(false,
|
||||||
|
true,
|
||||||
|
batchSize,
|
||||||
|
frameSize,
|
||||||
|
frameSize * 2,
|
||||||
|
1,
|
||||||
|
grad.gateGrad,
|
||||||
|
frameSize * 3,
|
||||||
|
value.gateWeight,
|
||||||
|
frameSize * 2,
|
||||||
|
1,
|
||||||
|
grad.prevOutGrad,
|
||||||
|
frameSize);
|
||||||
|
|
||||||
|
if (grad.gateWeightGrad) {
|
||||||
|
BlasGemm<Device, T>::compute(true,
|
||||||
|
false,
|
||||||
|
frameSize,
|
||||||
|
frameSize * 2,
|
||||||
|
batchSize,
|
||||||
|
1,
|
||||||
|
value.prevOutValue,
|
||||||
|
frameSize,
|
||||||
|
grad.gateGrad,
|
||||||
|
frameSize * 3,
|
||||||
|
1,
|
||||||
|
grad.gateWeightGrad,
|
||||||
|
frameSize * 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,140 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "SwitchOp.h"
|
||||||
|
#include "paddle/math/Vector.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void NCHW2NHWC<DEVICE_TYPE_CPU>(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const int num,
|
||||||
|
const int inC,
|
||||||
|
const int inH,
|
||||||
|
const int inW,
|
||||||
|
const int argType) {
|
||||||
|
for (int n = 0; n < num; ++n) {
|
||||||
|
for (int c = 0; c < inC; ++c) {
|
||||||
|
for (int h = 0; h < inH; ++h) {
|
||||||
|
for (int w = 0; w < inW; ++w) {
|
||||||
|
if (argType == ADD_TO) {
|
||||||
|
outputs[((n * inH + h) * inW + w) * inC + c] += *(inputs++);
|
||||||
|
} else {
|
||||||
|
outputs[((n * inH + h) * inW + w) * inC + c] = *(inputs++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void NHWC2NCHW<DEVICE_TYPE_CPU>(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const int num,
|
||||||
|
const int inH,
|
||||||
|
const int inW,
|
||||||
|
const int inC,
|
||||||
|
const int argType) {
|
||||||
|
for (int n = 0; n < num; ++n) {
|
||||||
|
for (int h = 0; h < inH; ++h) {
|
||||||
|
for (int w = 0; w < inW; ++w) {
|
||||||
|
for (int c = 0; c < inC; ++c) {
|
||||||
|
if (argType == ADD_TO) {
|
||||||
|
outputs[((n * inC + c) * inH + h) * inW + w] += *(inputs++);
|
||||||
|
} else {
|
||||||
|
outputs[((n * inC + c) * inH + h) * inW + w] = *(inputs++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Switch dimension order of image input.
|
||||||
|
* The input and output is a 4D tensor. Switch order
|
||||||
|
* 'batch_size,channels, height, width' to
|
||||||
|
* order 'batch_size, height, width, channels'.
|
||||||
|
*
|
||||||
|
* Argument in this Function:
|
||||||
|
* \param inputs input data with order 'batch_size,channels, height, width'.
|
||||||
|
* \param outputs output data with order 'batch_size, height, width, channels'.
|
||||||
|
*/
|
||||||
|
template <DeviceType Device>
|
||||||
|
class NCHW2NHWCFunc : public FunctionBase {
|
||||||
|
public:
|
||||||
|
void init(const FuncConfig& config) override {}
|
||||||
|
|
||||||
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||||
|
CHECK_EQ(1UL, inputs.size());
|
||||||
|
CHECK_EQ(1UL, outputs.size());
|
||||||
|
|
||||||
|
size_t num = inputs[0].shape()[0];
|
||||||
|
size_t inC = inputs[0].shape()[1];
|
||||||
|
size_t inH = inputs[0].shape()[2];
|
||||||
|
size_t inW = inputs[0].shape()[3];
|
||||||
|
NCHW2NHWC<Device>(outputs[0].data<real>(),
|
||||||
|
inputs[0].data<real>(),
|
||||||
|
num,
|
||||||
|
inC,
|
||||||
|
inH,
|
||||||
|
inW,
|
||||||
|
outputs[0].getArgType());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Switch dimension order of image input.
|
||||||
|
* The input and output is a 4D tensor. Switch order
|
||||||
|
* 'batch_size, height, width, channels' to
|
||||||
|
* order 'batch_size, channels, height, width'.
|
||||||
|
*
|
||||||
|
* Argument in this Function:
|
||||||
|
* \param inputs input data with order 'batch_size, height, width, channels'.
|
||||||
|
* \param outputs output data with order 'batch_size, channels, height, width'.
|
||||||
|
*/
|
||||||
|
template <DeviceType Device>
|
||||||
|
class NHWC2NCHWFunc : public FunctionBase {
|
||||||
|
public:
|
||||||
|
void init(const FuncConfig& config) override {}
|
||||||
|
|
||||||
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||||
|
CHECK_EQ(1UL, inputs.size());
|
||||||
|
CHECK_EQ(1UL, outputs.size());
|
||||||
|
|
||||||
|
size_t num = inputs[0].shape()[0];
|
||||||
|
size_t inH = inputs[0].shape()[1];
|
||||||
|
size_t inW = inputs[0].shape()[2];
|
||||||
|
size_t inC = inputs[0].shape()[3];
|
||||||
|
|
||||||
|
NHWC2NCHW<Device>(outputs[0].data<real>(),
|
||||||
|
inputs[0].data<real>(),
|
||||||
|
num,
|
||||||
|
inH,
|
||||||
|
inW,
|
||||||
|
inC,
|
||||||
|
outputs[0].getArgType());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_TYPED_FUNC(NCHW2NHWC, CPU, NCHW2NHWCFunc);
|
||||||
|
REGISTER_TYPED_FUNC(NHWC2NCHW, CPU, NHWC2NCHWFunc);
|
||||||
|
#ifndef PADDLE_ONLY_CPU
|
||||||
|
REGISTER_TYPED_FUNC(NCHW2NHWC, GPU, NCHW2NHWCFunc);
|
||||||
|
REGISTER_TYPED_FUNC(NHWC2NCHW, GPU, NHWC2NCHWFunc);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,66 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Function.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief This funtion switch dimension order of image input.
|
||||||
|
* The input and output is a 4D tensor. Switch order 'batch_size,
|
||||||
|
*channels, height, width' to
|
||||||
|
* order 'batch_size, height, width, channels'.
|
||||||
|
*
|
||||||
|
* \param[out] outputs save results.
|
||||||
|
* \param[in] inputs input data.
|
||||||
|
* \param[in] num batch size of input data.
|
||||||
|
* \param[in] inC channel number of input data.
|
||||||
|
* \param[in] inH height of input data.
|
||||||
|
* \param[in] inH with of input data.
|
||||||
|
* \param[in] argType type of output argument.
|
||||||
|
*/
|
||||||
|
template <DeviceType Device>
|
||||||
|
void NCHW2NHWC(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const int num,
|
||||||
|
const int inC,
|
||||||
|
const int inH,
|
||||||
|
const int inW,
|
||||||
|
const int argtype);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief This funtion switch dimension order of image input.
|
||||||
|
* The input and output is a 4D tensor. Switch order 'batch_size,
|
||||||
|
*height, width, channels' to
|
||||||
|
* order 'batch_size, channels, height, width'.
|
||||||
|
*
|
||||||
|
* \param[out] inGrad gradients of previous layer.
|
||||||
|
* \param[in] outGrad output gradients.
|
||||||
|
* \param[in] num batch size of input data.
|
||||||
|
* \param[in] inH height of input data.
|
||||||
|
* \param[in] inW with of input data.
|
||||||
|
* \param[in] inC channel number of input data.
|
||||||
|
* \param[in] argType type of output argument.
|
||||||
|
*/
|
||||||
|
template <DeviceType Device>
|
||||||
|
void NHWC2NCHW(real* inGrad,
|
||||||
|
const real* outGrad,
|
||||||
|
const int num,
|
||||||
|
const int inH,
|
||||||
|
const int inW,
|
||||||
|
const int inC,
|
||||||
|
const int argType);
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,98 @@
|
|||||||
|
/* Copyright (c) 2016 Paddle
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "SwitchOp.h"
|
||||||
|
#include "hl_base.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
__global__ void KeNCHW2NHWC(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
int inC,
|
||||||
|
int inH,
|
||||||
|
int inW,
|
||||||
|
int nthreads,
|
||||||
|
int argType) {
|
||||||
|
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
if (idx < nthreads) {
|
||||||
|
const int w = idx % inW;
|
||||||
|
const int h = (idx / inW) % inH;
|
||||||
|
const int c = (idx / inW / inH) % inC;
|
||||||
|
const int n = idx / inW / inH / inC;
|
||||||
|
|
||||||
|
const int off = ((n * inH + h) * inW + w) * inC + c;
|
||||||
|
if (argType == ADD_TO) {
|
||||||
|
outputs[off] += inputs[idx];
|
||||||
|
} else {
|
||||||
|
outputs[off] = inputs[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void NCHW2NHWC<DEVICE_TYPE_GPU>(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const int num,
|
||||||
|
const int inC,
|
||||||
|
const int inH,
|
||||||
|
const int inW,
|
||||||
|
const int argType) {
|
||||||
|
size_t nth = num * inC * inH * inW;
|
||||||
|
int blockSize = 1024;
|
||||||
|
int gridSize = (nth + 1024 - 1) / 1024;
|
||||||
|
KeNCHW2NHWC<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(
|
||||||
|
outputs, inputs, inC, inH, inW, nth, argType);
|
||||||
|
CHECK_SYNC("NCHW2NHWC");
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void KeNHWC2NCHW(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
int inH,
|
||||||
|
int inW,
|
||||||
|
int inC,
|
||||||
|
int nthreads,
|
||||||
|
int argType) {
|
||||||
|
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
if (idx < nthreads) {
|
||||||
|
const int c = idx % inC;
|
||||||
|
const int w = (idx / inC) % inW;
|
||||||
|
const int h = (idx / inC / inW) % inH;
|
||||||
|
const int n = idx / inW / inH / inC;
|
||||||
|
|
||||||
|
const int off = ((n * inC + c) * inH + h) * inW + w;
|
||||||
|
if (argType == ADD_TO) {
|
||||||
|
outputs[off] += inputs[idx];
|
||||||
|
} else {
|
||||||
|
outputs[off] = inputs[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void NHWC2NCHW<DEVICE_TYPE_GPU>(real* outputs,
|
||||||
|
const real* inputs,
|
||||||
|
const int num,
|
||||||
|
const int inH,
|
||||||
|
const int inW,
|
||||||
|
const int inC,
|
||||||
|
const int argType) {
|
||||||
|
int nth = num * inC * inH * inW;
|
||||||
|
int blockSize = 1024;
|
||||||
|
int gridSize = (nth + 1024 - 1) / 1024;
|
||||||
|
KeNHWC2NCHW<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(
|
||||||
|
outputs, inputs, inH, inW, inC, nth, argType);
|
||||||
|
CHECK_SYNC("NHWC2NCHW");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,44 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "FunctionTest.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
TEST(Pad, real) {
|
||||||
|
for (size_t numSamples : {1, 4, 8, 16}) {
|
||||||
|
for (size_t channels : {1, 4, 8, 16}) {
|
||||||
|
for (size_t imgSizeH : {1, 4, 8, 16}) {
|
||||||
|
for (size_t imgSizeW : {1, 4, 8, 16}) {
|
||||||
|
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
|
||||||
|
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
|
||||||
|
for (bool test_grad : {true, false}) {
|
||||||
|
CpuGpuFuncCompare compare(test_grad ? "NHWC2NCHW" : "NCHW2NHWC",
|
||||||
|
FuncConfig());
|
||||||
|
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
|
||||||
|
TensorShape outDims{numSamples, imgSizeH, imgSizeW, channels};
|
||||||
|
compare.addInputs(
|
||||||
|
BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
|
||||||
|
compare.addOutputs(BufferArg(
|
||||||
|
VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO));
|
||||||
|
compare.run();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,136 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "NeonDepthwiseConv.h"
|
||||||
|
#include "paddle/function/ConvOp.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
|
||||||
|
|
||||||
|
template <DeviceType Device>
|
||||||
|
class NeonDepthwiseConvTransposeFunction : public ConvFunctionBase {
|
||||||
|
public:
|
||||||
|
void init(const FuncConfig& config) override {
|
||||||
|
ConvFunctionBase::init(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||||
|
const TensorShape& input = inputs[0].shape();
|
||||||
|
const TensorShape& filter = inputs[1].shape();
|
||||||
|
const TensorShape& output = outputs[0].shape();
|
||||||
|
checkShape(input, filter, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
|
||||||
|
CHECK_EQ(numInputs_, inputs.size());
|
||||||
|
CHECK_EQ(numOutputs_, outputs.size());
|
||||||
|
check(inputs, outputs);
|
||||||
|
|
||||||
|
const TensorShape& input = inputs[0].shape();
|
||||||
|
const TensorShape& filter = inputs[1].shape();
|
||||||
|
const TensorShape& output = outputs[0].shape();
|
||||||
|
|
||||||
|
int batchSize = input[0];
|
||||||
|
int inputChannels = input[1];
|
||||||
|
int inputHeight = input[2];
|
||||||
|
int inputWidth = input[3];
|
||||||
|
int filterHeight = getFilterHeight(filter);
|
||||||
|
int filterWidth = getFilterWidth(filter);
|
||||||
|
int outputChannels = output[1];
|
||||||
|
int outputHeight = output[2];
|
||||||
|
int outputWidth = output[3];
|
||||||
|
int filterMultiplier = outputChannels / groups_;
|
||||||
|
CHECK_EQ(inputChannels, groups_);
|
||||||
|
|
||||||
|
// only support strideH() == strideW() and filterHeight == filterWidth.
|
||||||
|
CHECK_EQ(strideH(), strideW());
|
||||||
|
CHECK_EQ(paddingH(), paddingW());
|
||||||
|
CHECK_EQ(filterHeight, filterWidth);
|
||||||
|
|
||||||
|
float* inputData = inputs[0].data<float>();
|
||||||
|
float* filterData = inputs[1].data<float>();
|
||||||
|
float* outputData = outputs[0].data<float>();
|
||||||
|
|
||||||
|
// padding the input, input -> inputPadding
|
||||||
|
float* inputPadding = inputData;
|
||||||
|
int padInputHeight =
|
||||||
|
(inputHeight - 1) * strideH() + 2 * filterHeight - 1 - 2 * paddingH();
|
||||||
|
int padInputWidth =
|
||||||
|
(inputWidth - 1) * strideW() + 2 * filterWidth - 1 - 2 * paddingW();
|
||||||
|
|
||||||
|
if (padInputHeight > inputHeight || padInputWidth > inputWidth) {
|
||||||
|
int newSize = batchSize * inputChannels * padInputHeight * padInputWidth;
|
||||||
|
resizeBuffer<Device>(newSize);
|
||||||
|
inputPadding = reinterpret_cast<float*>(memory_->getBuf());
|
||||||
|
if (strideH() == 1) {
|
||||||
|
neon::Padding<float>::run(inputData,
|
||||||
|
inputPadding,
|
||||||
|
batchSize * inputChannels,
|
||||||
|
inputHeight,
|
||||||
|
inputWidth,
|
||||||
|
padInputHeight,
|
||||||
|
padInputWidth);
|
||||||
|
} else if (strideH() == 2) {
|
||||||
|
neon::StridePadding::run(inputData,
|
||||||
|
inputPadding,
|
||||||
|
batchSize * inputChannels,
|
||||||
|
inputHeight,
|
||||||
|
inputWidth,
|
||||||
|
padInputHeight,
|
||||||
|
padInputWidth);
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Not supported";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<void(
|
||||||
|
const float*, const float*, int, int, int, int, int, int, float*)>
|
||||||
|
DepthWiseConv;
|
||||||
|
|
||||||
|
if (filterWidth == 3) {
|
||||||
|
DepthWiseConv = neon::DepthwiseConvKernel<3, 1>::run;
|
||||||
|
} else if (filterWidth == 4) {
|
||||||
|
DepthWiseConv = neon::DepthwiseConvKernel<4, 1>::run;
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Not supported";
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < batchSize; i++) {
|
||||||
|
DepthWiseConv(inputPadding,
|
||||||
|
filterData,
|
||||||
|
padInputHeight,
|
||||||
|
padInputWidth,
|
||||||
|
outputChannels,
|
||||||
|
outputHeight,
|
||||||
|
outputWidth,
|
||||||
|
filterMultiplier,
|
||||||
|
outputData);
|
||||||
|
inputPadding += inputChannels * padInputHeight * padInputWidth;
|
||||||
|
outputData += outputChannels * outputHeight * outputWidth;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifndef PADDLE_TYPE_DOUBLE
|
||||||
|
|
||||||
|
REGISTER_TYPED_FUNC(NeonDepthwiseConvTranspose,
|
||||||
|
CPU,
|
||||||
|
NeonDepthwiseConvTransposeFunction);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,107 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "SwitchOrderLayer.h"
|
||||||
|
#include "paddle/utils/Stat.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
REGISTER_LAYER(switch_order, SwitchOrderLayer);
|
||||||
|
|
||||||
|
bool SwitchOrderLayer::init(const LayerMap& layerMap,
|
||||||
|
const ParameterMap& parameterMap) {
|
||||||
|
/* Initialize the basic parent class */
|
||||||
|
Layer::init(layerMap, parameterMap);
|
||||||
|
auto& img_conf = config_.inputs(0).image_conf();
|
||||||
|
size_t inH =
|
||||||
|
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
|
||||||
|
size_t inW = img_conf.img_size();
|
||||||
|
size_t inC = img_conf.channels();
|
||||||
|
inDims_ = TensorShape({0, inC, inH, inW});
|
||||||
|
outDims_ = TensorShape(4);
|
||||||
|
|
||||||
|
auto& reshape_conf = config_.reshape_conf();
|
||||||
|
for (int i = 0; i < reshape_conf.height_axis_size(); i++) {
|
||||||
|
heightAxis_.push_back(reshape_conf.height_axis(i));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < reshape_conf.width_axis_size(); i++) {
|
||||||
|
widthAxis_.push_back(reshape_conf.width_axis(i));
|
||||||
|
}
|
||||||
|
createFunction(nchw2nhwc_, "NCHW2NHWC", FuncConfig());
|
||||||
|
createFunction(nhwc2nchw_, "NHWC2NCHW", FuncConfig());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SwitchOrderLayer::setOutDims() {
|
||||||
|
outDims_.setDim(0, inDims_[0]);
|
||||||
|
outDims_.setDim(1, inDims_[2]);
|
||||||
|
outDims_.setDim(2, inDims_[3]);
|
||||||
|
outDims_.setDim(3, inDims_[1]);
|
||||||
|
reshapeHeight_ = 1;
|
||||||
|
for (size_t i = 0; i < heightAxis_.size(); i++) {
|
||||||
|
reshapeHeight_ *= outDims_[heightAxis_[i]];
|
||||||
|
}
|
||||||
|
output_.setFrameHeight(reshapeHeight_);
|
||||||
|
reshapeWidth_ = 1;
|
||||||
|
for (size_t i = 0; i < widthAxis_.size(); i++) {
|
||||||
|
reshapeWidth_ *= outDims_[widthAxis_[i]];
|
||||||
|
}
|
||||||
|
output_.setFrameWidth(reshapeWidth_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SwitchOrderLayer::setInDims() {
|
||||||
|
MatrixPtr input = inputLayers_[0]->getOutputValue();
|
||||||
|
size_t batchSize = input->getHeight();
|
||||||
|
inDims_.setDim(0, batchSize);
|
||||||
|
|
||||||
|
int h = inputLayers_[0]->getOutput().getFrameHeight();
|
||||||
|
if (h != 0) inDims_.setDim(2, h);
|
||||||
|
int w = inputLayers_[0]->getOutput().getFrameWidth();
|
||||||
|
if (w != 0) inDims_.setDim(3, w);
|
||||||
|
int totalCount = input->getElementCnt();
|
||||||
|
int channels = totalCount / (inDims_[0] * inDims_[2] * inDims_[3]);
|
||||||
|
if (channels != 0) inDims_.setDim(1, channels);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SwitchOrderLayer::forward(PassType passType) {
|
||||||
|
Layer::forward(passType);
|
||||||
|
setInDims();
|
||||||
|
setOutDims();
|
||||||
|
resetOutput(outDims_[0], outDims_[1] * outDims_[2] * outDims_[3]);
|
||||||
|
if (heightAxis_.size() > 0) {
|
||||||
|
getOutputValue()->reshape(reshapeHeight_, reshapeWidth_);
|
||||||
|
getOutputGrad()->reshape(reshapeHeight_, reshapeWidth_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// switch NCHW to NHWC
|
||||||
|
BufferArgs inputs;
|
||||||
|
BufferArgs outputs;
|
||||||
|
inputs.addArg(*getInputValue(0), inDims_);
|
||||||
|
outputs.addArg(*getOutputValue(), outDims_);
|
||||||
|
nchw2nhwc_[0]->calc(inputs, outputs);
|
||||||
|
forwardActivation();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SwitchOrderLayer::backward(const UpdateCallback& callback) {
|
||||||
|
(void)callback;
|
||||||
|
backwardActivation();
|
||||||
|
|
||||||
|
// switch NHWC to NCHW
|
||||||
|
BufferArgs inputs;
|
||||||
|
BufferArgs outputs;
|
||||||
|
inputs.addArg(*getOutputGrad(), outDims_);
|
||||||
|
outputs.addArg(*getInputGrad(0), inDims_, ADD_TO);
|
||||||
|
nhwc2nchw_[0]->calc(inputs, outputs);
|
||||||
|
}
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,47 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Layer.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief This layer calculate softmax in image channel dimension.
|
||||||
|
*/
|
||||||
|
class SwitchOrderLayer : public Layer {
|
||||||
|
public:
|
||||||
|
explicit SwitchOrderLayer(const LayerConfig& config) : Layer(config) {}
|
||||||
|
|
||||||
|
~SwitchOrderLayer() {}
|
||||||
|
|
||||||
|
bool init(const LayerMap& layerMap,
|
||||||
|
const ParameterMap& parameterMap) override;
|
||||||
|
void forward(PassType passType) override;
|
||||||
|
void backward(const UpdateCallback& callback = nullptr) override;
|
||||||
|
void setInDims();
|
||||||
|
void setOutDims();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<std::shared_ptr<FunctionBase>> nchw2nhwc_;
|
||||||
|
std::vector<std::shared_ptr<FunctionBase>> nhwc2nchw_;
|
||||||
|
TensorShape inDims_;
|
||||||
|
TensorShape outDims_;
|
||||||
|
std::vector<int> heightAxis_;
|
||||||
|
std::vector<int> widthAxis_;
|
||||||
|
size_t reshapeHeight_;
|
||||||
|
size_t reshapeWidth_;
|
||||||
|
};
|
||||||
|
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue