Merge pull request #10 from PaddlePaddle/develop

merge to local
revert-16190-refine_parallel_executor
lujun 6 years ago committed by GitHub
commit a3f7ebd6d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,8 +3,8 @@
English | [简体中文](./README_cn.md)
[![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.3/beginners_guide/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/index.html)
[![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
@ -18,7 +18,7 @@ learning to many products at Baidu.
Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
### Latest PaddlePaddle Release: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2)
### Latest PaddlePaddle Release: [Fluid 1.3.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.3)
### Install Latest Stable Release:
```
# Linux CPU
@ -26,9 +26,9 @@ pip install paddlepaddle
# Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==1.2.0.post87
pip install paddlepaddle-gpu==1.3.0.post87
# Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==1.2.0.post85
pip install paddlepaddle-gpu==1.3.0.post85
# For installation on other platform, refer to http://paddlepaddle.org/
```
@ -75,26 +75,26 @@ pip install paddlepaddle-gpu==1.2.0.post85
## Installation
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html) on our website.
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/en/1.3/beginners_guide/index_en.html) on our website.
## Documentation
We provide [English](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html) and
[Chinese](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) documentation.
We provide [English](http://paddlepaddle.org/documentation/docs/en/1.3/beginners_guide/index_en.html) and
[Chinese](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/index.html) documentation.
- [Deep Learning 101](https://github.com/PaddlePaddle/book)
You might want to start from this online interactive book that can run in a Jupyter Notebook.
- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/1.2/user_guides/howto/training/cluster_howto.html)
- [Distributed Training](http://paddlepaddle.org/documentation/docs/en/1.3/user_guides/howto/training/multi_node_en.html)
You can run distributed training jobs on MPI clusters.
- [Python API](http://paddlepaddle.org/documentation/docs/zh/1.2/api_cn/index_cn.html)
- [Python API](http://paddlepaddle.org/documentation/docs/en/1.3/api/index_en.html)
Our new API enables much shorter programs.
- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/1.2/advanced_usage/development/contribute_to_paddle/index_cn.html)
- [How to Contribute](http://paddlepaddle.org/documentation/docs/en/1.3/advanced_usage/development/contribute_to_paddle/index_en.html)
We appreciate your contributions!

@ -3,8 +3,8 @@
[English](./README.md) | 简体中文
[![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.3/beginners_guide/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/index.html)
[![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
@ -16,7 +16,7 @@ PaddlePaddle (PArallel Distributed Deep LEarning) 是一个简单易用、高效
跟进PaddlePaddle最新特性请参考我们的[版本说明](https://github.com/PaddlePaddle/Paddle/releases)
### PaddlePaddle最新版本: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2)
### PaddlePaddle最新版本: [Fluid 1.3.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.3)
### 安装最新稳定版本:
```
# Linux CPU
@ -24,9 +24,9 @@ pip install paddlepaddle
# Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==1.2.0.post87
pip install paddlepaddle-gpu==1.3.0.post87
# Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==1.2.0.post85
pip install paddlepaddle-gpu==1.3.0.post85
# 其他平台上的安装指引请参考 http://paddlepaddle.org/
```
@ -57,26 +57,26 @@ pip install paddlepaddle-gpu==1.2.0.post85
## 安装
推荐阅读官网上的[安装说明](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html)
推荐阅读官网上的[安装说明](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/install/index_cn.html)
## 文档
我们提供[英文](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html)和
[中文](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) 文档
我们提供[英文](http://paddlepaddle.org/documentation/docs/en/1.3/beginners_guide/index_en.html)和
[中文](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/index.html) 文档
- [深度学习101](https://github.com/PaddlePaddle/book)
或许您想从这个在线交互式书籍开始可以在Jupyter Notebook中运行
- [分布式训练](http://paddlepaddle.org/documentation/docs/zh/1.2/user_guides/howto/training/cluster_howto.html)
- [分布式训练](http://paddlepaddle.org/documentation/docs/zh/1.3/user_guides/howto/training/multi_node.html)
可以在MPI集群上运行分布式训练任务
- [Python API](http://paddlepaddle.org/documentation/docs/zh/1.2/api_cn/index_cn.html)
- [Python API](http://paddlepaddle.org/documentation/docs/zh/1.3/api_cn/index_cn.html)
新的API支持代码更少更简洁的程序
- [贡献方式](http://paddlepaddle.org/documentation/docs/zh/1.2/advanced_usage/development/contribute_to_paddle/index_cn.html)
- [贡献方式](http://paddlepaddle.org/documentation/docs/zh/1.3/advanced_usage/development/contribute_to_paddle/index_cn.html)
欢迎您的贡献!

@ -39,10 +39,8 @@ IF(WIN32)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib)
SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll)
SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll)
ELSE()
#TODO(intel-huying):
# Now enable Erf function in mklml library temporarily, it will be updated as offical version later.
SET(MKLML_VER "VsErf_mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE)
ELSE()
SET(MKLML_VER "mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE)
SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)

@ -71,7 +71,7 @@ paddle.fluid.initializer.NumpyArrayInitializer.__init__ ArgSpec(args=['self', 'v
paddle.fluid.layers.fc ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param_attr', 'bias_attr', 'act', 'is_test', 'name'], varargs=None, keywords=None, defaults=(1, None, None, None, False, None))
paddle.fluid.layers.embedding ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32'))
paddle.fluid.layers.dynamic_lstm ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None))
paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None))
paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name', 'h_0', 'c_0', 'cell_clip', 'proj_clip'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None, None, None, None, None))
paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None, False))
paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid', False))
paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,))
@ -121,6 +121,7 @@ paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False))
paddle.fluid.layers.sampled_softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'num_samples', 'num_true', 'remove_accidental_hits', 'use_customized_samples', 'customized_samples', 'customized_probabilities', 'seed'], varargs=None, keywords=None, defaults=(1, True, False, None, None, 0))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name', 'return_parent_idx'], varargs=None, keywords=None, defaults=(0, True, None, False))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))

@ -135,12 +135,15 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass;
if (strategy_.is_distribution_) {
VLOG(3) << "multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
VLOG(3) << "multi devices collective mode with allreduce";
multi_devices_pass =
AppendPass("allreduce_mode_multi_devices_pass").get();
} else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
VLOG(3) << "multi deivces collective mode with reduce";
multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get();
} else {
PADDLE_THROW("Unknown reduce strategy.");

@ -937,9 +937,21 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
}
void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
if (need_broadcast_var_ ||
(UseGPU() &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce)) {
// broad cast received parameters when training in parameter server mode.
if (need_broadcast_var_) {
// There are 4 conditions:
// 1. GPU && Reduce: Reduce gradient then broadcast gradient to other GPUS.
// Need to broadcast received parameters to other GPU.
// 2. GPU && AllReduce: AllReduce all graident to each GPU. Need to
// broadcast received parameters to other GPU.
// 3. CPU && AllReduce: AllReduce all gradient to each thread. Need to
// broadcast received parameters to other scope.
// 4. CPU && Reduce: because all parameters share the same memory, did not
// broadcast received parameters.
if (!UseGPU() &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
return;
}
if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else {

@ -66,7 +66,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search)
if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)

@ -11,7 +11,6 @@ limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
@ -25,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h"
#ifdef PADDLE_WITH_MKLDNN
@ -303,28 +301,8 @@ template <typename T>
struct GeluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
// Because the execute or device context can not be deliver here, it keep the
// marco for NVCC.
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
auto x_data = x.data();
auto out_data = out.data();
int n = std::min(x.size(), out.size());
std::memset(out_data, 0, n * sizeof(T));
math::CBlas<T>::AXPY(n, static_cast<T>(M_SQRT1_2), x_data, 1, out_data, 1);
math::CBlas<T>::VMERF(n, out_data, out_data, VML_LA);
for (int i = 0; i < n; i++) {
out_data[i] += static_cast<T>(1);
}
math::CBlas<T>::VMUL(n, x_data, out_data, out_data);
for (int i = 0; i < n; i++) {
out_data[i] *= static_cast<T>(0.5);
}
#else
auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
#endif
}
};

@ -122,7 +122,7 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
auto cpu_place = std::unique_ptr<paddle::platform::CPUPlace>(
new paddle::platform::CPUPlace());
paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place.get());
paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place);
framework::LoD lod;
lod.push_back(source_level_lod);

@ -144,34 +144,40 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
"The ignore threshold to ignore confidence loss.")
.SetDefault(0.7);
AddComment(R"DOC(
This operator generate yolov3 loss by given predict result and ground
This operator generates yolov3 loss based on given predict result and ground
truth boxes.
The output of previous network is in shape [N, C, H, W], while H and W
should be the same, specify the grid size, each grid point predict given
number boxes, this given number is specified by anchors, it should be
half anchors length, which following will be represented as S. In the
second dimention(the channel dimention), C should be S * (class_num + 5),
class_num is the box categoriy number of source dataset(such as coco),
so in the second dimention, stores 4 box location coordinates x, y, w, h
and confidence score of the box and class one-hot key of each anchor box.
should be the same, H and W specify the grid size, each grid point predict
given number boxes, this given number, which following will be represented as S,
is specified by the number of anchors, In the second dimension(the channel
dimension), C should be equal to S * (class_num + 5), class_num is the object
category number of source dataset(such as 80 in coco dataset), so in the
second(channel) dimension, apart from 4 box location coordinates x, y, w, h,
also includes confidence score of the box and class one-hot key of each anchor box.
While the 4 location coordinates if $$tx, ty, tw, th$$, the box predictions
correspnd to:
Assume the 4 location coordinates are :math:`t_x, t_y, t_w, t_h`, the box predictions
should be as follows:
$$
b_x = \sigma(t_x) + c_x
b_y = \sigma(t_y) + c_y
b_x = \\sigma(t_x) + c_x
$$
$$
b_y = \\sigma(t_y) + c_y
$$
$$
b_w = p_w e^{t_w}
$$
$$
b_h = p_h e^{t_h}
$$
While $$c_x, c_y$$ is the left top corner of current grid and $$p_w, p_h$$
is specified by anchors.
In the equation above, :math:`c_x, c_y` is the left top corner of current grid
and :math:`p_w, p_h` is specified by anchors.
As for confidence score, it is the logistic regression value of IoU between
anchor boxes and ground truth boxes, the score of the anchor box which has
the max IoU should be 1, and if the anchor box has IoU bigger then ignore
the max IoU should be 1, and if the anchor box has IoU bigger than ignore
thresh, the confidence score loss of this anchor box will be ignored.
Therefore, the yolov3 loss consist of three major parts, box location loss,
@ -186,13 +192,13 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
In order to trade off box coordinate losses between big boxes and small
boxes, box coordinate losses will be mutiplied by scale weight, which is
calculated as follow.
calculated as follows.
$$
weight_{box} = 2.0 - t_w * t_h
$$
Final loss will be represented as follow.
Final loss will be represented as follows.
$$
loss = (loss_{xy} + loss_{wh}) * weight_{box}

@ -151,9 +151,10 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value.output_value = out_t.data<T>();
lstm_value.state_value = cell_t.data<T>();
lstm_value.state_active_value = cell_pre_act_t.data<T>();
T cell_clip = 0.0;
math::LstmUnitFunctor<DeviceContext, T>::compute(
device_ctx, lstm_value, frame_size, cur_batch_size, gate_act,
cell_act, cand_act);
device_ctx, lstm_value, frame_size, cur_batch_size, cell_clip,
gate_act, cell_act, cand_act);
lstm_value.prev_state_value = lstm_value.state_value;
}
@ -316,9 +317,10 @@ class LSTMGradKernel : public framework::OpKernel<T> {
lstm_value.output_value = nullptr;
lstm_grad.state_active_grad = nullptr;
int cur_batch_size = bend - bstart;
T cell_clip = 0.0;
math::LstmUnitGradFunctor<DeviceContext, T>::compute(
device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size,
gate_act, cell_act, cand_act);
cell_clip, gate_act, cell_act, cand_act);
if (n > 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);

@ -73,12 +73,6 @@ class LSTMPOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) of LSTMP operator should not be null after "
"Input(H0) provided.");
auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
ctx->SetOutputDim("OrderedP0", {h_dims[0], proj_dims[1]});
}
auto b_dims = ctx->GetInputDim("Bias");
@ -180,11 +174,6 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
"This LoDTensor is obtained in the forward and used in the "
"backward.")
.AsIntermediate();
AddOutput("OrderedP0",
"(Tensor) the projection of the initial hidden state "
"H0. This is a tensor with shape (N x P), where N is the "
"batch size and P is the hidden size.")
.AsIntermediate();
AddAttr<bool>("use_peepholes",
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections.")
@ -193,6 +182,16 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: False) "
"whether to compute reversed LSTMP.")
.SetDefault(false);
AddAttr<float>("cell_clip",
"(float, defalut: 0.0) "
"Clip for Tensor for cell state tensor when clip value is "
"greater than 0.0")
.SetDefault(0.0);
AddAttr<float>("proj_clip",
"(float, defalut: 0.0) "
"Clip for Tensor for projection tensor when clip value is "
"greater than 0.0")
.SetDefault(0.0);
AddAttr<std::string>(
"gate_activation",
"(string, default: sigmoid)"

@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
@ -21,17 +22,50 @@ limitations under the License. */
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
using platform::Transform;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
class _ClipFunctor {
public:
explicit _ClipFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x) const {
if (x < min_)
return min_;
else if (x > max_)
return max_;
else
return x;
}
private:
T min_;
T max_;
};
template <typename T>
class _ClipGradFunctor {
public:
explicit _ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x, const T& y) const {
return (y > min_ && y < max_) ? x : 0;
}
private:
T min_;
T max_;
};
template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src,
@ -67,9 +101,11 @@ class LSTMPKernel : public framework::OpKernel<T> {
auto* bias = ctx.Input<Tensor>("Bias");
auto* hidden_t0 = ctx.Input<Tensor>("H0");
auto* ordered_proj0 = ctx.Output<Tensor>("OrderedP0");
auto* cell_t0 = ctx.Input<Tensor>("C0");
auto proj_clip = static_cast<T>(ctx.Attr<float>("proj_clip"));
auto cell_clip = static_cast<T>(ctx.Attr<float>("cell_clip"));
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(ctx.GetPlace());
auto* proj_out = ctx.Output<LoDTensor>("Projection");
@ -110,6 +146,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
}
lstmp_value.prev_state_value = nullptr;
Tensor ordered_c0;
Tensor ordered_h0;
framework::Vector<size_t> order(batch_gate->lod()[2]);
@ -169,18 +206,9 @@ class LSTMPKernel : public framework::OpKernel<T> {
// Since the batch computing for LSTMP reorders the input sequence
// according to their length. The initialized hidden state also needs
// to reorder.
Tensor ordered_h0;
ordered_proj0->mutable_data<T>(ctx.GetPlace());
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
&ordered_h0, true);
blas.MatMul(ordered_h0, false, *proj_weight, false, static_cast<T>(1.0),
ordered_proj0, static_cast<T>(0.0));
if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
ActCompute(cell_act, place, proj0_dev, proj0_dev);
}
blas.MatMul(*ordered_proj0, false, *weight, false, static_cast<T>(1.0),
blas.MatMul(ordered_h0, false, *weight, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0));
}
@ -189,8 +217,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
lstmp_value.state_value = cell_t.data<T>();
lstmp_value.state_active_value = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<DeviceContext, T>::compute(
device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act,
cell_act, cand_act);
device_ctx, lstmp_value, frame_size, cur_batch_size, cell_clip,
gate_act, cell_act, cand_act);
lstmp_value.prev_state_value = lstmp_value.state_value;
blas.MatMul(hidden_t, false, *proj_weight, false, static_cast<T>(1.0),
&proj_t, static_cast<T>(0.0));
@ -198,6 +226,14 @@ class LSTMPKernel : public framework::OpKernel<T> {
auto proj_t_dev = EigenMatrix<T>::From(proj_t);
ActCompute(cell_act, place, proj_t_dev, proj_t_dev);
}
if (proj_clip && proj_clip > 0.0) {
T* x_data = proj_t.data<T>();
int64_t numel = proj_t.numel();
Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), x_data,
x_data + numel, x_data,
_ClipFunctor<T>(-1.0 * proj_clip, proj_clip));
}
}
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
@ -239,6 +275,9 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
auto* proj_out = ctx.Input<LoDTensor>("Projection");
auto* cell_out = ctx.Input<LoDTensor>("Cell");
auto proj_clip = static_cast<T>(ctx.Attr<float>("proj_clip"));
auto cell_clip = static_cast<T>(ctx.Attr<float>("cell_clip"));
auto* batch_gate = ctx.Input<LoDTensor>("BatchGate");
auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");
auto* batch_hidden = ctx.Input<LoDTensor>("BatchHidden");
@ -253,7 +292,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto* h0 = ctx.Input<Tensor>("H0");
auto* ordered_proj0 = ctx.Input<Tensor>("OrderedP0");
auto* c0 = ctx.Input<Tensor>("C0");
auto* h0_g = ctx.Output<Tensor>(framework::GradVarName("H0"));
@ -363,6 +401,17 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
Tensor cur_proj = batch_proj.Slice(bstart, bend);
Tensor proj_g = batch_proj_g.Slice(bstart, bend);
if (proj_clip && proj_clip > 0.0) {
T* dx_data = proj_g.data<T>();
T* x_data = cur_proj.data<T>();
int64_t numel = proj_g.numel();
Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), dx_data,
dx_data + numel, x_data, dx_data,
_ClipGradFunctor<T>(-1.0 * proj_clip, proj_clip));
}
if (proj_act != math::detail::ActivationType::kIdentity) {
auto cur_proj_dev = EigenMatrix<T>::From(cur_proj);
auto proj_g_dev = EigenMatrix<T>::From(proj_g);
@ -412,7 +461,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
math::LstmUnitGradFunctor<DeviceContext, T>::compute(
device_ctx, lstmp_value, lstmp_grad, frame_size, cur_batch_size,
gate_act, cell_act, cand_act);
cell_clip, gate_act, cell_act, cand_act);
if (n > 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
@ -431,31 +480,14 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
&ordered_h0, true);
if (weight_g) {
blas.MatMul(*ordered_proj0, true, gate_g, false,
static_cast<T>(1.0), weight_g, static_cast<T>(1.0));
blas.MatMul(ordered_h0, true, gate_g, false, static_cast<T>(1.0),
weight_g, static_cast<T>(1.0));
}
}
if (h0 && (h0_g || proj_weight_g)) {
ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
Tensor proj0_g;
proj0_g.Resize({in_dims[0], proj_weight->dims()[1]});
proj0_g.mutable_data<T>(ctx.GetPlace());
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
&proj0_g, static_cast<T>(0.0));
if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
auto proj0_g_dev = EigenMatrix<T>::From(proj0_g);
ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev,
proj0_g_dev);
}
if (h0_g) {
blas.MatMul(proj0_g, false, *proj_weight, true, static_cast<T>(1.0),
&ordered_h0_g, static_cast<T>(0.0));
}
if (proj_weight_g) {
blas.MatMul(ordered_h0, true, proj0_g, false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
}
&ordered_h0_g, static_cast<T>(0.0));
}
}
}

@ -39,6 +39,7 @@ math_library(cross_entropy)
math_library(cos_sim_functor)
math_library(depthwise_conv DEPS cub)
math_library(im2col)
math_library(sample_prob)
math_library(sampler)
math_library(gru_compute DEPS activation_functions math_function)

@ -184,9 +184,6 @@ class Blas {
template <typename T>
void VINV(int n, const T* a, T* y) const;
template <typename T>
void VMERF(int n, const T* a, T* y, int64_t mode) const;
private:
const DeviceContext& context_;
};
@ -293,11 +290,6 @@ class BlasT : private Blas<DeviceContext> {
Base()->template VINV<T>(args...);
}
template <typename... ARGS>
void VMERF(ARGS... args) const {
Base()->template VMERF<T>(args...);
}
private:
const Blas<DeviceContext>* Base() const {
return static_cast<const Blas<DeviceContext>*>(this);

@ -123,11 +123,6 @@ struct CBlas<float> {
static void VINV(ARGS... args) {
platform::dynload::vsInv(args...);
}
template <typename... ARGS>
static void VMERF(ARGS... args) {
platform::dynload::vmsErf(args...);
}
};
template <>
@ -228,11 +223,6 @@ struct CBlas<double> {
static void VINV(ARGS... args) {
platform::dynload::vdInv(args...);
}
template <typename... ARGS>
static void VMERF(ARGS... args) {
platform::dynload::vmdErf(args...);
}
};
#else
@ -635,19 +625,6 @@ void Blas<DeviceContext>::VINV(int n, const T *a, T *y) const {
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VMERF(int n, const T *a, T *y,
int64_t mode) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VMERF(n, a, y, mode);
#else
for (int i = 0; i < n; ++i) {
y[i] = std::erf(a[i]);
}
#endif
}
} // namespace math
} // namespace operators
} // namespace paddle

@ -32,7 +32,8 @@ namespace detail {
template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size, ActivationType active_node,
int frame_size, T cell_clip,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
T r_value_in;
@ -67,7 +68,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_node, active_gate, active_state);
&cell_clip, active_node, active_gate, active_state);
value_in[i] = r_value_in;
value_ig[i] = r_value_ig;
@ -82,7 +83,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size,
ActivationType active_node,
T cell_clip, ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
T r_value_in;
@ -135,7 +136,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
&r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
&r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
&r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
active_node, active_gate, active_state);
&cell_clip, active_node, active_gate, active_state);
grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig;
@ -154,7 +155,8 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size, ActivationType active_node,
int frame_size, T cell_clip,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
#ifdef __AVX__
@ -194,7 +196,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_node, active_gate, active_state);
&cell_clip, active_node, active_gate, active_state);
value_in[i] = r_value_in;
value_ig[i] = r_value_ig;
@ -210,7 +212,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size,
ActivationType active_node,
T cell_clip, ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
#ifdef __AVX__
@ -268,7 +270,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
&r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
&r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
&r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
active_node, active_gate, active_state);
&cell_clip, active_node, active_gate, active_state);
grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig;
@ -292,27 +294,27 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
ActivationType active_node, ActivationType active_gate,
ActivationType active_state) {
T cell_clip, ActivationType active_node,
ActivationType active_gate, ActivationType active_state) {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_forward_one_sequence<T>(op, value, frame_size, active_node,
active_gate, active_state);
avx_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip,
active_node, active_gate, active_state);
} else {
naive_lstm_forward_one_sequence<T>(op, value, frame_size, active_node,
active_gate, active_state);
naive_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip,
active_node, active_gate, active_state);
}
}
template <class T, class Op>
void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, ActivationType active_node,
int frame_size, T cell_clip, ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, active_node,
active_gate, active_state);
avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, cell_clip,
active_node, active_gate, active_state);
} else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frame_size,
naive_lstm_backward_one_sequence<T>(op, value, grad, frame_size, cell_clip,
active_node, active_gate, active_state);
}
}

@ -31,7 +31,8 @@ namespace detail {
*/
template <class T, class Op, bool is_batch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node,
int batch_size, T cell_clip,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
@ -72,7 +73,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_node, active_gate, active_state);
&cell_clip, active_node, active_gate, active_state);
value.gate_value[frame_idx] = r_value_in;
value.gate_value[frame_idx + frame_size] = r_value_ig;
@ -91,7 +92,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
template <class T, class Op, bool is_batch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size,
int batch_size, ActivationType active_node,
int batch_size, T cell_clip,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
@ -148,8 +150,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, &r_grad_ig,
&r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state,
&r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF,
&r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, active_node,
active_gate, active_state);
&r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, &cell_clip,
active_node, active_gate, active_state);
grad.gate_grad[frame_idx] = r_grad_in;
grad.gate_grad[frame_idx + frame_size] = r_grad_ig;
@ -185,8 +187,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, int frame_size, int batch_size,
ActivationType active_node, ActivationType active_gate,
ActivationType active_state) {
T cell_clip, ActivationType active_node,
ActivationType active_gate, ActivationType active_state) {
dim3 threads;
dim3 grid;
if (batch_size == 1) {
@ -205,12 +207,12 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
if (batch_size == 1) {
KeLstmForward<T, Op,
/* is_batch= */ false><<<grid, threads, 0, stream>>>(
op, value, frame_size, batch_size, active_node, active_gate,
op, value, frame_size, batch_size, cell_clip, active_node, active_gate,
active_state);
} else {
KeLstmForward<T, Op,
/* is_batch= */ true><<<grid, threads, 0, stream>>>(
op, value, frame_size, batch_size, active_node, active_gate,
op, value, frame_size, batch_size, cell_clip, active_node, active_gate,
active_state);
}
}
@ -218,7 +220,7 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
template <class T, class Op>
void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size,
int frame_size, int batch_size, T cell_clip,
ActivationType active_node, ActivationType active_gate,
ActivationType active_state) {
dim3 threads;
@ -239,13 +241,13 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
if (batch_size == 1) {
KeLstmBackward<T, Op,
/* is_batch= */ false><<<grid, threads, 0, stream>>>(
op, value, grad, frame_size, batch_size, active_node, active_gate,
active_state);
op, value, grad, frame_size, batch_size, cell_clip, active_node,
active_gate, active_state);
} else {
KeLstmBackward<T, Op,
/* is_batch= */ true><<<grid, threads, 0, stream>>>(
op, value, grad, frame_size, batch_size, active_node, active_gate,
active_state);
op, value, grad, frame_size, batch_size, cell_clip, active_node,
active_gate, active_state);
}
}

@ -29,7 +29,7 @@ class lstm {
public:
HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og,
T *prev_state, T *state, T *state_atv, T *output,
T *checkI, T *checkF, T *checkO,
T *checkI, T *checkF, T *checkO, T *cell_clip,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
@ -37,6 +37,15 @@ class lstm {
*value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate);
*value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate);
*state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg);
if (*cell_clip > 0.0) {
if (*state < -1.0 * (*cell_clip)) {
*state = -1.0 * (*cell_clip);
}
if (*state > *cell_clip) {
*state = *cell_clip;
}
}
*value_og = activation(*value_og + (*state) * (*checkO), active_gate);
*state_atv = activation(*state, active_state);
*output = (*value_og) * (*state_atv);
@ -52,7 +61,7 @@ class lstm {
__m256 *value_fg, __m256 *value_og,
__m256 *prev_state, __m256 *state,
__m256 *state_atv, __m256 *output, __m256 *checkI,
__m256 *checkF, __m256 *checkO,
__m256 *checkF, __m256 *checkO, T *cell_clip,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
@ -65,6 +74,13 @@ class lstm {
active_gate);
*state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig),
_mm256_mul_ps(*prev_state, *value_fg));
if (*cell_clip > 0.0f) {
__m256 min = _mm256_set1_ps(0.0f - *cell_clip);
__m256 max = _mm256_set1_ps(*cell_clip);
*state = _mm256_min_ps(max, *state);
*state = _mm256_max_ps(min, *state);
}
*value_og = activation(
_mm256_add_ps(*value_og, _mm256_mul_ps(*state, *checkO)), active_gate);
*state_atv = activation(*state, active_state);
@ -86,15 +102,26 @@ class lstm {
T *prev_state, T *prev_state_grad, T *state,
T *state_grad, T *state_atv, T *output_grad,
T *checkI, T *checkF, T *checkO, T *checkIGrad,
T *checkFGrad, T *checkOGrad,
T *checkFGrad, T *checkOGrad, T *cell_clip,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
*grad_og =
activation((*output_grad) * (*state_atv), *value_og, active_gate);
*state_grad +=
activation((*output_grad) * (*value_og), *state_atv, active_state) +
(*grad_og) * (*checkO);
if (*cell_clip > 0.0f) {
if (*state >= (*cell_clip) || *state <= (0.0f - (*cell_clip))) {
*state_grad = 0.0f;
} else {
*state_grad +=
activation((*output_grad) * (*value_og), *state_atv, active_state) +
(*grad_og) * (*checkO);
}
} else {
*state_grad +=
activation((*output_grad) * (*value_og), *state_atv, active_state) +
(*grad_og) * (*checkO);
}
*grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node);
*grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate);
*grad_fg =
@ -117,15 +144,24 @@ class lstm {
__m256 *prev_state, __m256 *prev_state_grad, __m256 *state,
__m256 *state_grad, __m256 *state_atv, __m256 *output_grad,
__m256 *checkI, __m256 *checkF, __m256 *checkO, __m256 *checkIGrad,
__m256 *checkFGrad, __m256 *checkOGrad, ActivationType active_node,
ActivationType active_gate, ActivationType active_state) {
__m256 *checkFGrad, __m256 *checkOGrad, T *cell_clip,
ActivationType active_node, ActivationType active_gate,
ActivationType active_state) {
*grad_og = activation(_mm256_mul_ps(*output_grad, *state_atv), *value_og,
active_gate);
*state_grad =
_mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og),
*state_atv, active_state),
*state_grad);
*state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad);
if (*cell_clip > 0.0f) {
T *state_ = reinterpret_cast<T *>(state);
if (*state_ >= (*cell_clip) || *state_ <= (0.0f - (*cell_clip))) {
*state_grad = _mm256_set1_ps(0.0f);
} else {
*state_grad =
_mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og),
*state_atv, active_state),
*state_grad);
*state_grad =
_mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad);
}
}
*grad_in = activation(_mm256_mul_ps(*state_grad, *value_ig), *value_in,
active_node);
*grad_ig = activation(_mm256_mul_ps(*state_grad, *value_in), *value_ig,

@ -24,12 +24,12 @@ template <class T>
struct LstmUnitFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType& gate_act,
T cell_clip, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
cand_act, gate_act, cell_act);
cell_clip, cand_act, gate_act, cell_act);
value.gate_value += frame_size * 4;
value.state_value += frame_size;
value.state_active_value += frame_size;
@ -45,13 +45,14 @@ template <class T>
struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size,
int frame_size, int batch_size, T cell_clip,
const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
frame_size, cand_act, gate_act, cell_act);
frame_size, cell_clip, cand_act, gate_act,
cell_act);
value.gate_value += frame_size * 4;
value.state_value += frame_size;

@ -24,12 +24,12 @@ template <class T>
struct LstmUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType& gate_act,
T cell_clip, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
frame_size, batch_size, cand_act, gate_act,
cell_act);
frame_size, batch_size, cell_clip, cand_act,
gate_act, cell_act);
}
};
@ -37,13 +37,13 @@ template <class T>
struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size,
int frame_size, int batch_size, T cell_clip,
const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
frame_size, batch_size, cand_act, gate_act,
cell_act);
frame_size, batch_size, cell_clip, cand_act,
gate_act, cell_act);
}
};

@ -50,7 +50,7 @@ template <typename DeviceContext, typename T>
class LstmUnitFunctor {
public:
static void compute(const DeviceContext &context, LstmMetaValue<T> value,
int frame_size, int batch_size,
int frame_size, int batch_size, T cell_clip,
const detail::ActivationType &gate_act,
const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act);
@ -61,7 +61,7 @@ class LstmUnitGradFunctor {
public:
static void compute(const DeviceContext &context, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, int batch_size,
const detail::ActivationType &gate_act,
T cell_clip, const detail::ActivationType &gate_act,
const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act);
};

@ -0,0 +1,26 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/sample_prob.h"
namespace paddle {
namespace operators {
namespace math {
template class SampleWithProb<platform::CPUDeviceContext, float>;
template class SampleWithProb<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle

@ -0,0 +1,161 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <thrust/random.h>
#include <thrust/sort.h>
#include <iostream>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sample_prob.h"
#include "paddle/fluid/operators/math/sampler.h"
namespace paddle {
namespace operators {
namespace math {
using Tensor = framework::Tensor;
template <typename T>
__device__ T gpu_adjust_prob(const T prob, const int num_samples,
const int num_tries) {
if (num_samples == num_tries) {
return prob * num_samples;
} else {
return -expm1(num_tries * log1p(-prob));
}
}
class GPULogUniformSampler {
public:
__device__ int64_t Sample(float random, const int range,
const float log_range) const;
__device__ float Probability(int64_t value, const float log_range) const;
};
__device__ int64_t GPULogUniformSampler::Sample(float random, const int range,
const float log_range) const {
// Got Log Uniform distribution from uniform distribution by
// inverse_transform_sampling method
const int64_t value = static_cast<int64_t>(exp(random * log_range)) - 1;
// Mathematically, value should be <= range_, but might not be due to some
// floating point roundoff, so we mod by range_.
return value % range;
}
__device__ float GPULogUniformSampler::Probability(
int64_t value, const float log_range) const {
// Given f(x) = 1/[(x+1) * log_range_]
// The value's probability is integral of f(x) from value to (value + 1)
return (log((value + 2.0) / (value + 1.0))) / log_range;
}
template <typename T>
__global__ void SamplingCondidate(
const size_t n, const int num_tries, const int range, const float log_range,
const int num_true, const std::size_t num_samples,
const int64_t* label_data, int64_t* samples_data, T* probabilities_data) {
const int num_sampled_classes = num_true + num_samples;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;
GPULogUniformSampler sampler;
for (; idx < n; idx += blockDim.x * gridDim.x) {
int col_idx = idx % num_sampled_classes;
int row_idx = idx / num_sampled_classes;
if (col_idx < num_true) {
samples_data[idx] = label_data[row_idx * num_true + col_idx];
} else {
samples_data[idx] = samples_data[col_idx];
}
probabilities_data[idx] = sampler.Probability(samples_data[idx], log_range);
probabilities_data[idx] =
gpu_adjust_prob(probabilities_data[idx], num_samples, num_tries);
}
}
template <typename T>
int UniqSampler(const Sampler& sampler, const std::size_t num_samples,
int64_t* samples_data) {
// sample num_samles unique samples for an example, note that they are not
// all negative samples
std::unordered_set<int64_t> tmp_samples;
tmp_samples.clear();
int num_tries = 0;
int j = 0;
while (j < num_samples) {
++num_tries;
auto v = sampler.Sample();
auto insert_ok = tmp_samples.insert(v).second;
if (!insert_ok) {
continue;
}
samples_data[j] = v;
++j;
}
return num_tries;
}
template <typename T>
void GPUSampleWithProb<T>::operator()(
const platform::CUDADeviceContext& context, const int seed,
const int dict_size, const bool uniq, const std::size_t num_samples,
const Tensor* L, Tensor* S, Tensor* P) {
// UNDERSTAND: dimension issues
const auto lbl_dim = L->dims();
const int batch_size = lbl_dim[0];
const int num_true = lbl_dim[1];
const int num_sampled_classes = num_true + num_samples;
framework::DDim ret_dim{batch_size, num_sampled_classes};
// UNDERSTAND: raw data view
const int64_t* label_data = L->data<int64_t>();
int64_t* samples_data = S->data<int64_t>();
T* probabilities_data = P->data<T>();
int s_size = num_samples;
framework::DDim s_dim{s_size};
Tensor s;
int64_t* s_data = s.mutable_data<int64_t>(s_dim, platform::CPUPlace());
math::LogUniformSampler sampler(dict_size, seed);
int range = dict_size;
float log_range = log(range + 1);
int num_tries = UniqSampler<T>(sampler, num_samples, s_data);
VLOG(1) << "num_tries: " << num_tries;
PADDLE_ENFORCE(cudaMemcpy(samples_data + num_true, s_data,
sizeof(int64_t) * num_samples,
cudaMemcpyHostToDevice));
int threads = 512;
const size_t size = batch_size * num_sampled_classes;
int grid = (batch_size * num_sampled_classes + threads - 1) / threads;
SamplingCondidate<T><<<grid, threads, 0, context.stream()>>>(
size, num_tries, range, log_range, num_true, num_samples, label_data,
samples_data, probabilities_data);
}
template class GPUSampleWithProb<float>;
template class GPUSampleWithProb<double>;
} // namespace math
} // namespace operators
} // namespace paddle

@ -0,0 +1,118 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/sampler.h"
namespace paddle {
namespace operators {
namespace math {
using Tensor = framework::Tensor;
/* UNDERSTAND: utility function to adjust probability for unique sampling,
return whatever as it is if not using unique samping */
template <typename T>
static T adjust_prob(const T prob, const int num_samples, const int num_tries) {
if (num_samples == num_tries) {
return prob * num_samples;
} else {
return -expm1(num_tries * log1p(-prob));
}
}
template <typename DeviceContext, typename T>
class SampleWithProb {
public:
void operator()(const DeviceContext& context, const Sampler& sampler,
const std::size_t num_samples, const Tensor* L, Tensor* S,
Tensor* P) {
// UNDERSTAND: dimension issues
const auto lbl_dim = L->dims();
const int batch_size = lbl_dim[0];
const int num_true = lbl_dim[1];
const int num_sampled_classes = num_true + num_samples;
framework::DDim ret_dim{batch_size, num_sampled_classes};
// UNDERSTAND: raw data view
const int64_t* label_data = L->data<int64_t>();
int64_t* samples_data =
S->mutable_data<int64_t>(ret_dim, context.GetPlace());
T* probabilities_data = P->mutable_data<T>(ret_dim, context.GetPlace());
// temp sets for unique sampling
std::unordered_set<int64_t> tmp_samples;
int j = 0; // column index
// add true labels, not that efficient
while (j < num_true) {
for (int i = 0; i < batch_size; ++i) {
auto samples_index = i * num_sampled_classes + j;
auto v = label_data[i * num_true + j];
samples_data[samples_index] = v;
probabilities_data[samples_index] = sampler.Probability(v);
}
++j;
}
// sample num_samles unique samples for an example, note that they are not
// all negative samples
tmp_samples.clear();
int num_tries = 0;
while (j < num_sampled_classes) {
++num_tries;
auto v = sampler.Sample();
auto insert_ok = tmp_samples.insert(v).second;
if (!insert_ok) {
continue;
}
auto p = sampler.Probability(v);
for (int i = 0; i < batch_size; ++i) {
auto samples_index = i * num_sampled_classes + j;
samples_data[samples_index] = v;
probabilities_data[samples_index] = p;
}
++j;
}
// compute Q(y|x), because of unique sampling, probabilities need to be
// adjusted
for (int k = 0; k < num_sampled_classes; ++k) {
for (int i = 0; i < batch_size; ++i) {
auto samples_index = i * num_sampled_classes + k;
probabilities_data[samples_index] = adjust_prob(
probabilities_data[samples_index], num_samples, num_tries);
}
}
}
};
#ifdef PADDLE_WITH_CUDA
template <typename T>
class GPUSampleWithProb {
public:
void operator()(const platform::CUDADeviceContext& context, const int seed,
const int dict_size, const bool uniq,
const std::size_t num_samples, const Tensor* L, Tensor* S,
Tensor* P);
};
#endif
} // namespace math
} // namespace operators
} // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save