diff --git a/doc/faq/index_cn.rst b/doc/faq/index_cn.rst index acbf4c87ae..b3ecfba791 100644 --- a/doc/faq/index_cn.rst +++ b/doc/faq/index_cn.rst @@ -390,4 +390,125 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数 * 如果发现最早的报错就是网络通信的问题,很有可能是非独占方式执行导致的端口冲突,可以联系OP,看当前MPI集群是否支持resource=full参数提交,如果支持增加此参数提交,并更换job 端口。 -* 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。 \ No newline at end of file +* 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。 + +19. PaddlePaddle如何输出多个层 +------------------------------ + +* 将需要输出的层作为 :code:`paddle.inference.Inference()` 接口的 :code:`output_layer` 参数输入,代码如下: + +.. code-block:: python + + inferer = paddle.inference.Inference(output_layer=[layer1, layer2], parameters=parameters) + +* 指定要输出的字段进行输出。以输出 :code:`value` 字段为例,代码如下: + +.. code-block:: python + + out = inferer.infer(input=data_batch, flatten_result=False, field=["value"]) + +这里设置 :code:`flatten_result=False`,得到的输出结果是元素个数等于输出字段数的 :code:`list`,该 :code:`list` 的每个元素是由所有输出层相应字段结果组成的 :code:`list`,每个字段结果的类型是 :code:`numpy.array`。:code:`flatten_result` 的默认值为 :code:`True`,该情况下,PaddlePaddle会分别对每个字段将所有输出层的结果按行进行拼接,如果各输出层该字段 :code:`numpy.array` 结果的相应维数不匹配,程序将不能正常运行。 + +20. :code:`paddle.layer.memory` 的参数 :code:`name` 如何使用 +------------------------------------------------------------- + +* :code:`paddle.layer.memory` 用于获取特定layer上一时间步的输出,该layer是通过参数 :code:`name` 指定,即,:code:`paddle.layer.memory` 会关联参数 :code:`name` 取值相同的layer,并将该layer上一时间步的输出作为自身当前时间步的输出。 + +* PaddlePaddle的所有layer都有唯一的name,用户通过参数 :code:`name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。而 :code:`paddle.layer.memory` 不是真正的layer,其name由参数 :code:`memory_name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。:code:`paddle.layer.memory` 的参数 :code:`name` 用于指定其要关联的layer,需要用户显式设定。 + +21. dropout 使用 +----------------- + +* 在PaddlePaddle中使用dropout有两种方式 + + * 在相应layer的 :code:`layer_atter` 设置 :code:`drop_rate`,以 :code:`paddle.layer.fc` 为例,代码如下: + + .. code-block:: python + + fc = paddle.layer.fc(input=input, layer_attr=paddle.attr.ExtraLayerAttribute(drop_rate=0.5)) + + * 使用 :code:`paddle.layer.dropout`,以 :code:`paddle.layer.fc` 为例,代码如下: + + .. code-block:: python + + fc = paddle.layer.fc(input=input) + drop_fc = paddle.layer.dropout(input=fc, dropout_rate=0.5) + +* :code:`paddle.layer.dropout` 实际上使用了 :code:`paddle.layer.add_to`,并在该layer里采用第一种方式设置 :code:`drop_rate` 来使用dropout的。这种方式对内存消耗较大。 + +* PaddlePaddle在激活函数里实现dropout,而不是在layer里实现。 + +* :code:`paddle.layer.lstmemory`、:code:`paddle.layer.grumemory`、:code:`paddle.layer.recurrent` 不是通过一般的方式来实现对输出的激活,所以不能采用第一种方式在这几个layer里设置 :code:`drop_rate` 来使用dropout。若要对这几个layer使用dropout,可采用第二种方式,即使用 :code:`paddle.layer.dropout`。 + +22. 如何设置学习率退火(learning rate annealing) +------------------------------------------------ + +在相应的优化算法里设置learning_rate_schedule及相关参数,以使用Adam算法为例,代码如下: + +.. code-block:: python + + optimizer = paddle.optimizer.Adam( + learning_rate=1e-3, + learning_rate_decay_a=0.5, + learning_rate_decay_b=0.75, + learning_rate_schedule="poly",) + +PaddlePaddle目前支持8种learning_rate_schedule,这8种learning_rate_schedule及其对应学习率计算方式如下: + +* "constant" + + lr = learning_rate + +* "poly" + + lr = learning_rate * pow(1 + learning_rate_decay_a * num_samples_processed, -learning_rate_decay_b) + + 其中,num_samples_processed为已训练样本数,下同。 + +* "caffe_poly" + + lr = learning_rate * pow(1.0 - num_samples_processed / learning_rate_decay_a, learning_rate_decay_b) + +* "exp" + + lr = learning_rate * pow(learning_rate_decay_a, num_samples_processed / learning_rate_decay_b) + +* "discexp" + + lr = learning_rate * pow(learning_rate_decay_a, floor(num_samples_processed / learning_rate_decay_b)) + +* "linear" + + lr = max(learning_rate - learning_rate_decay_a * num_samples_processed, learning_rate_decay_b) + +* "manual" + + 这是一种按已训练样本数分段取值的学习率退火方法。使用该learning_rate_schedule时,用户通过参数 :code:`learning_rate_args` 设置学习率衰减因子分段函数,当前的学习率为所设置 :code:`learning_rate` 与当前的衰减因子的乘积。以使用Adam算法为例,代码如下: + + .. code-block:: python + + optimizer = paddle.optimizer.Adam( + learning_rate=1e-3, + learning_rate_schedule="manual", + learning_rate_args="1000:1.0,2000:0.9,3000:0.8",) + + 在该示例中,当已训练样本数小于等于1000时,学习率为 :code:`1e-3 * 1.0`;当已训练样本数大于1000小于等于2000时,学习率为 :code:`1e-3 * 0.9`;当已训练样本数大于2000时,学习率为 :code:`1e-3 * 0.8`。 + +* "pass_manual" + + 这是一种按已训练pass数分段取值的学习率退火方法。使用该learning_rate_schedule时,用户通过参数 :code:`learning_rate_args` 设置学习率衰减因子分段函数,当前的学习率为所设置 :code:`learning_rate` 与当前的衰减因子的乘积。以使用Adam算法为例,代码如下: + + .. code-block:: python + + optimizer = paddle.optimizer.Adam( + learning_rate=1e-3, + learning_rate_schedule="manual", + learning_rate_args="1:1.0,2:0.9,3:0.8",) + + 在该示例中,当已训练pass数小于等于1时,学习率为 :code:`1e-3 * 1.0`;当已训练pass数大于1小于等于2时,学习率为 :code:`1e-3 * 0.9`;当已训练pass数大于2时,学习率为 :code:`1e-3 * 0.8`。 + +23. 出现 :code:`Duplicated layer name` 错误怎么办 +-------------------------------------------------- + +出现该错误的原因一般是用户对不同layer的参数 :code:`name` 设置了相同的取值。遇到该错误时,先找出参数 :code:`name` 取值相同的layer,然后将这些layer的参数 :code:`name` 设置为不同的值。 + diff --git a/paddle/operators/crop_op.h b/paddle/operators/crop_op.h index 2f40c05903..ac3aeaf41e 100644 --- a/paddle/operators/crop_op.h +++ b/paddle/operators/crop_op.h @@ -38,10 +38,10 @@ class CropKernel : public framework::OpKernel { auto out_stride = framework::stride(out->dims()); auto offsets = context.Attr>("offsets"); PADDLE_ENFORCE_EQ( - x->dims().size(), offsets.size(), + x->dims().size(), static_cast(offsets.size()), "Offsets size should be equal to dimension size of input tensor."); int64_t offset = 0; - for (int i = 0; i < offsets.size(); ++i) { + for (size_t i = 0; i < offsets.size(); ++i) { offset += (x_stride[i] * offsets[i]); } StridedMemcpy(context.device_context(), x_data + offset, x_stride, @@ -57,7 +57,7 @@ void CropGradFunction(const framework::ExecutionContext& context) { d_x->mutable_data(context.GetPlace()); auto offsets = context.Attr>("offsets"); Eigen::array, D> paddings; - for (int i = 0; i < D; ++i) { + for (size_t i = 0; i < D; ++i) { paddings[i].first = offsets[i]; paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i]; } diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index def4b01da0..ba653afa2c 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -48,6 +48,32 @@ void gemm(const platform::DeviceContext& context, beta, C, ldc); } +template <> +void gemm(const platform::DeviceContext& context, + const bool transA, const bool transB, + const int M, const int N, const int K, + const float alpha, const float* A, + const int lda, const float* B, + const int ldb, const float beta, float* C, + const int ldc) { + cblas_sgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); +} + +template <> +void gemm(const platform::DeviceContext& context, + const bool transA, const bool transB, + const int M, const int N, const int K, + const double alpha, const double* A, + const int lda, const double* B, + const int ldb, const double beta, + double* C, const int ldc) { + cblas_dgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); +} + template <> void matmul( const platform::DeviceContext& context, const framework::Tensor& matrix_a, diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 71563b77b4..649f1f352c 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -63,6 +63,42 @@ void gemm(const platform::DeviceContext& context, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } +template <> +void gemm(const platform::DeviceContext& context, + const bool transA, const bool transB, + const int M, const int N, const int K, + const float alpha, const float* A, + const int lda, const float* B, + const int ldb, const float beta, float* C, + const int ldc) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T; + PADDLE_ENFORCE(platform::dynload::cublasSgemm( + reinterpret_cast(context) + .cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); +} + +template <> +void gemm(const platform::DeviceContext& context, + const bool transA, const bool transB, + const int M, const int N, const int K, + const double alpha, const double* A, + const int lda, const double* B, + const int ldb, const double beta, + double* C, const int ldc) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T; + PADDLE_ENFORCE(platform::dynload::cublasDgemm( + reinterpret_cast(context) + .cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); +} + template <> void matmul( const platform::DeviceContext& context, const framework::Tensor& matrix_a, diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index d8518e77fa..43306fca73 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -70,6 +70,13 @@ void gemm(const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const T alpha, const T* A, const T* B, const T beta, T* C); +// gemm wrapper with stride args for matrix uncontinuous in memory +template +void gemm(const platform::DeviceContext& context, const bool transA, + const bool transB, const int M, const int N, const int K, + const T alpha, const T* A, const int lda, const T* B, const int ldb, + const T beta, T* C, const int ldc); + // matrix multiply with continuous memory template void matmul(const platform::DeviceContext& context, diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index 7e339457f7..f272f7e513 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -72,4 +72,174 @@ TEST(math_function, trans_mul_notrans) { EXPECT_EQ(out_ptr[8], 29); delete gpu_place; } + +TEST(math_function, gemm_notrans_cublas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor input3_gpu; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({3, 4}, *cpu_place); + float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::CUDADeviceContext context(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input2, *gpu_place); + input3_gpu.CopyFrom(input3, *gpu_place); + float* a = input1_gpu.data(); + float* b = input2_gpu.data(); + float* c = input3_gpu.mutable_data(*gpu_place); + + paddle::operators::math::gemm( + context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4); + + input3.CopyFrom(input3_gpu, *cpu_place); + + // numpy code: + // a = np.arange(6).reshape(2, 3) + // b = np.arange(12).reshape(3, 4)[:, 1:] + // c = np.arange(8).reshape(2, 4)[:, 1:] + // out = np.arange(8).reshape(2, 4) + // out[:, 1:] = np.dot(a, b) + c + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); + delete gpu_place; +} + +TEST(math_function, gemm_trans_cublas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor input3_gpu; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({4, 3}, *cpu_place); + float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::CUDADeviceContext context(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input2, *gpu_place); + input3_gpu.CopyFrom(input3, *gpu_place); + float* a = input1_gpu.data(); + float* b = input2_gpu.data(); + float* c = input3_gpu.mutable_data(*gpu_place); + + paddle::operators::math::gemm( + context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4); + + input3.CopyFrom(input3_gpu, *cpu_place); + + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); + delete gpu_place; +} #endif + +TEST(math_function, gemm_notrans_cblas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({3, 4}, *cpu_place); + float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + paddle::platform::CPUDeviceContext context(*cpu_place); + paddle::operators::math::gemm( + context, false, false, m, n, k, 1, input1_ptr, 3, input2_ptr + 1, 4, 1, + input3_ptr + 1, 4); + + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); +} + +TEST(math_function, gemm_trans_clbas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({4, 3}, *cpu_place); + float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + paddle::platform::CPUDeviceContext context(*cpu_place); + paddle::operators::math::gemm( + context, false, true, m, n, k, 1, input1_ptr, 3, input2_ptr + 3, 3, 1, + input3_ptr + 1, 4); + + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); +} diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index e0c19a3190..44be9b38ce 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -25,24 +25,30 @@ class MultiplexOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"), + "Input(Ids) shouldn't be null."); PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), - "Input(X) should not be null."); + "MultiInput(X) shouldn't be empty."); PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), "Output(Out) shouldn't be null."); + auto ids_dim = ctx.Input("Ids")->dims(); + PADDLE_ENFORCE( + ids_dim.size() == 2 && ids_dim[1] == 1, + "The index tensor must be a vector with size batchSize x 1."); + auto ins = ctx.MultiInput("X"); auto *out = ctx.Output("Out"); auto num_ins = ins.size(); - PADDLE_ENFORCE(num_ins > 2, - "multiplex operator should have more than 2 inputs."); - PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1, - "The first input must be a index vector."); - auto in_dim = ins[1]->dims(); + PADDLE_ENFORCE(num_ins > 1, + "multiplex operator should have more than " + "one candidate input tensors."); - for (size_t i = 2; i < num_ins; i++) { + auto in_dim = ins[0]->dims(); + PADDLE_ENFORCE(in_dim.size() == 2, "Candidate tensors must be matrix."); + for (size_t i = 1; i < num_ins; i++) { auto dim = ins[i]->dims(); PADDLE_ENFORCE(in_dim == dim, - "All the input tensors except the first one must have the " - "same size."); + "All the candidate tensors must have the same size."); } out->Resize(in_dim); } @@ -53,25 +59,26 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { MultiplexOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input tensors of multiplex operator.").AsDuplicable(); + AddInput("Ids", "The index tensor of multiplex operator."); + AddInput("X", "The candidate tensors of multiplex operator.") + .AsDuplicable(); AddOutput("Out", "The output tensor of multiplex operator."); AddComment(R"DOC(Multiplex operator Multiplex multiple tensors according to the index provided by the first input tensor. -ins[0]: the index tensor. -ins[1:N]: the candidate output tensors. +Ids: the index tensor. +X[0 : N - 1]: the candidate tensors for output (N >= 2). For each index i from 0 to batchSize - 1, the output is the i-th row of the -the (index[i] + 1)-th tensor. +the (Ids[i])-th tensor. For i-th row of the output tensor: -y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1) +y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{0}.width - 1) where y is the output tensor. `x_{k}` is the k-th input tensor -and `k = x{0}[i] + 1`. - +and `k = Ids[i]`. )DOC"); } }; @@ -90,8 +97,8 @@ class MultiplexGradOp : public framework::OperatorWithKernel { "Input(Out@GRAD) shouldn't be null."); auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); auto ins = ctx.MultiInput("X"); - // don't compute gradient for index (ins[0]) - for (size_t i = 1; i < ins.size(); i++) { + // No need to compute gradient for Input(Ids) + for (size_t i = 0; i < ins.size(); i++) { if (d_ins[i]) { d_ins[i]->Resize(ins[i]->dims()); } diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index ae4c7d183a..d990b227e7 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -25,21 +25,23 @@ class MultiplexGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); + auto* ids = ctx.Input("Ids"); auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->dims()[1]; // copy index to cpu Tensor index_t_cpu; - index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); - auto* index = index_t_cpu.data(); + index_t_cpu.CopyFrom(*ids, platform::CPUPlace()); + auto* index = index_t_cpu.data(); auto stream = reinterpret_cast( ctx.device_context()) .stream(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - size_t k = (size_t)index[i] + 1; + int32_t k = index[i]; + PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative."); PADDLE_ENFORCE_LT(k, ins.size(), "index exceeds the number of candidate tensors."); memory::Copy(place, out->data() + i * cols, place, @@ -54,8 +56,9 @@ class MultiplexGradGPUKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* d_out = ctx.Input(framework::GradVarName("Out")); auto ins = ctx.MultiInput("X"); + auto* ids = ctx.Input("Ids"); auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); - for (size_t i = 1; i < d_ins.size(); i++) { + for (size_t i = 0; i < d_ins.size(); i++) { if (d_ins[i]) { d_ins[i]->mutable_data(ctx.GetPlace()); auto t = framework::EigenVector::Flatten(*d_ins[i]); @@ -63,19 +66,19 @@ class MultiplexGradGPUKernel : public framework::OpKernel { } } - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->dims()[1]; // copy index to cpu Tensor index_t_cpu; - index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); - auto* index = index_t_cpu.data(); + index_t_cpu.CopyFrom(*ids, platform::CPUPlace()); + auto* index = index_t_cpu.data(); auto stream = reinterpret_cast( ctx.device_context()) .stream(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - size_t k = (size_t)index[i] + 1; + size_t k = static_cast(index[i]); if (d_ins[k]) { memory::Copy(place, d_ins[k]->data() + i * cols, place, d_out->data() + i * cols, cols * sizeof(T), stream); diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h index 98b8ec930d..c39684920c 100644 --- a/paddle/operators/multiplex_op.h +++ b/paddle/operators/multiplex_op.h @@ -27,17 +27,19 @@ class MultiplexCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); + auto ids = ctx.Input("Ids"); auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; - auto* index = ins[0]->data(); + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->dims()[1]; + auto index = ids->data(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - size_t k = (size_t)index[i] + 1; - PADDLE_ENFORCE_LT(k, ins.size(), + int32_t k = index[i]; + PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative."); + PADDLE_ENFORCE_LT(static_cast(k), ins.size(), "index exceeds the number of candidate tensors."); memory::Copy(place, out->data() + i * cols, place, ins[k]->data() + i * cols, cols * sizeof(T)); @@ -50,10 +52,11 @@ class MultiplexGradCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto* ids = ctx.Input("Ids"); auto ins = ctx.MultiInput("X"); auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); - for (size_t i = 1; i < d_ins.size(); i++) { + for (size_t i = 0; i < d_ins.size(); i++) { if (d_ins[i]) { d_ins[i]->mutable_data(ctx.GetPlace()); auto t = framework::EigenVector::Flatten(*d_ins[i]); @@ -61,12 +64,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel { } } - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; - auto* index = ins[0]->data(); + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->dims()[1]; + auto* index = ids->data(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - size_t k = (size_t)index[i] + 1; + size_t k = static_cast(index[i]); if (d_ins[k]) { memory::Copy(place, d_ins[k]->data() + i * cols, place, d_out->data() + i * cols, cols * sizeof(T)); @@ -74,5 +77,5 @@ class MultiplexGradCPUKernel : public framework::OpKernel { } } }; -} -} +} // namespace operators +} // namespace paddle diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index c97e6c0a36..74025d2a7b 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -921,7 +921,7 @@ def data_layer(name, size, depth=None, height=None, width=None, data = data_layer(name="input", size=1000) - :param name: The name of this layer. It is optional. + :param name: The name of this layer. :type name: basestring :param size: Size of this data layer. :type size: int @@ -3668,6 +3668,7 @@ def gru_step_naive_layer(input, :param param_attr: :param layer_attr: :return: + :rtype: LayerOutput """ if input.size % 3 != 0: raise ValueError("GruStep input size must be divided by 3") diff --git a/python/paddle/v2/framework/tests/test_multiplex_op.py b/python/paddle/v2/framework/tests/test_multiplex_op.py index f2b3881cde..5937eb5aa4 100644 --- a/python/paddle/v2/framework/tests/test_multiplex_op.py +++ b/python/paddle/v2/framework/tests/test_multiplex_op.py @@ -6,20 +6,22 @@ from op_test import OpTest class TestMultiplexOp(OpTest): def setUp(self): self.op_type = "multiplex" - rows = 3 - index = np.array([3, 1, 0]) + rows = 4 + index = np.arange(0, rows).astype('int32') + np.random.shuffle(index) + index = np.reshape(index, (rows, 1)) ins1 = np.random.random((rows, 10)).astype("float32") ins2 = np.random.random((rows, 10)).astype("float32") ins3 = np.random.random((rows, 10)).astype("float32") ins4 = np.random.random((rows, 10)).astype("float32") self.inputs = { - 'X': [('index', index), ('x1', ins1), ('x2', ins2), ('x3', ins3), - ('x4', ins4)] + 'Ids': index, + 'X': [('x1', ins1), ('x2', ins2), ('x3', ins3), ('x4', ins4)] } # multiplex output output = np.zeros_like(ins1) for i in range(0, rows): - k = index[i] + 1 + k = index[i][0] output[i] = self.inputs['X'][k][1][i] self.outputs = {'Out': output}