update-doc-pybind
commit
1cd2014007
@ -0,0 +1,51 @@
|
||||
set -e
|
||||
|
||||
unset OMP_NUM_THREADS MKL_NUM_THREADS
|
||||
export OMP_DYNAMIC="FALSE"
|
||||
export KMP_AFFINITY="granularity=fine,compact,0,0"
|
||||
|
||||
function train() {
|
||||
topology=$1
|
||||
bs=$2
|
||||
use_mkldnn=$3
|
||||
if [ $3 == "True" ]; then
|
||||
use_mkldnn=$3
|
||||
thread=1
|
||||
log="logs/${topology}-mkldnn-${bs}.log"
|
||||
elif [ $3 == "False" ]; then
|
||||
use_mkldnn=$3
|
||||
thread=`nproc`
|
||||
log="logs/${topology}-${thread}mklml-${bs}.log"
|
||||
else
|
||||
echo "Wrong input $3, use True or False."
|
||||
fi
|
||||
args="batch_size=${bs}"
|
||||
config="${topology}.py"
|
||||
paddle train --job=time \
|
||||
--config=$config \
|
||||
--use_mkldnn=$use_mkldnn \
|
||||
--use_gpu=False \
|
||||
--trainer_count=$thread \
|
||||
--log_period=10 \
|
||||
--test_period=100 \
|
||||
--config_args=$args \
|
||||
2>&1 | tee ${log}
|
||||
}
|
||||
|
||||
if [ ! -d "train.list" ]; then
|
||||
echo " " > train.list
|
||||
fi
|
||||
if [ ! -d "logs" ]; then
|
||||
mkdir logs
|
||||
fi
|
||||
|
||||
#========= mkldnn =========#
|
||||
# vgg
|
||||
train vgg 64 True
|
||||
train vgg 128 True
|
||||
train vgg 256 True
|
||||
|
||||
#========== mklml ===========#
|
||||
train vgg 64 False
|
||||
train vgg 128 False
|
||||
train vgg 256 False
|
@ -0,0 +1,103 @@
|
||||
#!/usr/bin/env python
|
||||
from paddle.trainer_config_helpers import *
|
||||
|
||||
height = 224
|
||||
width = 224
|
||||
num_class = 1000
|
||||
batch_size = get_config_arg('batch_size', int, 64)
|
||||
layer_num = get_config_arg('layer_num', int, 19)
|
||||
|
||||
args = {'height': height, 'width': width, 'color': True, 'num_class': num_class}
|
||||
define_py_data_sources2(
|
||||
"train.list", None, module="provider", obj="process", args=args)
|
||||
|
||||
settings(
|
||||
batch_size=batch_size,
|
||||
learning_rate=0.01 / batch_size,
|
||||
learning_method=MomentumOptimizer(0.9),
|
||||
regularization=L2Regularization(0.0005 * batch_size))
|
||||
|
||||
img = data_layer(name='image', size=height * width * 3)
|
||||
|
||||
|
||||
def vgg_network(vgg_num=3):
|
||||
tmp = img_conv_group(
|
||||
input=img,
|
||||
num_channels=3,
|
||||
conv_padding=1,
|
||||
conv_num_filter=[64, 64],
|
||||
conv_filter_size=3,
|
||||
conv_act=ReluActivation(),
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
pool_type=MaxPooling())
|
||||
|
||||
tmp = img_conv_group(
|
||||
input=tmp,
|
||||
conv_num_filter=[128, 128],
|
||||
conv_padding=1,
|
||||
conv_filter_size=3,
|
||||
conv_act=ReluActivation(),
|
||||
pool_stride=2,
|
||||
pool_type=MaxPooling(),
|
||||
pool_size=2)
|
||||
|
||||
channels = []
|
||||
for i in range(vgg_num):
|
||||
channels.append(256)
|
||||
tmp = img_conv_group(
|
||||
input=tmp,
|
||||
conv_num_filter=channels,
|
||||
conv_padding=1,
|
||||
conv_filter_size=3,
|
||||
conv_act=ReluActivation(),
|
||||
pool_stride=2,
|
||||
pool_type=MaxPooling(),
|
||||
pool_size=2)
|
||||
channels = []
|
||||
for i in range(vgg_num):
|
||||
channels.append(512)
|
||||
tmp = img_conv_group(
|
||||
input=tmp,
|
||||
conv_num_filter=channels,
|
||||
conv_padding=1,
|
||||
conv_filter_size=3,
|
||||
conv_act=ReluActivation(),
|
||||
pool_stride=2,
|
||||
pool_type=MaxPooling(),
|
||||
pool_size=2)
|
||||
tmp = img_conv_group(
|
||||
input=tmp,
|
||||
conv_num_filter=channels,
|
||||
conv_padding=1,
|
||||
conv_filter_size=3,
|
||||
conv_act=ReluActivation(),
|
||||
pool_stride=2,
|
||||
pool_type=MaxPooling(),
|
||||
pool_size=2)
|
||||
|
||||
tmp = fc_layer(
|
||||
input=tmp,
|
||||
size=4096,
|
||||
act=ReluActivation(),
|
||||
layer_attr=ExtraAttr(drop_rate=0.5))
|
||||
|
||||
tmp = fc_layer(
|
||||
input=tmp,
|
||||
size=4096,
|
||||
act=ReluActivation(),
|
||||
layer_attr=ExtraAttr(drop_rate=0.5))
|
||||
|
||||
return fc_layer(input=tmp, size=num_class, act=SoftmaxActivation())
|
||||
|
||||
|
||||
if layer_num == 16:
|
||||
vgg = vgg_network(3)
|
||||
elif layer_num == 19:
|
||||
vgg = vgg_network(4)
|
||||
else:
|
||||
print("Wrong layer number.")
|
||||
|
||||
lab = data_layer('label', num_class)
|
||||
loss = cross_entropy(input=vgg, label=lab)
|
||||
outputs(loss)
|
@ -0,0 +1,71 @@
|
||||
# Cluster bootstrapping tool survey
|
||||
## Abstract
|
||||
In order to bring up a cluster from bare metal machine to a fully functional kubernetes cluster for Paddlepaddle to run, we need to utilize some tools. Here we are going to compare [Sextant](https://github.com/k8sp/sextant) and [Tectonic installer](https://github.com/coreos/tectonic-installer)
|
||||
|
||||
## Basic assumptions
|
||||
Here are some basic assumptions before we move on to details
|
||||
1. You are an administrator of a bare metal machine cluster, which means:
|
||||
* you have full control to each of the machines.
|
||||
* you have full control to the network which machines are connected to.
|
||||
2. Machines can be booted from network with PEX or iPXE
|
||||
3. You understand the [general procedure to bring up a cluster](#appendix-general-procedure-to-bring-up-a-cluster)
|
||||
|
||||
if your cluster is able to mark above items with checkmarks, then keep reading.
|
||||
|
||||
## Comparing Sextant and Tectonic installer
|
||||
### Sextant
|
||||
Sextant is an end2end solution to bring up a bare metal cluster to a fully functional k8s cluster, it integrates DHCP, name service, PEX, cloud-config-service, docker registry services altogether.
|
||||
|
||||
#### Pros
|
||||
1. End2End: basically all admin need to do is to config the cluster.yaml and power on the cluster.
|
||||
2. Offline cluster configuration: Sextant has 2 phases during working with it, config time and deploy time. when admin is configuring, it requires admin's machine has internet connectivity, which will download some images, etc. But in deploy time, it's completely OK to go offline since all dependencies are ready during config time.
|
||||
3. docker registry integrated.
|
||||
4. GPU machine took care of.
|
||||
|
||||
### Cons
|
||||
1. k8s API server is not deployed with high availability in considering by default.
|
||||
2. No grouping support.
|
||||
3. No API interface, a one-off service.
|
||||
|
||||
|
||||
### Tectonic installer
|
||||
First of all, Tectonic is not free, it requires coreos.com account as a step of installation, and free user can only create less than 10 nodes.
|
||||
|
||||
Tectonic is a suite of software which wraps around k8s and providing more utility regarding dev ops, ie,
|
||||
Tectonic installer as it's named, it installs Tectonic to a bare metal cluster which means it's not totally an equivalent of Sextant. At the "booting a cluster" part, it mostly utilizes [Matchbox](https://github.com/coreos/matchbox), which is a general cluster bootstrapper.
|
||||
|
||||
Matchbox's Approach is similar to Sexstant.
|
||||
|
||||
### Pros
|
||||
1. supports grouping machines.
|
||||
2. supports running provisioning service in rtk. (not a big deal though).
|
||||
3. supports http/gRPC API interface.
|
||||
4. supports multi-template.
|
||||
|
||||
### Cons
|
||||
1. Not an e2e solution to bring up a cluster, need a lot of extra work and other software.
|
||||
2. [Not fully supporting](https://github.com/coreos/matchbox/issues/550) centOS deployment yet.
|
||||
|
||||
## Conclusion
|
||||
Sextant is a better solution overall for paddle cloud deploying to a bare metal cluster. It would be great if Sextant can also 1) deploy k8s api server with high availability by default; 2) not designed as a one-off service.
|
||||
|
||||
|
||||
|
||||
## Appendix: General procedure to bring up a cluster
|
||||
It's physically impossible for a cluster admin to manually install OS and applications into cluster nodes one by one, here is what an admin would do in cloud industry:
|
||||
1. setup a bootstrap machine with static IP in the cluster, which has following services:
|
||||
* DHCP: assigns ip address for rest of the nodes.
|
||||
* name service: to map node name to a IP
|
||||
* PXE related services: the booting related info will be delivered to newly booted machines as their IP is assigned via DHCP service, PXE service will provide further booting and installing info and image with TFTP and http protocol.
|
||||
* cluster config service: this is for providing cluster node with OS config via http
|
||||
* optional docker registry: a built-in docker registry makes the whole cluster independent from connecting internet, and speeds up software distribution.
|
||||
2. New node powers on, it will
|
||||
* broadcast the request for an IP address
|
||||
* DHCP server assigns the IP address, and deliver the PXE booting related info to the node.
|
||||
* cluster node will request config files with booting info delivered with DHCP via the TFTP service, and in most of the cases, the config file will point to a http service for the booting image.
|
||||
* Since PXE is configured with initrd, it will utilize the cloud config service and do further installations like coreOS or K8s installations.
|
||||
* then restart the node.
|
||||
|
||||
For further understanding, following 2 links from Matchbox are some good readings:
|
||||
* [Machine lifecycle](https://github.com/coreos/matchbox/blob/master/Documentation/machine-lifecycle.md)
|
||||
* [PXE booting](https://github.com/coreos/matchbox/blob/master/Documentation/network-booting.md)
|
@ -0,0 +1,103 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/lstm_unit_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class LstmUnitOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
|
||||
"Input(X) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("C_prev"),
|
||||
"Input(C_prev) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("C"),
|
||||
"Output(C) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("H"),
|
||||
"Output(H) of LSTM should not be null.");
|
||||
|
||||
auto *x = ctx.Input<framework::Tensor>("X");
|
||||
auto *c_prev = ctx.Input<framework::Tensor>("C_prev");
|
||||
|
||||
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
|
||||
PADDLE_ENFORCE(x->dims()[0] == c_prev->dims()[0],
|
||||
"Batch size of inputs and states must be equal");
|
||||
PADDLE_ENFORCE(x->dims()[1] == c_prev->dims()[1] * 4,
|
||||
"Dimension of FC should equal to prev state * 4");
|
||||
|
||||
int b_size = c_prev->dims()[0]; // batch size
|
||||
int s_dim = c_prev->dims()[1]; // state dim
|
||||
ctx.Output<framework::LoDTensor>("C")->Resize({b_size, s_dim});
|
||||
ctx.Output<framework::LoDTensor>("H")->Resize({b_size, s_dim});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AttrType>
|
||||
class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
LstmUnitOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "FC input before the non-linear activation.");
|
||||
AddInput(
|
||||
"C_prev",
|
||||
"The cell state tensor of last time-step in the Lstm Unit operator.");
|
||||
AddOutput("C", "The cell tensor of Lstm Unit operator.");
|
||||
AddOutput("H", "The hidden state tensor of Lstm Unit operator.");
|
||||
|
||||
AddComment(R"DOC(Lstm-Unit Operator
|
||||
|
||||
Equation:
|
||||
i, f, o, j = split(X)
|
||||
C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j)
|
||||
H = C * sigm(o)
|
||||
|
||||
)DOC");
|
||||
AddAttr<AttrType>("forget_bias", "The forget bias of Lstm Unit.")
|
||||
.SetDefault(0.0);
|
||||
}
|
||||
};
|
||||
|
||||
class LstmUnitGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("C")),
|
||||
"Input(C@GRAD) should not be null");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("H")),
|
||||
"Input(H@GRAD) should not be null");
|
||||
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"))
|
||||
->Resize(ctx.Input<Tensor>("X")->dims());
|
||||
ctx.Output<framework::LoDTensor>(framework::GradVarName("C_prev"))
|
||||
->Resize(ctx.Input<Tensor>("C_prev")->dims());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker<float>,
|
||||
lstm_unit_grad, ops::LstmUnitGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(lstm_unit,
|
||||
ops::LstmUnitKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
lstm_unit_grad, ops::LstmUnitGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,173 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/cross_entropy_op.h"
|
||||
#include "paddle/platform/assert.h"
|
||||
#include "paddle/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
template <typename Dtype>
|
||||
__device__ Dtype cuda_sigmoid(const Dtype x) {
|
||||
return Dtype(1) / (Dtype(1) + exp(-x));
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__device__ Dtype cuda_tanh(const Dtype x) {
|
||||
return Dtype(1 - exp(-2. * x)) / (Dtype(1) + exp(-2. * x));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void LSTMUnitKernel(const int nthreads, const int dim,
|
||||
const T* C_prev, const T* X, T* C, T* H,
|
||||
const T forget_bias) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
const int n = index / dim;
|
||||
const int d = index % dim;
|
||||
|
||||
const T* X_offset = X + 4 * dim * n;
|
||||
const T i = cuda_sigmoid(X_offset[d]);
|
||||
const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias);
|
||||
const T o = cuda_sigmoid(X_offset[2 * dim + d]);
|
||||
const T g = cuda_tanh(X_offset[3 * dim + d]);
|
||||
const T c_prev = C_prev[index];
|
||||
const T c = f * c_prev + i * g;
|
||||
C[index] = c;
|
||||
const T tanh_c = cuda_tanh(c);
|
||||
H[index] = o * tanh_c;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void LSTMUnitGradientKernel(const int nthreads, const int dim,
|
||||
const T* C_prev, const T* X, const T* C,
|
||||
const T* H, const T* C_diff,
|
||||
const T* H_diff, T* C_prev_diff,
|
||||
T* X_diff, const T forget_bias) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
const int n = index / dim;
|
||||
const int d = index % dim;
|
||||
const T* X_offset = X + 4 * dim * n;
|
||||
T* c_prev_diff = C_prev_diff + index;
|
||||
T* X_diff_offset = X_diff + 4 * dim * n;
|
||||
T* i_diff = X_diff_offset + d;
|
||||
T* f_diff = X_diff_offset + 1 * dim + d;
|
||||
T* o_diff = X_diff_offset + 2 * dim + d;
|
||||
T* g_diff = X_diff_offset + 3 * dim + d;
|
||||
|
||||
const T i = cuda_sigmoid(X_offset[d]);
|
||||
const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias);
|
||||
const T o = cuda_sigmoid(X_offset[2 * dim + d]);
|
||||
const T g = cuda_tanh(X_offset[3 * dim + d]);
|
||||
const T c_prev = C_prev[index];
|
||||
const T c = C[index];
|
||||
const T tanh_c = cuda_tanh(c);
|
||||
const T c_term_diff =
|
||||
C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c);
|
||||
*c_prev_diff = c_term_diff * f;
|
||||
*i_diff = c_term_diff * g * i * (1 - i);
|
||||
*f_diff = c_term_diff * c_prev * f * (1 - f);
|
||||
*o_diff = H_diff[index] * tanh_c * o * (1 - o);
|
||||
*g_diff = c_term_diff * i * (1 - g * g);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename AttrType = T>
|
||||
class LstmUnitOpCUDAKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"It must use GPUPlace.");
|
||||
|
||||
auto* x_tensor = ctx.Input<framework::Tensor>("X");
|
||||
auto* c_prev_tensor = ctx.Input<framework::Tensor>("C_prev");
|
||||
auto* c_tensor = ctx.Output<framework::Tensor>("C");
|
||||
auto* h_tensor = ctx.Output<framework::Tensor>("H");
|
||||
|
||||
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
||||
|
||||
int b_size = c_tensor->dims()[0];
|
||||
int D = c_tensor->dims()[1];
|
||||
|
||||
const T* X = x_tensor->data<T>();
|
||||
const T* C_prev = c_prev_tensor->data<T>();
|
||||
|
||||
T* C = c_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
T* H = h_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int block = 512;
|
||||
int n = b_size * D;
|
||||
int grid = (n + block - 1) / block;
|
||||
|
||||
LSTMUnitKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, forget_bias);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename AttrType = T>
|
||||
class LstmUnitGradOpCUDAKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"It must use GPUPlace.");
|
||||
|
||||
auto x_tensor = ctx.Input<Tensor>("X");
|
||||
auto c_prev_tensor = ctx.Input<Tensor>("C_prev");
|
||||
auto c_tensor = ctx.Input<Tensor>("C");
|
||||
auto h_tensor = ctx.Input<Tensor>("H");
|
||||
|
||||
auto hdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("H"));
|
||||
auto cdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("C"));
|
||||
|
||||
auto xdiff_tensor = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto c_prev_diff_tensor =
|
||||
ctx.Output<Tensor>(framework::GradVarName("C_prev"));
|
||||
|
||||
auto* X = x_tensor->data<T>();
|
||||
auto* C_prev = c_prev_tensor->data<T>();
|
||||
auto* C = c_tensor->data<T>();
|
||||
auto* H = h_tensor->data<T>();
|
||||
|
||||
auto* H_diff = hdiff_tensor->data<T>();
|
||||
auto* C_diff = cdiff_tensor->data<T>();
|
||||
|
||||
auto* C_prev_diff = c_prev_diff_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
auto* X_diff = xdiff_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int N = c_tensor->dims()[0];
|
||||
int D = c_tensor->dims()[1];
|
||||
|
||||
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
||||
|
||||
int block = 512;
|
||||
int n = N * D;
|
||||
int grid = (n + block - 1) / block;
|
||||
|
||||
LSTMUnitGradientKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, C_diff,
|
||||
H_diff, C_prev_diff, X_diff,
|
||||
forget_bias);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(lstm_unit, ops::LstmUnitOpCUDAKernel<float>);
|
||||
REGISTER_OP_GPU_KERNEL(lstm_unit_grad, ops::LstmUnitGradOpCUDAKernel<float>);
|
@ -0,0 +1,148 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::LoDTensor;
|
||||
using framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
inline T sigmoid(T x) {
|
||||
return 1. / (1. + exp(-x));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T tanh(T x) {
|
||||
return 2. * sigmoid(2. * x) - 1.;
|
||||
}
|
||||
|
||||
template <typename Place, typename T, typename AttrType = T>
|
||||
class LstmUnitKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
|
||||
"It must use CPUPlace.");
|
||||
|
||||
auto* x_tensor = ctx.Input<framework::Tensor>("X");
|
||||
auto* c_prev_tensor = ctx.Input<framework::Tensor>("C_prev");
|
||||
auto* c_tensor = ctx.Output<framework::Tensor>("C");
|
||||
auto* h_tensor = ctx.Output<framework::Tensor>("H");
|
||||
|
||||
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
||||
|
||||
int b_size = c_tensor->dims()[0];
|
||||
int D = c_tensor->dims()[1];
|
||||
|
||||
T* C = c_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
T* H = h_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
const T* X = x_tensor->data<T>();
|
||||
const T* C_prev = c_prev_tensor->data<T>();
|
||||
|
||||
for (int n = 0; n < b_size; ++n) {
|
||||
for (int d = 0; d < D; ++d) {
|
||||
const T i = sigmoid(X[d]);
|
||||
const T f = sigmoid(X[1 * D + d] + forget_bias);
|
||||
const T o = sigmoid(X[2 * D + d]);
|
||||
const T g = tanh(X[3 * D + d]);
|
||||
const T c_prev = C_prev[d];
|
||||
const T c = f * c_prev + i * g;
|
||||
C[d] = c;
|
||||
const T tanh_c = tanh(c);
|
||||
H[d] = o * tanh_c;
|
||||
}
|
||||
C_prev += D;
|
||||
X += 4 * D;
|
||||
C += D;
|
||||
H += D;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T, typename AttrType = T>
|
||||
class LstmUnitGradKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
|
||||
"It must use CPUPlace.");
|
||||
|
||||
auto x_tensor = ctx.Input<Tensor>("X");
|
||||
auto c_prev_tensor = ctx.Input<Tensor>("C_prev");
|
||||
auto c_tensor = ctx.Input<Tensor>("C");
|
||||
auto h_tensor = ctx.Input<Tensor>("H");
|
||||
|
||||
auto hdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("H"));
|
||||
auto cdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("C"));
|
||||
|
||||
auto xdiff_tensor = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto c_prev_diff_tensor =
|
||||
ctx.Output<Tensor>(framework::GradVarName("C_prev"));
|
||||
|
||||
auto* X = x_tensor->data<T>();
|
||||
auto* C_prev = c_prev_tensor->data<T>();
|
||||
auto* C = c_tensor->data<T>();
|
||||
auto* H = h_tensor->data<T>();
|
||||
|
||||
auto* H_diff = hdiff_tensor->data<T>();
|
||||
auto* C_diff = cdiff_tensor->data<T>();
|
||||
|
||||
auto* C_prev_diff = c_prev_diff_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
auto* X_diff = xdiff_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int N = c_tensor->dims()[0];
|
||||
int D = c_tensor->dims()[1];
|
||||
|
||||
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int d = 0; d < D; ++d) {
|
||||
T* c_prev_diff = C_prev_diff + d;
|
||||
T* i_diff = X_diff + d;
|
||||
T* f_diff = X_diff + 1 * D + d;
|
||||
T* o_diff = X_diff + 2 * D + d;
|
||||
T* g_diff = X_diff + 3 * D + d;
|
||||
|
||||
const T i = sigmoid(X[d]);
|
||||
const T f = sigmoid(X[1 * D + d] + forget_bias);
|
||||
const T o = sigmoid(X[2 * D + d]);
|
||||
const T g = tanh(X[3 * D + d]);
|
||||
const T c_prev = C_prev[d];
|
||||
const T c = C[d];
|
||||
const T tanh_c = tanh(c);
|
||||
const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - tanh_c * tanh_c);
|
||||
*c_prev_diff = c_term_diff * f;
|
||||
*i_diff = c_term_diff * g * i * (1 - i);
|
||||
*f_diff = c_term_diff * c_prev * f * (1 - f);
|
||||
*o_diff = H_diff[d] * tanh_c * o * (1 - o);
|
||||
*g_diff = c_term_diff * i * (1 - g * g);
|
||||
}
|
||||
C_prev += D;
|
||||
X += 4 * D;
|
||||
C += D;
|
||||
H += D;
|
||||
C_diff += D;
|
||||
H_diff += D;
|
||||
X_diff += 4 * D;
|
||||
C_prev_diff += D;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,113 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/multiplex_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
class MultiplexOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
|
||||
"Input(X) should not be null");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
|
||||
"Output(Out) shouldn't be null.");
|
||||
auto ins = ctx.MultiInput<Tensor>("X");
|
||||
auto *out = ctx.Output<LoDTensor>("Out");
|
||||
auto num_ins = ins.size();
|
||||
PADDLE_ENFORCE(num_ins > 2,
|
||||
"multiplex operator should have more than 2 inputs.");
|
||||
PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1,
|
||||
"The first input must be a index vector.");
|
||||
auto in_dim = ins[1]->dims();
|
||||
|
||||
for (size_t i = 2; i < num_ins; i++) {
|
||||
auto dim = ins[i]->dims();
|
||||
PADDLE_ENFORCE(
|
||||
in_dim == dim,
|
||||
"All the input tensors except the first one must have the same size");
|
||||
}
|
||||
out->Resize(in_dim);
|
||||
}
|
||||
};
|
||||
|
||||
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
MultiplexOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "The input tensors of multiplex operator.").AsDuplicable();
|
||||
AddOutput("Out", "The output tensor of multiplex operator.");
|
||||
AddComment(R"DOC(Multiplex operator
|
||||
|
||||
Multiplex multiple tensors according to the index provided by the first
|
||||
input tensor.
|
||||
|
||||
ins[0]: the index tensor.
|
||||
ins[1:N]: the candidate output tensors.
|
||||
For each index i from 0 to batchSize - 1, the output is the i-th row of the
|
||||
the (index[i] + 1)-th tensor.
|
||||
|
||||
For i-th row of the output tensor:
|
||||
|
||||
y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1)
|
||||
|
||||
where y is the output tensor. `x_{k}` is the k-th input tensor
|
||||
and `k = x{0}[i] + 1`.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class MultiplexGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
|
||||
"Input(X) should not be null");
|
||||
PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(),
|
||||
"Output(X@Grad) should not be null");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) shouldn't be null.");
|
||||
auto d_ins = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));
|
||||
auto ins = ctx.MultiInput<Tensor>("X");
|
||||
// don't compute gradient for index (ins[0])
|
||||
for (size_t i = 1; i < ins.size(); i++) {
|
||||
if (d_ins[i]) {
|
||||
d_ins[i]->Resize(ins[i]->dims());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad,
|
||||
ops::MultiplexGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
multiplex, ops::MultiplexCPUKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
multiplex_grad,
|
||||
ops::MultiplexGradCPUKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,95 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/multiplex_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename Place, typename T>
|
||||
class MultiplexGPUKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
||||
auto* out = ctx.Output<framework::LoDTensor>("Out");
|
||||
|
||||
out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto rows = ins[1]->dims()[0];
|
||||
auto cols = ins[1]->dims()[1];
|
||||
// copy index to cpu
|
||||
framework::Tensor index_t_cpu;
|
||||
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
|
||||
auto* index = index_t_cpu.data<T>();
|
||||
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
||||
ctx.device_context())
|
||||
.stream();
|
||||
Place place = boost::get<Place>(ctx.GetPlace());
|
||||
for (auto i = 0; i < rows; i++) {
|
||||
int k = (int)index[i] + 1;
|
||||
PADDLE_ENFORCE_LT(k, ins.size(),
|
||||
"index exceeds the number of candidate tensors.");
|
||||
memory::Copy(place, out->data<T>() + i * cols, place,
|
||||
ins[k]->data<T>() + i * cols, cols * sizeof(T), stream);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class MultiplexGradGPUKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
||||
auto d_ins =
|
||||
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
|
||||
for (size_t i = 1; i < d_ins.size(); i++) {
|
||||
if (d_ins[i]) {
|
||||
d_ins[i]->mutable_data<T>(ctx.GetPlace());
|
||||
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
|
||||
t.device(ctx.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
|
||||
}
|
||||
}
|
||||
|
||||
auto rows = ins[1]->dims()[0];
|
||||
auto cols = ins[1]->dims()[1];
|
||||
// copy index to cpu
|
||||
framework::Tensor index_t_cpu;
|
||||
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
|
||||
auto* index = index_t_cpu.data<T>();
|
||||
|
||||
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
||||
ctx.device_context())
|
||||
.stream();
|
||||
Place place = boost::get<Place>(ctx.GetPlace());
|
||||
for (auto i = 0; i < rows; i++) {
|
||||
int k = (int)index[i] + 1;
|
||||
if (d_ins[k]) {
|
||||
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
|
||||
d_out->data<T>() + i * cols, cols * sizeof(T), stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
multiplex, ops::MultiplexGPUKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
multiplex_grad,
|
||||
ops::MultiplexGradGPUKernel<paddle::platform::GPUPlace, float>);
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue