commit
089f8e2d37
@ -0,0 +1,51 @@
|
|||||||
|
set -e
|
||||||
|
|
||||||
|
unset OMP_NUM_THREADS MKL_NUM_THREADS
|
||||||
|
export OMP_DYNAMIC="FALSE"
|
||||||
|
export KMP_AFFINITY="granularity=fine,compact,0,0"
|
||||||
|
|
||||||
|
function train() {
|
||||||
|
topology=$1
|
||||||
|
bs=$2
|
||||||
|
use_mkldnn=$3
|
||||||
|
if [ $3 == "True" ]; then
|
||||||
|
use_mkldnn=$3
|
||||||
|
thread=1
|
||||||
|
log="logs/${topology}-mkldnn-${bs}.log"
|
||||||
|
elif [ $3 == "False" ]; then
|
||||||
|
use_mkldnn=$3
|
||||||
|
thread=`nproc`
|
||||||
|
log="logs/${topology}-${thread}mklml-${bs}.log"
|
||||||
|
else
|
||||||
|
echo "Wrong input $3, use True or False."
|
||||||
|
fi
|
||||||
|
args="batch_size=${bs}"
|
||||||
|
config="${topology}.py"
|
||||||
|
paddle train --job=time \
|
||||||
|
--config=$config \
|
||||||
|
--use_mkldnn=$use_mkldnn \
|
||||||
|
--use_gpu=False \
|
||||||
|
--trainer_count=$thread \
|
||||||
|
--log_period=10 \
|
||||||
|
--test_period=100 \
|
||||||
|
--config_args=$args \
|
||||||
|
2>&1 | tee ${log}
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ ! -d "train.list" ]; then
|
||||||
|
echo " " > train.list
|
||||||
|
fi
|
||||||
|
if [ ! -d "logs" ]; then
|
||||||
|
mkdir logs
|
||||||
|
fi
|
||||||
|
|
||||||
|
#========= mkldnn =========#
|
||||||
|
# vgg
|
||||||
|
train vgg 64 True
|
||||||
|
train vgg 128 True
|
||||||
|
train vgg 256 True
|
||||||
|
|
||||||
|
#========== mklml ===========#
|
||||||
|
train vgg 64 False
|
||||||
|
train vgg 128 False
|
||||||
|
train vgg 256 False
|
@ -0,0 +1,103 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from paddle.trainer_config_helpers import *
|
||||||
|
|
||||||
|
height = 224
|
||||||
|
width = 224
|
||||||
|
num_class = 1000
|
||||||
|
batch_size = get_config_arg('batch_size', int, 64)
|
||||||
|
layer_num = get_config_arg('layer_num', int, 19)
|
||||||
|
|
||||||
|
args = {'height': height, 'width': width, 'color': True, 'num_class': num_class}
|
||||||
|
define_py_data_sources2(
|
||||||
|
"train.list", None, module="provider", obj="process", args=args)
|
||||||
|
|
||||||
|
settings(
|
||||||
|
batch_size=batch_size,
|
||||||
|
learning_rate=0.01 / batch_size,
|
||||||
|
learning_method=MomentumOptimizer(0.9),
|
||||||
|
regularization=L2Regularization(0.0005 * batch_size))
|
||||||
|
|
||||||
|
img = data_layer(name='image', size=height * width * 3)
|
||||||
|
|
||||||
|
|
||||||
|
def vgg_network(vgg_num=3):
|
||||||
|
tmp = img_conv_group(
|
||||||
|
input=img,
|
||||||
|
num_channels=3,
|
||||||
|
conv_padding=1,
|
||||||
|
conv_num_filter=[64, 64],
|
||||||
|
conv_filter_size=3,
|
||||||
|
conv_act=ReluActivation(),
|
||||||
|
pool_size=2,
|
||||||
|
pool_stride=2,
|
||||||
|
pool_type=MaxPooling())
|
||||||
|
|
||||||
|
tmp = img_conv_group(
|
||||||
|
input=tmp,
|
||||||
|
conv_num_filter=[128, 128],
|
||||||
|
conv_padding=1,
|
||||||
|
conv_filter_size=3,
|
||||||
|
conv_act=ReluActivation(),
|
||||||
|
pool_stride=2,
|
||||||
|
pool_type=MaxPooling(),
|
||||||
|
pool_size=2)
|
||||||
|
|
||||||
|
channels = []
|
||||||
|
for i in range(vgg_num):
|
||||||
|
channels.append(256)
|
||||||
|
tmp = img_conv_group(
|
||||||
|
input=tmp,
|
||||||
|
conv_num_filter=channels,
|
||||||
|
conv_padding=1,
|
||||||
|
conv_filter_size=3,
|
||||||
|
conv_act=ReluActivation(),
|
||||||
|
pool_stride=2,
|
||||||
|
pool_type=MaxPooling(),
|
||||||
|
pool_size=2)
|
||||||
|
channels = []
|
||||||
|
for i in range(vgg_num):
|
||||||
|
channels.append(512)
|
||||||
|
tmp = img_conv_group(
|
||||||
|
input=tmp,
|
||||||
|
conv_num_filter=channels,
|
||||||
|
conv_padding=1,
|
||||||
|
conv_filter_size=3,
|
||||||
|
conv_act=ReluActivation(),
|
||||||
|
pool_stride=2,
|
||||||
|
pool_type=MaxPooling(),
|
||||||
|
pool_size=2)
|
||||||
|
tmp = img_conv_group(
|
||||||
|
input=tmp,
|
||||||
|
conv_num_filter=channels,
|
||||||
|
conv_padding=1,
|
||||||
|
conv_filter_size=3,
|
||||||
|
conv_act=ReluActivation(),
|
||||||
|
pool_stride=2,
|
||||||
|
pool_type=MaxPooling(),
|
||||||
|
pool_size=2)
|
||||||
|
|
||||||
|
tmp = fc_layer(
|
||||||
|
input=tmp,
|
||||||
|
size=4096,
|
||||||
|
act=ReluActivation(),
|
||||||
|
layer_attr=ExtraAttr(drop_rate=0.5))
|
||||||
|
|
||||||
|
tmp = fc_layer(
|
||||||
|
input=tmp,
|
||||||
|
size=4096,
|
||||||
|
act=ReluActivation(),
|
||||||
|
layer_attr=ExtraAttr(drop_rate=0.5))
|
||||||
|
|
||||||
|
return fc_layer(input=tmp, size=num_class, act=SoftmaxActivation())
|
||||||
|
|
||||||
|
|
||||||
|
if layer_num == 16:
|
||||||
|
vgg = vgg_network(3)
|
||||||
|
elif layer_num == 19:
|
||||||
|
vgg = vgg_network(4)
|
||||||
|
else:
|
||||||
|
print("Wrong layer number.")
|
||||||
|
|
||||||
|
lab = data_layer('label', num_class)
|
||||||
|
loss = cross_entropy(input=vgg, label=lab)
|
||||||
|
outputs(loss)
|
@ -0,0 +1,103 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/lstm_unit_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class LstmUnitOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
|
||||||
|
"Input(X) of LSTM should not be null.");
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("C_prev"),
|
||||||
|
"Input(C_prev) of LSTM should not be null.");
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("C"),
|
||||||
|
"Output(C) of LSTM should not be null.");
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("H"),
|
||||||
|
"Output(H) of LSTM should not be null.");
|
||||||
|
|
||||||
|
auto *x = ctx.Input<framework::Tensor>("X");
|
||||||
|
auto *c_prev = ctx.Input<framework::Tensor>("C_prev");
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
|
||||||
|
PADDLE_ENFORCE(x->dims()[0] == c_prev->dims()[0],
|
||||||
|
"Batch size of inputs and states must be equal");
|
||||||
|
PADDLE_ENFORCE(x->dims()[1] == c_prev->dims()[1] * 4,
|
||||||
|
"Dimension of FC should equal to prev state * 4");
|
||||||
|
|
||||||
|
int b_size = c_prev->dims()[0]; // batch size
|
||||||
|
int s_dim = c_prev->dims()[1]; // state dim
|
||||||
|
ctx.Output<framework::LoDTensor>("C")->Resize({b_size, s_dim});
|
||||||
|
ctx.Output<framework::LoDTensor>("H")->Resize({b_size, s_dim});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename AttrType>
|
||||||
|
class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
LstmUnitOpMaker(framework::OpProto *proto,
|
||||||
|
framework::OpAttrChecker *op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddInput("X", "FC input before the non-linear activation.");
|
||||||
|
AddInput(
|
||||||
|
"C_prev",
|
||||||
|
"The cell state tensor of last time-step in the Lstm Unit operator.");
|
||||||
|
AddOutput("C", "The cell tensor of Lstm Unit operator.");
|
||||||
|
AddOutput("H", "The hidden state tensor of Lstm Unit operator.");
|
||||||
|
|
||||||
|
AddComment(R"DOC(Lstm-Unit Operator
|
||||||
|
|
||||||
|
Equation:
|
||||||
|
i, f, o, j = split(X)
|
||||||
|
C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j)
|
||||||
|
H = C * sigm(o)
|
||||||
|
|
||||||
|
)DOC");
|
||||||
|
AddAttr<AttrType>("forget_bias", "The forget bias of Lstm Unit.")
|
||||||
|
.SetDefault(0.0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LstmUnitGradOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("C")),
|
||||||
|
"Input(C@GRAD) should not be null");
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("H")),
|
||||||
|
"Input(H@GRAD) should not be null");
|
||||||
|
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"))
|
||||||
|
->Resize(ctx.Input<Tensor>("X")->dims());
|
||||||
|
ctx.Output<framework::LoDTensor>(framework::GradVarName("C_prev"))
|
||||||
|
->Resize(ctx.Input<Tensor>("C_prev")->dims());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker<float>,
|
||||||
|
lstm_unit_grad, ops::LstmUnitGradOp);
|
||||||
|
REGISTER_OP_CPU_KERNEL(lstm_unit,
|
||||||
|
ops::LstmUnitKernel<paddle::platform::CPUPlace, float>);
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
lstm_unit_grad, ops::LstmUnitGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,173 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
#include "paddle/operators/cross_entropy_op.h"
|
||||||
|
#include "paddle/platform/assert.h"
|
||||||
|
#include "paddle/platform/hostdevice.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||||
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||||
|
i += blockDim.x * gridDim.x)
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
__device__ Dtype cuda_sigmoid(const Dtype x) {
|
||||||
|
return Dtype(1) / (Dtype(1) + exp(-x));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
__device__ Dtype cuda_tanh(const Dtype x) {
|
||||||
|
return Dtype(1 - exp(-2. * x)) / (Dtype(1) + exp(-2. * x));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void LSTMUnitKernel(const int nthreads, const int dim,
|
||||||
|
const T* C_prev, const T* X, T* C, T* H,
|
||||||
|
const T forget_bias) {
|
||||||
|
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||||
|
const int n = index / dim;
|
||||||
|
const int d = index % dim;
|
||||||
|
|
||||||
|
const T* X_offset = X + 4 * dim * n;
|
||||||
|
const T i = cuda_sigmoid(X_offset[d]);
|
||||||
|
const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias);
|
||||||
|
const T o = cuda_sigmoid(X_offset[2 * dim + d]);
|
||||||
|
const T g = cuda_tanh(X_offset[3 * dim + d]);
|
||||||
|
const T c_prev = C_prev[index];
|
||||||
|
const T c = f * c_prev + i * g;
|
||||||
|
C[index] = c;
|
||||||
|
const T tanh_c = cuda_tanh(c);
|
||||||
|
H[index] = o * tanh_c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void LSTMUnitGradientKernel(const int nthreads, const int dim,
|
||||||
|
const T* C_prev, const T* X, const T* C,
|
||||||
|
const T* H, const T* C_diff,
|
||||||
|
const T* H_diff, T* C_prev_diff,
|
||||||
|
T* X_diff, const T forget_bias) {
|
||||||
|
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||||
|
const int n = index / dim;
|
||||||
|
const int d = index % dim;
|
||||||
|
const T* X_offset = X + 4 * dim * n;
|
||||||
|
T* c_prev_diff = C_prev_diff + index;
|
||||||
|
T* X_diff_offset = X_diff + 4 * dim * n;
|
||||||
|
T* i_diff = X_diff_offset + d;
|
||||||
|
T* f_diff = X_diff_offset + 1 * dim + d;
|
||||||
|
T* o_diff = X_diff_offset + 2 * dim + d;
|
||||||
|
T* g_diff = X_diff_offset + 3 * dim + d;
|
||||||
|
|
||||||
|
const T i = cuda_sigmoid(X_offset[d]);
|
||||||
|
const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias);
|
||||||
|
const T o = cuda_sigmoid(X_offset[2 * dim + d]);
|
||||||
|
const T g = cuda_tanh(X_offset[3 * dim + d]);
|
||||||
|
const T c_prev = C_prev[index];
|
||||||
|
const T c = C[index];
|
||||||
|
const T tanh_c = cuda_tanh(c);
|
||||||
|
const T c_term_diff =
|
||||||
|
C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c);
|
||||||
|
*c_prev_diff = c_term_diff * f;
|
||||||
|
*i_diff = c_term_diff * g * i * (1 - i);
|
||||||
|
*f_diff = c_term_diff * c_prev * f * (1 - f);
|
||||||
|
*o_diff = H_diff[index] * tanh_c * o * (1 - o);
|
||||||
|
*g_diff = c_term_diff * i * (1 - g * g);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename AttrType = T>
|
||||||
|
class LstmUnitOpCUDAKernel : public framework::OpKernel {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||||
|
"It must use GPUPlace.");
|
||||||
|
|
||||||
|
auto* x_tensor = ctx.Input<framework::Tensor>("X");
|
||||||
|
auto* c_prev_tensor = ctx.Input<framework::Tensor>("C_prev");
|
||||||
|
auto* c_tensor = ctx.Output<framework::Tensor>("C");
|
||||||
|
auto* h_tensor = ctx.Output<framework::Tensor>("H");
|
||||||
|
|
||||||
|
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
||||||
|
|
||||||
|
int b_size = c_tensor->dims()[0];
|
||||||
|
int D = c_tensor->dims()[1];
|
||||||
|
|
||||||
|
const T* X = x_tensor->data<T>();
|
||||||
|
const T* C_prev = c_prev_tensor->data<T>();
|
||||||
|
|
||||||
|
T* C = c_tensor->mutable_data<T>(ctx.GetPlace());
|
||||||
|
T* H = h_tensor->mutable_data<T>(ctx.GetPlace());
|
||||||
|
|
||||||
|
int block = 512;
|
||||||
|
int n = b_size * D;
|
||||||
|
int grid = (n + block - 1) / block;
|
||||||
|
|
||||||
|
LSTMUnitKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, forget_bias);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename AttrType = T>
|
||||||
|
class LstmUnitGradOpCUDAKernel : public framework::OpKernel {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||||
|
"It must use GPUPlace.");
|
||||||
|
|
||||||
|
auto x_tensor = ctx.Input<Tensor>("X");
|
||||||
|
auto c_prev_tensor = ctx.Input<Tensor>("C_prev");
|
||||||
|
auto c_tensor = ctx.Input<Tensor>("C");
|
||||||
|
auto h_tensor = ctx.Input<Tensor>("H");
|
||||||
|
|
||||||
|
auto hdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("H"));
|
||||||
|
auto cdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("C"));
|
||||||
|
|
||||||
|
auto xdiff_tensor = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||||
|
auto c_prev_diff_tensor =
|
||||||
|
ctx.Output<Tensor>(framework::GradVarName("C_prev"));
|
||||||
|
|
||||||
|
auto* X = x_tensor->data<T>();
|
||||||
|
auto* C_prev = c_prev_tensor->data<T>();
|
||||||
|
auto* C = c_tensor->data<T>();
|
||||||
|
auto* H = h_tensor->data<T>();
|
||||||
|
|
||||||
|
auto* H_diff = hdiff_tensor->data<T>();
|
||||||
|
auto* C_diff = cdiff_tensor->data<T>();
|
||||||
|
|
||||||
|
auto* C_prev_diff = c_prev_diff_tensor->mutable_data<T>(ctx.GetPlace());
|
||||||
|
auto* X_diff = xdiff_tensor->mutable_data<T>(ctx.GetPlace());
|
||||||
|
|
||||||
|
int N = c_tensor->dims()[0];
|
||||||
|
int D = c_tensor->dims()[1];
|
||||||
|
|
||||||
|
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
||||||
|
|
||||||
|
int block = 512;
|
||||||
|
int n = N * D;
|
||||||
|
int grid = (n + block - 1) / block;
|
||||||
|
|
||||||
|
LSTMUnitGradientKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, C_diff,
|
||||||
|
H_diff, C_prev_diff, X_diff,
|
||||||
|
forget_bias);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_GPU_KERNEL(lstm_unit, ops::LstmUnitOpCUDAKernel<float>);
|
||||||
|
REGISTER_OP_GPU_KERNEL(lstm_unit_grad, ops::LstmUnitGradOpCUDAKernel<float>);
|
@ -0,0 +1,148 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "glog/logging.h"
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using framework::LoDTensor;
|
||||||
|
using framework::Tensor;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T sigmoid(T x) {
|
||||||
|
return 1. / (1. + exp(-x));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T tanh(T x) {
|
||||||
|
return 2. * sigmoid(2. * x) - 1.;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Place, typename T, typename AttrType = T>
|
||||||
|
class LstmUnitKernel : public framework::OpKernel {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
|
||||||
|
"It must use CPUPlace.");
|
||||||
|
|
||||||
|
auto* x_tensor = ctx.Input<framework::Tensor>("X");
|
||||||
|
auto* c_prev_tensor = ctx.Input<framework::Tensor>("C_prev");
|
||||||
|
auto* c_tensor = ctx.Output<framework::Tensor>("C");
|
||||||
|
auto* h_tensor = ctx.Output<framework::Tensor>("H");
|
||||||
|
|
||||||
|
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
||||||
|
|
||||||
|
int b_size = c_tensor->dims()[0];
|
||||||
|
int D = c_tensor->dims()[1];
|
||||||
|
|
||||||
|
T* C = c_tensor->mutable_data<T>(ctx.GetPlace());
|
||||||
|
T* H = h_tensor->mutable_data<T>(ctx.GetPlace());
|
||||||
|
|
||||||
|
const T* X = x_tensor->data<T>();
|
||||||
|
const T* C_prev = c_prev_tensor->data<T>();
|
||||||
|
|
||||||
|
for (int n = 0; n < b_size; ++n) {
|
||||||
|
for (int d = 0; d < D; ++d) {
|
||||||
|
const T i = sigmoid(X[d]);
|
||||||
|
const T f = sigmoid(X[1 * D + d] + forget_bias);
|
||||||
|
const T o = sigmoid(X[2 * D + d]);
|
||||||
|
const T g = tanh(X[3 * D + d]);
|
||||||
|
const T c_prev = C_prev[d];
|
||||||
|
const T c = f * c_prev + i * g;
|
||||||
|
C[d] = c;
|
||||||
|
const T tanh_c = tanh(c);
|
||||||
|
H[d] = o * tanh_c;
|
||||||
|
}
|
||||||
|
C_prev += D;
|
||||||
|
X += 4 * D;
|
||||||
|
C += D;
|
||||||
|
H += D;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Place, typename T, typename AttrType = T>
|
||||||
|
class LstmUnitGradKernel : public framework::OpKernel {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
|
||||||
|
"It must use CPUPlace.");
|
||||||
|
|
||||||
|
auto x_tensor = ctx.Input<Tensor>("X");
|
||||||
|
auto c_prev_tensor = ctx.Input<Tensor>("C_prev");
|
||||||
|
auto c_tensor = ctx.Input<Tensor>("C");
|
||||||
|
auto h_tensor = ctx.Input<Tensor>("H");
|
||||||
|
|
||||||
|
auto hdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("H"));
|
||||||
|
auto cdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("C"));
|
||||||
|
|
||||||
|
auto xdiff_tensor = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||||
|
auto c_prev_diff_tensor =
|
||||||
|
ctx.Output<Tensor>(framework::GradVarName("C_prev"));
|
||||||
|
|
||||||
|
auto* X = x_tensor->data<T>();
|
||||||
|
auto* C_prev = c_prev_tensor->data<T>();
|
||||||
|
auto* C = c_tensor->data<T>();
|
||||||
|
auto* H = h_tensor->data<T>();
|
||||||
|
|
||||||
|
auto* H_diff = hdiff_tensor->data<T>();
|
||||||
|
auto* C_diff = cdiff_tensor->data<T>();
|
||||||
|
|
||||||
|
auto* C_prev_diff = c_prev_diff_tensor->mutable_data<T>(ctx.GetPlace());
|
||||||
|
auto* X_diff = xdiff_tensor->mutable_data<T>(ctx.GetPlace());
|
||||||
|
|
||||||
|
int N = c_tensor->dims()[0];
|
||||||
|
int D = c_tensor->dims()[1];
|
||||||
|
|
||||||
|
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
||||||
|
|
||||||
|
for (int n = 0; n < N; ++n) {
|
||||||
|
for (int d = 0; d < D; ++d) {
|
||||||
|
T* c_prev_diff = C_prev_diff + d;
|
||||||
|
T* i_diff = X_diff + d;
|
||||||
|
T* f_diff = X_diff + 1 * D + d;
|
||||||
|
T* o_diff = X_diff + 2 * D + d;
|
||||||
|
T* g_diff = X_diff + 3 * D + d;
|
||||||
|
|
||||||
|
const T i = sigmoid(X[d]);
|
||||||
|
const T f = sigmoid(X[1 * D + d] + forget_bias);
|
||||||
|
const T o = sigmoid(X[2 * D + d]);
|
||||||
|
const T g = tanh(X[3 * D + d]);
|
||||||
|
const T c_prev = C_prev[d];
|
||||||
|
const T c = C[d];
|
||||||
|
const T tanh_c = tanh(c);
|
||||||
|
const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - tanh_c * tanh_c);
|
||||||
|
*c_prev_diff = c_term_diff * f;
|
||||||
|
*i_diff = c_term_diff * g * i * (1 - i);
|
||||||
|
*f_diff = c_term_diff * c_prev * f * (1 - f);
|
||||||
|
*o_diff = H_diff[d] * tanh_c * o * (1 - o);
|
||||||
|
*g_diff = c_term_diff * i * (1 - g * g);
|
||||||
|
}
|
||||||
|
C_prev += D;
|
||||||
|
X += 4 * D;
|
||||||
|
C += D;
|
||||||
|
H += D;
|
||||||
|
C_diff += D;
|
||||||
|
H_diff += D;
|
||||||
|
X_diff += 4 * D;
|
||||||
|
C_prev_diff += D;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,63 @@
|
|||||||
|
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from paddle.trainer_config_helpers import *
|
||||||
|
|
||||||
|
################################### Data Configuration ###################################
|
||||||
|
TrainData(ProtoData(files = "trainer/tests/mnist.list"))
|
||||||
|
################################### Algorithm Configuration ###################################
|
||||||
|
settings(batch_size = 1000,
|
||||||
|
learning_method = MomentumOptimizer(momentum=0.5, sparse=False))
|
||||||
|
################################### Network Configuration ###################################
|
||||||
|
data = data_layer(name ="input", size=784)
|
||||||
|
|
||||||
|
tmp = img_conv_layer(input=data,
|
||||||
|
num_channels=1,
|
||||||
|
filter_size=3,
|
||||||
|
num_filters=32,
|
||||||
|
padding=1,
|
||||||
|
shared_biases=True,
|
||||||
|
act=ReluActivation())
|
||||||
|
|
||||||
|
tmp = img_pool_layer(input=tmp,
|
||||||
|
pool_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
pool_type=AvgPooling())
|
||||||
|
|
||||||
|
tmp = img_conv_layer(input=tmp,
|
||||||
|
filter_size=3,
|
||||||
|
num_filters=64,
|
||||||
|
padding=1,
|
||||||
|
shared_biases=True,
|
||||||
|
act=ReluActivation())
|
||||||
|
|
||||||
|
tmp = img_pool_layer(input=tmp,
|
||||||
|
pool_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
pool_type=MaxPooling())
|
||||||
|
|
||||||
|
tmp = fc_layer(input=tmp, size=64,
|
||||||
|
bias_attr=True,
|
||||||
|
act=ReluActivation())
|
||||||
|
|
||||||
|
output = fc_layer(input=tmp, size=10,
|
||||||
|
bias_attr=True,
|
||||||
|
act=SoftmaxActivation())
|
||||||
|
|
||||||
|
lbl = data_layer(name ="label", size=10)
|
||||||
|
|
||||||
|
cost = classification_cost(input=output, label=lbl)
|
||||||
|
outputs(cost)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue