Merge branch 'develop' of upstream into fix_docs

wangkuiyi-patch-1
Yibing Liu 7 years ago
commit fbddb8e4bf

@ -22,6 +22,7 @@
| jczaja | Jacek Czaja | | jczaja | Jacek Czaja |
| JiayiFeng | Jia-Yi Feng | | JiayiFeng | Jia-Yi Feng |
| kbinias | Krzysztof Binias | | kbinias | Krzysztof Binias |
| kexinzhao | Ke-Xin Zhao |
| kuke | Yi-Bing Liu | | kuke | Yi-Bing Liu |
| lcy-seso | Ying Cao | | lcy-seso | Ying Cao |
| lipeng-unisound | Peng Li | | lipeng-unisound | Peng Li |

@ -155,6 +155,15 @@ copy(inference_lib DEPS paddle_fluid_shared paddle_fluid
DSTS ${dst_dir}/${module} ${dst_dir}/${module} DSTS ${dst_dir}/${module} ${dst_dir}/${module}
) )
if(WITH_CONTRIB)
set(contrib_dst_dir "${FLUID_INSTALL_DIR}/contrib/inference")
copy(contrib_inference_lib DEPS paddle_inference_api
SRCS ${PADDLE_SOURCE_DIR}/paddle/contrib/inference/paddle_inference_api.h
${PADDLE_BINARY_DIR}/paddle/contrib/inference/libpaddle_inference_api.*
DSTS ${contrib_dst_dir} ${contrib_dst_dir}
)
endif()
set(module "platform") set(module "platform")
copy(platform_lib DEPS profiler_py_proto copy(platform_lib DEPS profiler_py_proto
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/dynload/*.h ${src_dir}/${module}/details/*.h SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/dynload/*.h ${src_dir}/${module}/details/*.h

@ -342,6 +342,12 @@ conv2d
.. autofunction:: paddle.fluid.layers.conv2d .. autofunction:: paddle.fluid.layers.conv2d
:noindex: :noindex:
conv3d
------
.. autofunction:: paddle.fluid.layers.conv3d
:noindex:
sequence_pool sequence_pool
------------- -------------
@ -366,6 +372,12 @@ pool2d
.. autofunction:: paddle.fluid.layers.pool2d .. autofunction:: paddle.fluid.layers.pool2d
:noindex: :noindex:
pool3d
------
.. autofunction:: paddle.fluid.layers.pool3d
:noindex:
batch_norm batch_norm
---------- ----------
@ -384,6 +396,13 @@ conv2d_transpose
.. autofunction:: paddle.fluid.layers.conv2d_transpose .. autofunction:: paddle.fluid.layers.conv2d_transpose
:noindex: :noindex:
conv3d_transpose
----------------
.. autofunction:: paddle.fluid.layers.conv2d_transpose
:noindex:
sequence_expand sequence_expand
--------------- ---------------

@ -104,7 +104,7 @@ no changes added to commit (use "git add" and/or "git commit -a")
➜ docker run -it -v $(pwd):/paddle paddle:latest-dev bash -c "cd /paddle/build && ctest" ➜ docker run -it -v $(pwd):/paddle paddle:latest-dev bash -c "cd /paddle/build && ctest"
``` ```
关于构建和测试的更多信息,请参见[这篇文档](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/getstarted/build_and_install/docker_install_cn.rst)。 关于构建和测试的更多信息,请参见[使用Docker安装运行](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/v2/build_and_install/docker_install_cn.rst)。
## 提交commit ## 提交commit

@ -50,7 +50,7 @@ cc_test(test_paddle_inference_api
inference_api_test(test_paddle_inference_api_impl inference_api_test(test_paddle_inference_api_impl
ARGS test_word2vec test_image_classification) ARGS test_word2vec test_image_classification)
if (WITH_ANAKIN) if (WITH_ANAKIN AND WITH_TESTING) # only needed in CI
# Due to Anakin do not have official library releases and the versions of protobuf and cuda do not match Paddle's, # Due to Anakin do not have official library releases and the versions of protobuf and cuda do not match Paddle's,
# so anakin library will not be merged to our official inference library. To use anakin prediction API, one need to # so anakin library will not be merged to our official inference library. To use anakin prediction API, one need to
# compile the libinference_anakin_api.a and compile with anakin.so. # compile the libinference_anakin_api.a and compile with anakin.so.

@ -17,7 +17,7 @@ if(APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pessimizing-move") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pessimizing-move")
endif(APPLE) endif(APPLE)
cc_library(tape_variable SRCS variable.cc DEPS ${FLUID_CORE_MODULES} device_context) cc_library(tape_variable SRCS variable.cc DEPS ${FLUID_CORE_MODULES} device_context framework_proto proto_desc operator)
cc_library(tape SRCS tape.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} tape_variable) cc_library(tape SRCS tape.cc DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} tape_variable)
cc_test(test_tape cc_test(test_tape

@ -271,18 +271,18 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", "Input of HardShrink operator"); AddInput("X", "Input of HardShrink operator");
AddOutput("Out", "Output of HardShrink operator"); AddOutput("Out", "Output of HardShrink operator");
AddAttr<float>("threshold", "The value of threshold for HardShrink") AddAttr<float>("threshold",
"The value of threshold for HardShrink. [default: 0.5]")
.SetDefault(0.5f); .SetDefault(0.5f);
AddComment(R"DOC( AddComment(R"DOC(
HardShrink Activation Operator. :strong:`HardShrink activation operator`
$$ .. math::
out = \begin{cases} out = \begin{cases}
x, \text{if } x > \lambda \\ x, \text{if } x > \lambda \\
x, \text{if } x < -\lambda \\ x, \text{if } x < -\lambda \\
0, \text{otherwise} 0, \text{otherwise}
\end{cases} \end{cases}
$$
)DOC"); )DOC");
} }
@ -394,18 +394,18 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", "Input of ThresholdedRelu operator"); AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Out", "Output of ThresholdedRelu operator"); AddOutput("Out", "Output of ThresholdedRelu operator");
AddAttr<float>("threshold", "The threshold location of activation") AddAttr<float>("threshold",
"The threshold location of activation. [default 1.0].")
.SetDefault(1.0f); .SetDefault(1.0f);
AddComment(R"DOC( AddComment(R"DOC(
ThresholdedRelu Activation Operator. :strong:`ThresholdedRelu activation operator`
.. math::
$$
out = \begin{cases} out = \begin{cases}
x, \text{if } x > threshold \\ x, \text{if } x > threshold \\
0, \text{otherwise} 0, \text{otherwise}
\end{cases} \end{cases}
$$
)DOC"); )DOC");
} }
}; };

@ -23,30 +23,26 @@ class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
OpComment comment; OpComment comment;
AddInput("X", AddInput("X", string::Sprintf("the left hand operand of %s operator",
string::Sprintf("(LoDTensor) the left hand operand of %s operator",
comment.type)); comment.type));
AddInput("Y", string::Sprintf( AddInput("Y", string::Sprintf("the right hand operand of %s operator",
"(LoDTensor) the right hand operand of %s operator",
comment.type)); comment.type));
AddAttr<bool>("force_cpu", AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu " "Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running " "memory. Otherwise, fill output variable to the running "
"device") "device [default true].")
.SetDefault(false); .SetDefault(true);
AddOutput("Out", string::Sprintf( AddOutput("Out", string::Sprintf("n-dim bool tensor. Each element is %s",
"(LoDTensor) n-dim bool tensor. Each element is %s",
comment.equation)); comment.equation));
AddComment(string::Sprintf(R"DOC(%s Operator AddComment(string::Sprintf(R"DOC(
It operates element-wise on X and Y, and returns the Out. Each of them is a It operates element-wise on X and Y, and returns the Out. Each of them is a
N-dim tensor. X and Y could be any type. The each element of the Out tensor is N-dim tensor. X and Y could be any type. The each element of the Out tensor is
calculated by %s calculated by $%s$
)DOC", )DOC",
comment.type, comment.equation)); comment.equation));
AddAttr<int>("axis", AddAttr<int>(
"(int, default -1). The start dimension index " "axis",
"for broadcasting Y onto X.") "The start dimension index for broadcasting Y onto X. [default -1]")
.SetDefault(-1) .SetDefault(-1)
.EqualGreaterThan(-1); .EqualGreaterThan(-1);
} }

@ -107,7 +107,13 @@ REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker,
false> /* set false to disable empty grad */); false> /* set false to disable empty grad */);
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad); REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>); concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
concat_grad, concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, float>); ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>);

@ -15,7 +15,13 @@ limitations under the License. */
#include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/operators/concat_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CUDADeviceContext, float>); concat, ops::ConcatKernel<paddle::platform::CUDADeviceContext, double>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, float>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
concat_grad, concat_grad,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, float>); ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>);

@ -30,19 +30,19 @@ class CumOp : public framework::OperatorWithKernel {
class CumsumOpMaker : public framework::OpProtoAndCheckerMaker { class CumsumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "Input of Cumsum operator"); AddInput("X", "Input of cumsum operator");
AddOutput("Out", "Output of Cumsum operator"); AddOutput("Out", "Output of cumsum operator");
AddAttr<int>("axis", AddAttr<int>("axis",
"(int, default -1). The dimenstion to accumulate along. " "The dimenstion to accumulate along. -1 means the last "
"-1 means the last dimenstion") "dimenstion [default -1].")
.SetDefault(-1) .SetDefault(-1)
.EqualGreaterThan(-1); .EqualGreaterThan(-1);
AddAttr<bool>("exclusive", AddAttr<bool>("exclusive",
"bool, default false). Whether to perform exclusive cumsum") "Whether to perform exclusive cumsum. [default false].")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("reverse", AddAttr<bool>("reverse",
"bool, default false). If true, the cumsum is performed in " "If true, the cumsum is performed in the reversed direction. "
"the reversed direction") "[default false].")
.SetDefault(false); .SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
The cumulative sum of the elements along a given axis. The cumulative sum of the elements along a given axis.

@ -85,7 +85,7 @@ class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
.InEnum({"CUDA", "CPU", "AUTO"}) .InEnum({"CUDA", "CPU", "AUTO"})
.SetDefault("AUTO"); .SetDefault("AUTO");
AddComment(R"DOC( AddComment(R"DOC(
Returns a list of places based on flags. The list will be used for parallel Returns a list of places based on arguments. The list will be used for parallel
execution. execution.
)DOC"); )DOC");
} }

@ -62,36 +62,33 @@ class LayerNormOp : public framework::OperatorWithKernel {
class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(LoDTensor) The input tensor."); AddInput("X", "The input tensor.");
AddInput("Scale", AddInput("Scale",
"(Tensor, optional) Scale is a 1-dimensional tensor of size " "(optional) Scale is a 1-dimensional tensor of size "
"H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])."
"It is applied to the output.") "It is applied to the output.")
.AsDispensable(); .AsDispensable();
AddInput("Bias", AddInput("Bias",
"(Tensor, optional) Bias is a 1-dimensional tensor of size " "(optional) Bias is a 1-dimensional tensor of size "
"H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])."
"It is applied to the output.") "It is applied to the output.")
.AsDispensable(); .AsDispensable();
AddOutput("Y", "(LoDTensor) Result after normalization."); AddOutput("Y", "Result after normalization.");
AddOutput("Mean", "(Tensor) Mean of the current mini batch.") AddOutput("Mean", "Mean of the current mini batch.").AsIntermediate();
.AsIntermediate(); AddOutput("Variance", "Variance of the current mini batch.")
AddOutput("Variance", "(Tensor) Variance of the current mini batch.")
.AsIntermediate(); .AsIntermediate();
AddAttr<float>("epsilon", AddAttr<float>("epsilon",
"(float, default 1e-5) Constant for " "Constant for numerical stability [default 1e-5].")
"numerical stability")
.SetDefault(1e-5) .SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) { .AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
"'epsilon' should be between 0.0 and 0.001."); "'epsilon' should be between 0.0 and 0.001.");
}); });
AddAttr<int>("begin_norm_axis", AddAttr<int>("begin_norm_axis",
"(int default:1), the " "the axis of `begin_norm_axis ... Rank(X) - 1` will be "
"axis of `begin_norm_axis ... Rank(X) - 1` will be "
"normalized. `begin_norm_axis` splits the tensor(`X`) to a " "normalized. `begin_norm_axis` splits the tensor(`X`) to a "
"matrix [N,H].") "matrix [N,H]. [default 1].")
.SetDefault(1) .SetDefault(1)
.AddCustomChecker([](const int &begin_norm_axis) { .AddCustomChecker([](const int &begin_norm_axis) {
PADDLE_ENFORCE_GT(begin_norm_axis, 0, PADDLE_ENFORCE_GT(begin_norm_axis, 0,
@ -99,10 +96,14 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
}); });
AddComment(R"DOC( AddComment(R"DOC(
Layer Normalization. Assume feature vectors exist on dimensions
Layer Norm has been implemented as discussed in the paper: :attr:`begin_norm_axis ... rank(input)` and calculate the moment statistics
https://arxiv.org/abs/1607.06450 along these dimensions for each feature vector :math:`a` with size
... :math:`H`, then normalize each feature vector using the corresponding
statistics. After that, apply learnable gain and bias on the normalized
tensor to scale and shift if :attr:`scale` and :attr:`shift` are set.
Refer to `Layer Normalization <https://arxiv.org/pdf/1607.06450v1.pdf>`_
)DOC"); )DOC");
} }
}; };

@ -62,26 +62,46 @@ class MultiplexOp : public framework::OperatorWithKernel {
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Ids", "The index tensor of multiplex operator."); AddInput("Ids",
AddInput("X", "The candidate tensors of multiplex operator.") "Tensor<int32>, index variable which is a 2-D tensor with shape "
"[M, 1] where M is the batch size.");
AddInput("X",
"A list of variables to gather from. All variables have the same "
"shape and the rank is at least 2.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "The output tensor of multiplex operator."); AddOutput("Out", "The output tensor of multiplex operator.");
AddComment(R"DOC( AddComment(R"DOC(
Multiplex Operator. Referring to the given index variable, this layer selects rows from the
input variables to construct a multiplex variable. Assuming that there are
Multiplex multiple tensors according to the index provided by the index tensor. :math:`m` input variables and :math:`I_i` represents the i-th input
variable and :math:`i` is in [0, :math:`m`). All input variables are
Ids: the index tensor. tensors with same shape [:math:`d_0`, :math:`d_1`, ..., :math:`d_R`].
X[0 : N - 1]: the candidate tensors for output (N >= 2). Please note that rank of the input tensor should be at least 2. Each input
For each index i from 0 to batchSize - 1, the output is the i-th row of the variable will be treated as a 2-D matrix with shape [:math:`M`, :math:`N`]
where :math:`M` for :math:`d_0` and :math:`N` for :math:`d_1` * :math:`d_2`
* ... * :math:`d_R`. Let :math:`I_i[j]` be the j-th row of the i-th input
variable. The given index variable should be a 2-D tensor with shape
[:math:`M`, 1]. Let `ID[i]` be the i-th index value of the index variable.
Then the output variable will be a tensor with shape [:math:`d_0`,
:math:`d_1`, ..., :math:`d_R`]. If we treat the output tensor as a 2-D
matrix with shape [:math:`M`, :math:`N`] and let :math:`O[i]` be the i-th
row of the matrix, then `O[i]` is equal to :math:`I_{ID[i]}[i]`.
* Ids: the index tensor.
* X[0 : N - 1]: the candidate tensors for output (N >= 2).
* For each index i from 0 to batchSize - 1, the output is the i-th row of the
the (Ids[i])-th tensor. the (Ids[i])-th tensor.
For i-th row of the output tensor: For i-th row of the output tensor:
$$y[i] = x_{k}[i]$$ $$
y[i] = x_{k}[i]
$$
where `y` is the output tensor, `x_{k}` is the k-th input tensor, where $y$ is the output tensor, $x_{k}$ is the k-th input tensor,
and `k = Ids[i]`. and $k = Ids[i]$.
)DOC"); )DOC");
} }

@ -78,11 +78,15 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
class CreateRecordIOReaderOpMaker : public FileReaderMakerBase { class CreateRecordIOReaderOpMaker : public FileReaderMakerBase {
protected: protected:
void Apply() override { void Apply() override {
AddAttr<std::string>("filename", "The filename of record io reader"); AddAttr<std::string>(
"filename",
"The filename of record file. This file will given to reader.");
AddComment(R"DOC( AddComment(R"DOC(
CreateRecordIOReader Operator Open a recordio file and return the reader object. The returned reader object
is thread-safe.
Create a reader from a record io file NOTE: This is a very low-level API. It is used for debugging data file or
training. Please use `open_files` instead of this API for production usage.
)DOC"); )DOC");
} }
}; };

@ -54,7 +54,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
} }
void FileReaderMakerBase::Make() { void FileReaderMakerBase::Make() {
AddOutput("Out", "(ReaderHolder) The created random reader.").AsDuplicable(); AddOutput("Out", "(ReaderHolder): The created random reader.").AsDuplicable();
AddAttr<std::vector<int>>("shape_concat", "The concat of all data's shapes."); AddAttr<std::vector<int>>("shape_concat", "The concat of all data's shapes.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ranks", "ranks",

@ -78,23 +78,23 @@ class RowConvOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", AddInput("X",
"(LoDTensor), the input(X) is a LodTensor, which supports " "the input(X) is a LodTensor, which supports "
"variable time-length input sequences. The underlying tensor " "variable time-length input sequences. The underlying tensor "
"in this LoDTensor is a matrix with shape (T x N), where T " "in this LoDTensor is a matrix with shape (T x N), where T "
"is the total time steps in this mini-batch and N is the input " "is the total time steps in this mini-batch and N is the input "
"data dimension."); "data dimension.");
AddInput("Filter", AddInput("Filter",
"(Tensor), the input(Filter) is a learnable parameter. It " "the input(Filter) is a learnable parameter. It "
"is a 2-D tensor with shape (future_context x N), where, " "is a 2-D tensor with shape (future_context x N), where, "
"future_context is the future context length and N is the data " "future_context is the future context length and N is the data "
"dimension."); "dimension.");
AddOutput("Out", AddOutput("Out",
"(LoDTensor), the output(Out) is a LodTensor, which supports " "the output(Out) is a LodTensor, which supports "
"variable time-length input sequences. The underlying tensor " "variable time-length input sequences. The underlying tensor "
"in this LodTensor is a matrix with shape T x N, i.e., the " "in this LodTensor is a matrix with shape T x N, i.e., the "
"same shape as X."); "same shape as X.");
AddComment(R"DOC( AddComment(R"DOC(
Row-convolution Operator. :strong:`Row-convolution operator`
The row convolution is called lookahead convolution. This operator was The row convolution is called lookahead convolution. This operator was
introduced in the following paper for DeepSpeech2: introduced in the following paper for DeepSpeech2:
@ -114,9 +114,23 @@ and a filter ($W$) of size $context \times d$,
the output sequence is convolved as: the output sequence is convolved as:
$$ $$
out_{i, :} = \sum_{j=i}^{i + context} in_{j,:} \dot W_{i-j, :} out_{i, :} = \\sum_{j=i}^{i + context} in_{j,:} \\cdot W_{i-j, :}
$$ $$
In the above equation:
* $Out_{i}$: The i-th row of output variable with shape [1, D].
* $\\tau$: Future context size.
* $X_{j}$: The j-th row of input variable with shape [1, D].
* $W_{i-j}$: The (i-j)-th row of parameters with shape [1, D].
More details about row_conv please refer to
the design document
https://github.com/PaddlePaddle/Paddle/issues/2228#issuecomment-303903645 .
)DOC"); )DOC");
} }
}; };

@ -115,4 +115,7 @@ USE_CPU_ONLY_OP(concat);
REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, ops::SplitGradMaker); REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, ops::SplitGradMaker);
REGISTER_OP_CPU_KERNEL(split, REGISTER_OP_CPU_KERNEL(split,
ops::SplitOpKernel<paddle::platform::CPUPlace, float>); ops::SplitOpKernel<paddle::platform::CPUPlace, double>,
ops::SplitOpKernel<paddle::platform::CPUPlace, float>,
ops::SplitOpKernel<paddle::platform::CPUPlace, int64_t>,
ops::SplitOpKernel<paddle::platform::CPUPlace, int>);

@ -15,4 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/split_op.h" #include "paddle/fluid/operators/split_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
split, ops::SplitOpKernel<paddle::platform::CUDADeviceContext, float>); split, ops::SplitOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SplitOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SplitOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SplitOpKernel<paddle::platform::CUDADeviceContext, int>);

@ -86,32 +86,24 @@ class UniformRandomOp : public framework::OperatorWithKernel {
class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddOutput("Out", "(Tensor) The output tensor of uniform random op"); AddOutput("Out", "The output tensor of uniform random op");
AddComment(R"DOC( AddComment(R"DOC(
Uniform random operator.
This operator initializes a tensor with random values sampled from a This operator initializes a tensor with random values sampled from a
uniform distribution. uniform distribution. The random result is in set [min, max].
)DOC"); )DOC");
AddAttr<std::vector<int>>("shape", AddAttr<std::vector<int>>("shape", "The shape of the output tensor");
"(vector<int>) The shape of the output tensor"); AddAttr<float>("min", "Minimum value of uniform random. [default -1.0].")
AddAttr<float>("min",
"(float, default -1.0) "
"Minimum value of uniform random")
.SetDefault(-1.0f); .SetDefault(-1.0f);
AddAttr<float>("max", AddAttr<float>("max", "Maximun value of uniform random. [default 1.0].")
"(float, default 1.0) "
"Maximun value of uniform random")
.SetDefault(1.0f); .SetDefault(1.0f);
AddAttr<int>("seed", AddAttr<int>("seed",
"(int, default 0) "
"Random seed used for generating samples. " "Random seed used for generating samples. "
"0 means use a seed generated by the system." "0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always " "Note that if seed is not 0, this operator will always "
"generate the same random numbers every time.") "generate the same random numbers every time. [default 0].")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type") AddAttr<int>("dtype", "Output tensor data type. [default 5(FP32)].")
.SetDefault(framework::proto::VarType::FP32); .SetDefault(framework::proto::VarType::FP32);
} }
}; };

File diff suppressed because it is too large Load Diff

@ -210,53 +210,68 @@ def bipartite_match(dist_matrix,
dist_threshold=None, dist_threshold=None,
name=None): name=None):
""" """
**Bipartite matchint operator** This operator implements a greedy bipartite matching algorithm, which is
used to obtain the matching with the maximum distance based on the input
This operator is a greedy bipartite matching algorithm, which is used to
obtain the matching with the maximum distance based on the input
distance matrix. For input 2D matrix, the bipartite matching algorithm can distance matrix. For input 2D matrix, the bipartite matching algorithm can
find the matched column for each row, also can find the matched row for find the matched column for each row (matched means the largest distance),
each column. And this operator only calculate matched indices from column also can find the matched row for each column. And this operator only
to row. For each instance, the number of matched indices is the number of calculate matched indices from column to row. For each instance,
of columns of the input ditance matrix. the number of matched indices is the column number of the input distance
matrix.
There are two outputs to save matched indices and distance.
A simple description, this algothrim matched the best (maximum distance) There are two outputs, matched indices and distance.
A simple description, this algorithm matched the best (maximum distance)
row entity to the column entity and the matched indices are not duplicated row entity to the column entity and the matched indices are not duplicated
in each row of ColToRowMatchIndices. If the column entity is not matched in each row of ColToRowMatchIndices. If the column entity is not matched
any row entity, set -1 in ColToRowMatchIndices. any row entity, set -1 in ColToRowMatchIndices.
Please note that the input DistMat can be LoDTensor (with LoD) or Tensor. NOTE: the input DistMat can be LoDTensor (with LoD) or Tensor.
If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size. If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size.
If Tensor, the height of ColToRowMatchIndices is 1. If Tensor, the height of ColToRowMatchIndices is 1.
NOTE: This API is a very low level API. It is used by :code:`ssd_loss`
layer. Please consider to use :code:`ssd_loss` instead.
Args: Args:
dist_matrix(Variable): This input is a 2-D LoDTensor with shape dist_matrix(Variable): This input is a 2-D LoDTensor with shape
[K, M]. It is pair-wise distance matrix between the entities [K, M]. It is pair-wise distance matrix between the entities
represented by each row and each column. For example, assumed one represented by each row and each column. For example, assumed one
entity is A with shape [K], another entity is B with shape [M]. The entity is A with shape [K], another entity is B with shape [M]. The
dist_matirx[i][j] is the distance between A[i] and B[j]. The bigger dist_matrix[i][j] is the distance between A[i] and B[j]. The bigger
the distance is, the better macthing the pairs are. Please note, the distance is, the better matching the pairs are.
This tensor can contain LoD information to represent a batch of
inputs. One instance of this batch can contain different numbers of NOTE: This tensor can contain LoD information to represent a batch
entities. of inputs. One instance of this batch can contain different numbers
of entities.
match_type(string|None): The type of matching method, should be match_type(string|None): The type of matching method, should be
'bipartite' or 'per_prediction', 'bipartite' by defalut. 'bipartite' or 'per_prediction'. [default 'bipartite'].
dist_threshold(float|None): If `match_type` is 'per_prediction', dist_threshold(float|None): If `match_type` is 'per_prediction',
this threshold is to determine the extra matching bboxes based this threshold is to determine the extra matching bboxes based
on the maximum distance, 0.5 by defalut. on the maximum distance, 0.5 by default.
Returns: Returns:
match_indices(Variable): A 2-D Tensor with shape [N, M] in int type. tuple: a tuple with two elements is returned. The first is
matched_indices, the second is matched_distance.
The matched_indices is a 2-D Tensor with shape [N, M] in int type.
N is the batch size. If match_indices[i][j] is -1, it N is the batch size. If match_indices[i][j] is -1, it
means B[j] does not match any entity in i-th instance. means B[j] does not match any entity in i-th instance.
Otherwise, it means B[j] is matched to row Otherwise, it means B[j] is matched to row
match_indices[i][j] in i-th instance. The row number of match_indices[i][j] in i-th instance. The row number of
i-th instance is saved in match_indices[i][j]. i-th instance is saved in match_indices[i][j].
match_distance(Variable): A 2-D Tensor with shape [N, M] in float type.
N is batch size. If match_indices[i][j] is -1, The matched_distance is a 2-D Tensor with shape [N, M] in float type
. N is batch size. If match_indices[i][j] is -1,
match_distance[i][j] is also -1.0. Otherwise, assumed match_distance[i][j] is also -1.0. Otherwise, assumed
match_distance[i][j] = d, and the row offsets of each instance match_distance[i][j] = d, and the row offsets of each instance
are called LoD. Then match_distance[i][j] = dist_matrix[d+LoD[i]][j]. are called LoD. Then match_distance[i][j] =
dist_matrix[d+LoD[i]][j].
Examples:
>>> x = fluid.layers.data(name='x', shape=[4], dtype='float32')
>>> y = fluid.layers.data(name='y', shape=[4], dtype='float32')
>>> iou = fluid.layers.iou_similarity(x=x, y=y)
>>> matched_indices, matched_dist = fluid.layers.bipartite_match(iou)
""" """
helper = LayerHelper('bipartite_match', **locals()) helper = LayerHelper('bipartite_match', **locals())
match_indices = helper.create_tmp_variable(dtype='int32') match_indices = helper.create_tmp_variable(dtype='int32')
@ -364,7 +379,7 @@ def ssd_loss(location,
normalize=True, normalize=True,
sample_size=None): sample_size=None):
""" """
**Multi-box loss layer for object dection algorithm of SSD** **Multi-box loss layer for object detection algorithm of SSD**
This layer is to compute dection loss for SSD given the location offset This layer is to compute dection loss for SSD given the location offset
predictions, confidence predictions, prior boxes and ground-truth boudding predictions, confidence predictions, prior boxes and ground-truth boudding
@ -372,21 +387,35 @@ def ssd_loss(location,
is a weighted sum of the localization loss (or regression loss) and is a weighted sum of the localization loss (or regression loss) and
confidence loss (or classification loss) by performing the following steps: confidence loss (or classification loss) by performing the following steps:
1. Find matched boundding box by bipartite matching algorithm. 1. Find matched bounding box by bipartite matching algorithm.
1.1 Compute IOU similarity between ground-truth boxes and prior boxes. 1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
1.2 Compute matched boundding box by bipartite matching algorithm. 1.2 Compute matched boundding box by bipartite matching algorithm.
2. Compute confidence for mining hard examples 2. Compute confidence for mining hard examples
2.1. Get the target label based on matched indices. 2.1. Get the target label based on matched indices.
2.2. Compute confidence loss. 2.2. Compute confidence loss.
3. Apply hard example mining to get the negative example indices and update 3. Apply hard example mining to get the negative example indices and update
the matched indices. the matched indices.
4. Assign classification and regression targets 4. Assign classification and regression targets
4.1. Encoded bbox according to the prior boxes. 4.1. Encoded bbox according to the prior boxes.
4.2. Assign regression targets. 4.2. Assign regression targets.
4.3. Assign classification targets. 4.3. Assign classification targets.
5. Compute the overall objective loss. 5. Compute the overall objective loss.
5.1 Compute confidence loss. 5.1 Compute confidence loss.
5.1 Compute localization loss. 5.1 Compute localization loss.
5.3 Compute the overall weighted loss. 5.3 Compute the overall weighted loss.
Args: Args:
@ -421,39 +450,36 @@ def ssd_loss(location,
mining_type (str): The hard example mining type, should be 'hard_example' mining_type (str): The hard example mining type, should be 'hard_example'
or 'max_negative', now only support `max_negative`. or 'max_negative', now only support `max_negative`.
normalize (bool): Whether to normalize the SSD loss by the total number normalize (bool): Whether to normalize the SSD loss by the total number
of output locations, True by defalut. of output locations, True by default.
sample_size (int): The max sample size of negative box, used only when sample_size (int): The max sample size of negative box, used only when
mining_type is 'hard_example'. mining_type is 'hard_example'.
Returns: Returns:
Variable: The weighted sum of the localization loss and confidence loss, The weighted sum of the localization loss and confidence loss, with \
with shape [N * Np, 1], N and Np are the same as they are shape [N * Np, 1], N and Np are the same as they are in `location`.
in `location`.
Raises: Raises:
ValueError: If mining_type is 'hard_example', now only support ValueError: If mining_type is 'hard_example', now only support mining \
mining type of `max_negative`. type of `max_negative`.
Examples: Examples:
.. code-block:: python >>> pb = fluid.layers.data(
>>> name='prior_box',
pb = layers.data( >>> shape=[10, 4],
name='prior_box', >>> append_batch_size=False,
shape=[10, 4], >>> dtype='float32')
append_batch_size=False, >>> pbv = fluid.layers.data(
dtype='float32') >>> name='prior_box_var',
pbv = layers.data( >>> shape=[10, 4],
name='prior_box_var', >>> append_batch_size=False,
shape=[10, 4], >>> dtype='float32')
append_batch_size=False, >>> loc = fluid.layers.data(name='target_box', shape=[10, 4], dtype='float32')
dtype='float32') >>> scores = fluid.layers.data(name='scores', shape=[10, 21], dtype='float32')
loc = layers.data(name='target_box', shape=[10, 4], dtype='float32') >>> gt_box = fluid.layers.data(
scores = layers.data(name='scores', shape=[10, 21], dtype='float32') >>> name='gt_box', shape=[4], lod_level=1, dtype='float32')
gt_box = layers.data( >>> gt_label = fluid.layers.data(
name='gt_box', shape=[4], lod_level=1, dtype='float32') >>> name='gt_label', shape=[1], lod_level=1, dtype='float32')
gt_label = layers.data( >>> loss = fluid.layers.ssd_loss(loc, scores, gt_box, gt_label, pb, pbv)
name='gt_label', shape=[1], lod_level=1, dtype='float32')
loss = layers.ssd_loss(loc, scores, gt_box, gt_label, pb, pbv)
""" """
helper = LayerHelper('ssd_loss', **locals()) helper = LayerHelper('ssd_loss', **locals())

@ -292,6 +292,7 @@ def _copy_reader_create_op_(block, op):
return new_op return new_op
@templatedoc(op_type='create_recordio_file_reader')
def open_recordio_file(filename, def open_recordio_file(filename,
shapes, shapes,
lod_levels, lod_levels,
@ -299,34 +300,30 @@ def open_recordio_file(filename,
pass_num=1, pass_num=1,
for_parallel=True): for_parallel=True):
""" """
Open a RecordIO file ${comment}
This layer takes a RecordIO file to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from the given RecordIO file.
Args: Args:
filename(str): The RecordIO file's name. filename(${filename_type}): ${filename_comment}.
shapes(list): List of tuples which declaring data shapes. shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level. lod_levels(${lod_levels_type}): ${lod_levels_comment}.
dtypes(list): List of strs which declaring data type. dtypes(list): List of strs which declaring data type.
pass_num(int): Number of passes to run. pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel. subsequent operators in parallel.
Returns: Returns:
Variable: A Reader Variable via which we can get RecordIO file data. ${out_comment}.
Examples: Examples:
.. code-block:: python
reader = fluid.layers.io.open_recordio_file( >>> import paddle.fluid as fluid
filename='./data.recordio', >>> reader = fluid.layers.io.open_recordio_file(
shapes=[(3,224,224), (1)], >>> filename='./data.recordio',
lod_levels=[0, 0], >>> shapes=[(3,224,224), (1)],
dtypes=['float32', 'int64']) >>> lod_levels=[0, 0],
>>> dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data: >>> # Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader) >>> image, label = fluid.layers.io.read_file(reader)
""" """
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = [] shape_concat = []
@ -554,6 +551,29 @@ def batch(reader, batch_size):
def double_buffer(reader, place=None, name=None): def double_buffer(reader, place=None, name=None):
"""
Wrap a double buffer reader. The data will copy to target place with a
double buffer queue. If the target place is None, the place that executor
perform on will be used.
Args:
reader(Variable): the reader variable need to be wrapped.
place(Place): the place of target data. Default is the sample place of
executor perform.
name(str): Variable name. None if the user does not care.
Returns:
wrapped reader with double buffer.
Examples:
>>> reader = fluid.layers.open_files(filenames=['somefile'],
>>> shapes=[[-1, 784], [-1, 1]],
>>> dtypes=['float32', 'int64'])
>>> reader = fluid.layers.double_buffer(reader)
>>> img, label = fluid.layers.read_file(reader)
"""
attrs = dict() attrs = dict()
if place is not None: if place is not None:
attrs['place'] = str(place).upper() attrs['place'] = str(place).upper()
@ -587,6 +607,26 @@ def read_file(file_obj):
class Preprocessor(object): class Preprocessor(object):
"""
A block for data pre-processing in reader.
Args:
reader (Variable): A reader variable.
name (str, default None): The name of the reader.
Examples:
.. code-block:: python
preprocessor = fluid.layers.io.Preprocessor(reader=reader)
with preprocessor.block():
img, lbl = preprocessor.inputs()
img_out = img / 2
lbl_out = lbl + 1
preprocessor.outputs(img_out, lbl_out)
data_file = fluid.layers.io.double_buffer(preprocessor())
"""
BEFORE_SUB_BLOCK = 0 BEFORE_SUB_BLOCK = 0
IN_SUB_BLOCK = 1 IN_SUB_BLOCK = 1
AFTER_SUB_BLOCK = 2 AFTER_SUB_BLOCK = 2

File diff suppressed because it is too large Load Diff

@ -40,8 +40,6 @@ __activations__ = [
'relu6', 'relu6',
'pow', 'pow',
'stanh', 'stanh',
'hard_shrink',
'thresholded_relu',
'hard_sigmoid', 'hard_sigmoid',
'swish', 'swish',
] ]
@ -64,11 +62,9 @@ __all__ = [
'logical_or', 'logical_or',
'logical_xor', 'logical_xor',
'logical_not', 'logical_not',
'uniform_random',
'uniform_random_batch_size_like', 'uniform_random_batch_size_like',
'gaussian_random', 'gaussian_random',
'gaussian_random_batch_size_like', 'gaussian_random_batch_size_like',
'cumsum',
'scatter', 'scatter',
'sum', 'sum',
'slice', 'slice',
@ -79,3 +75,88 @@ __all__ = [
for _OP in set(__all__): for _OP in set(__all__):
globals()[_OP] = generate_layer_fn(_OP) globals()[_OP] = generate_layer_fn(_OP)
__all__ += ["uniform_random"]
_uniform_random_ = generate_layer_fn('uniform_random')
def uniform_random(shape, dtype=None, min=None, max=None, seed=None):
kwargs = dict()
for name in locals():
val = locals()[name]
if val is not None:
kwargs[name] = val
return _uniform_random_(**kwargs)
uniform_random.__doc__ = _uniform_random_.__doc__ + """
Examples:
>>> result = fluid.layers.uniform_random(shape=[32, 784])
"""
__all__ += ['hard_shrink']
_hard_shrink_ = generate_layer_fn('hard_shrink')
def hard_shrink(x, threshold=None):
kwargs = dict()
for name in locals():
val = locals()[name]
if val is not None:
kwargs[name] = val
return _hard_shrink_(**kwargs)
hard_shrink.__doc__ = _hard_shrink_.__doc__ + """
Examples:
>>> data = fluid.layers.data(name="input", shape=[784])
>>> result = fluid.layers.hard_shrink(x=data, threshold=0.3)
"""
__all__ += ['cumsum']
_cum_sum_ = generate_layer_fn('cumsum')
def cumsum(x, axis=None, exclusive=None, reverse=None):
kwargs = dict()
for name in locals():
val = locals()[name]
if val is not None:
kwargs[name] = val
return _cum_sum_(**kwargs)
cumsum.__doc__ = _cum_sum_.__doc__ + """
Examples:
>>> data = fluid.layers.data(name="input", shape=[32, 784])
>>> result = fluid.layers.cumsum(data, axis=0)
"""
__all__ += ['thresholded_relu']
_thresholded_relu_ = generate_layer_fn('thresholded_relu')
def thresholded_relu(x, threshold=None):
kwargs = dict()
for name in locals():
val = locals()[name]
if val is not None:
kwargs[name] = val
_thresholded_relu_(**kwargs)
thresholded_relu.__doc__ = _thresholded_relu_.__doc__ + """
Examples:
>>> data = fluid.layers.data(name="input", shape=[1])
>>> result = fluid.layers.thresholded_relu(data, threshold=0.4)
"""

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save