From f14a7966b0f982a703f24940b1e9ae5325ee83c9 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 21 Sep 2017 03:18:42 +0000 Subject: [PATCH 1/4] Initialize the sequence softmax operator. --- paddle/operators/sequence_avg_pool_op.cc | 2 +- paddle/operators/sequence_softmax_op.cc | 82 +++++++++++++++++++ paddle/operators/sequence_softmax_op.cu | 25 ++++++ paddle/operators/sequence_softmax_op.h | 62 ++++++++++++++ .../tests/test_sequence_softmax_op.py | 35 ++++++++ 5 files changed, 205 insertions(+), 1 deletion(-) create mode 100644 paddle/operators/sequence_softmax_op.cc create mode 100644 paddle/operators/sequence_softmax_op.cu create mode 100644 paddle/operators/sequence_softmax_op.h create mode 100644 python/paddle/v2/framework/tests/test_sequence_softmax_op.py diff --git a/paddle/operators/sequence_avg_pool_op.cc b/paddle/operators/sequence_avg_pool_op.cc index 9815b8f3a8..ff354c62b6 100644 --- a/paddle/operators/sequence_avg_pool_op.cc +++ b/paddle/operators/sequence_avg_pool_op.cc @@ -36,7 +36,7 @@ class SequenceAvgPoolOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_GE( dims[0], /*batch size = */ static_cast(lod[0].size() - 1), - "The first dimension of Input(X) must be large than batch size."); + "The first dimension of Input(X) must be larger than batch size."); dims[0] = lod[0].size() - 1; ctx.Output("Out")->Resize({dims}); } diff --git a/paddle/operators/sequence_softmax_op.cc b/paddle/operators/sequence_softmax_op.cc new file mode 100644 index 0000000000..0a99717440 --- /dev/null +++ b/paddle/operators/sequence_softmax_op.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_softmax_op.h" + +namespace paddle { +namespace operators { + +class SequenceSoftmaxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar("X"), "Input(X) of SequenceSoftmaxOp should not be null."); + PADDLE_ENFORCE_NOT_NULL( + ctx.OutputVar("Out"), + "Output(Out) of SequenceSoftmaxOp should not be null."); + + auto *x = ctx.Input("X"); + auto dims = x->dims(); + auto lod = x->lod(); + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + PADDLE_ENFORCE_GE( + dims[0], + /* batch_size */ static_cast(lod[0].size() - 1), + "The first dimension of Input(X) should be larger than batch size."); + PADDLE_ENFORCE_EQ(x->numel(), static_cast(lod[0].size() - 1), + "The width of each timestep in Input(X) of " + "SequenceSoftmaxOp should be 1."); + + dims[0] = lod[0].size() - 1; + ctx.Output("Out")->Resize({dims}); + } +}; + +class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceSoftmaxOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(LoDTensor)"); + AddOutput("Out", "(LoDTensor)"); + AddComment(R"DOC( +Softmax of Sequence. +)DOC"); + } +}; + +class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sequence_softmax, ops::SequenceSoftmaxOp, + ops::SequenceSoftmaxOpMaker, sequence_softmax_grad, + ops::SequenceSoftmaxGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_softmax, + ops::SequenceSoftmaxKernel); +REGISTER_OP_CPU_KERNEL( + sequence_softmax_grad, + ops::SequenceSoftmaxGradKernel); diff --git a/paddle/operators/sequence_softmax_op.cu b/paddle/operators/sequence_softmax_op.cu new file mode 100644 index 0000000000..f2a1e3d5e3 --- /dev/null +++ b/paddle/operators/sequence_softmax_op.cu @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/operators/sequence_softmax_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + sequence_softmax, + ops::SequenceSoftmaxKernel) +REGISTER_OP_GPU_KERNEL( + sequence_softmax_grad, + ops::SequenceSoftmaxGradKernel); diff --git a/paddle/operators/sequence_softmax_op.h b/paddle/operators/sequence_softmax_op.h new file mode 100644 index 0000000000..54d8265271 --- /dev/null +++ b/paddle/operators/sequence_softmax_op.h @@ -0,0 +1,62 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/softmax_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +template +using EigenVector = framework::EigenVector; +template +using EigenMatrix = framework::EigenMatrix; + +template +class SequenceSoftmaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto lod = x->lod(); + const size_t level = lod.size(); + + out->mutable_data(ctx.GetPlace()); + for (int i = 0; i < static_cast(lod[level].size()) - 1; ++i) { + int start_pos = static_cast(lod[level][i]); + int end_pos = static_cast(lod[level][i + 1]); + Tensor x_i = x->Slice(start_pos, end_pos); + Tensor out_i = out->Slice(start_pos, end_pos); + + math::SoftmaxFunctor()(&x_i, &out_i, ctx); + } + } +}; + +template +class SequenceSoftmaxGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_sequence_softmax_op.py b/python/paddle/v2/framework/tests/test_sequence_softmax_op.py new file mode 100644 index 0000000000..d0667c1308 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_sequence_softmax_op.py @@ -0,0 +1,35 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def stable_softmax(x): + """Compute the softmax of vector x in a numerically stable way.""" + shiftx = x - np.max(x) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + +class TestSequenceSoftmaxOp(OpTest): + def setUp(self): + self.op_type = "sequence_softmax" + x = np.random.uniform(0.1, 1, (11, 1)).astype("float32") + lod = [[0, 4, 5, 8, 11]] + + out = np.zeros((11, 1)).astype("float32") + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + sub_x = sub_x.reshape(1, lod[0][i + 1] - lod[0][i]) + sub_out = stable_softmax(sub_x) + out[lod[0][i]:lod[0][i + 1], :] = sub_out.reshape( + lod[0][i + 1] - lod[0][i], 1) + + self.inputs = {"X": (x, lod)} + self.outputs = {"Out": out} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() From 12f2b8eb07b034a24e3a0e0538a757e389fb4c45 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 25 Sep 2017 04:12:04 +0000 Subject: [PATCH 2/4] Correct the forward of sequence_softmax_op. --- paddle/operators/reshape_op.cc | 3 +-- paddle/operators/sequence_softmax_op.cc | 10 ++++++---- paddle/operators/sequence_softmax_op.h | 6 +++++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index 0d05e34414..d5d04dac27 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -42,8 +42,7 @@ class ReshapeOp : public framework::OperatorWithKernel { int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); auto *in = ctx.Input("X"); - int64_t in_size = framework::product(in->dims()); - PADDLE_ENFORCE_EQ(capacity, in_size, + PADDLE_ENFORCE_EQ(capacity, in->numel(), "The size of Input(X) mismatches with Attr(shape)."); // resize output std::vector shape_int64(shape.size(), 0); diff --git a/paddle/operators/sequence_softmax_op.cc b/paddle/operators/sequence_softmax_op.cc index 0a99717440..58ef77b1a3 100644 --- a/paddle/operators/sequence_softmax_op.cc +++ b/paddle/operators/sequence_softmax_op.cc @@ -30,18 +30,20 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { "Output(Out) of SequenceSoftmaxOp should not be null."); auto *x = ctx.Input("X"); - auto dims = x->dims(); auto lod = x->lod(); - PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + auto dims = x->dims(); PADDLE_ENFORCE_GE( dims[0], /* batch_size */ static_cast(lod[0].size() - 1), "The first dimension of Input(X) should be larger than batch size."); - PADDLE_ENFORCE_EQ(x->numel(), static_cast(lod[0].size() - 1), + + const size_t level = lod.size() - 1; + PADDLE_ENFORCE_EQ(x->numel(), static_cast(lod[level].back()), "The width of each timestep in Input(X) of " "SequenceSoftmaxOp should be 1."); - dims[0] = lod[0].size() - 1; + std::cout << DebugString() << std::endl; + ctx.Output("Out")->Resize({dims}); } }; diff --git a/paddle/operators/sequence_softmax_op.h b/paddle/operators/sequence_softmax_op.h index 54d8265271..f39c2ec6c3 100644 --- a/paddle/operators/sequence_softmax_op.h +++ b/paddle/operators/sequence_softmax_op.h @@ -38,7 +38,7 @@ class SequenceSoftmaxKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); auto lod = x->lod(); - const size_t level = lod.size(); + const size_t level = lod.size() - 1; out->mutable_data(ctx.GetPlace()); for (int i = 0; i < static_cast(lod[level].size()) - 1; ++i) { @@ -47,6 +47,10 @@ class SequenceSoftmaxKernel : public framework::OpKernel { Tensor x_i = x->Slice(start_pos, end_pos); Tensor out_i = out->Slice(start_pos, end_pos); + // Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos) + framework::DDim dims = framework::make_ddim({1UL, end_pos - start_pos}); + x_i.Resize(dims); + out_i.Resize(dims); math::SoftmaxFunctor()(&x_i, &out_i, ctx); } } From 05ed8ee8ab35c5861a187deeca076322a2f9de34 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 28 Sep 2017 06:30:34 +0000 Subject: [PATCH 3/4] Add SoftmaxGradFunctor, and use SoftmaxGradFunctor in softmax_op instead. --- paddle/operators/CMakeLists.txt | 4 +-- paddle/operators/math/CMakeLists.txt | 9 +++-- paddle/operators/math/softmax.cc | 19 ++++++----- paddle/operators/math/softmax.cu | 19 ++++++----- paddle/operators/math/softmax.h | 49 +++++++++++++++++++++++----- paddle/operators/softmax_op.h | 31 +++++------------- 6 files changed, 74 insertions(+), 57 deletions(-) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index e56895c63a..da39c2cb55 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -94,8 +94,8 @@ set(DEPS_OPS op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor net_op) op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) -op_library(cross_entropy_op DEPS cross_entropy_function) -op_library(softmax_with_cross_entropy_op DEPS cross_entropy_function softmax_function) +op_library(cross_entropy_op DEPS cross_entropy) +op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 91ae3d49f1..b60e945aa8 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,15 +1,14 @@ if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) - nv_library(softmax_function SRCS softmax.cc softmax.cu - DEPS operator) - nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu + nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) + nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) - cc_library(softmax_function SRCS softmax.cc DEPS operator) - cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator) + cc_library(softmax SRCS softmax.cc DEPS operator) + cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/softmax.cc b/paddle/operators/math/softmax.cc index ac9f3c4bf6..0ba8197ab8 100644 --- a/paddle/operators/math/softmax.cc +++ b/paddle/operators/math/softmax.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/operators/math/softmax.h" @@ -19,6 +19,7 @@ namespace operators { namespace math { template class SoftmaxFunctor; +template class SoftmaxGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/softmax.cu b/paddle/operators/math/softmax.cu index 4c3df0550e..99f988d51e 100644 --- a/paddle/operators/math/softmax.cu +++ b/paddle/operators/math/softmax.cu @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #define EIGEN_USE_GPU @@ -21,6 +21,7 @@ namespace operators { namespace math { template class SoftmaxFunctor; +template class SoftmaxGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/softmax.h b/paddle/operators/math/softmax.h index 3d2f0d0aec..3c05a86bce 100644 --- a/paddle/operators/math/softmax.h +++ b/paddle/operators/math/softmax.h @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" @@ -68,6 +68,37 @@ class SoftmaxFunctor { .broadcast(one_by_class)); } }; + +template +class SoftmaxGradFunctor { + public: + void operator()(const framework::ExecutionContext& context, + const framework::Tensor* y, const framework::Tensor* y_grad, + framework::Tensor* x_grad) { + auto softmax = EigenMatrix::From(*y); + auto softmax_grad = EigenMatrix::From(*y_grad); + auto logits_grad = EigenMatrix::From(*x_grad); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = softmax.dimension(kBatchDim); + const int num_classes = softmax.dimension(kClassDim); + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + + auto dot = (softmax * softmax_grad) + .sum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class); + logits_grad.device(context.GetEigenDevice()) = + (softmax_grad - dot) * softmax; + } +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 7220f486be..3d35507a9a 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -29,8 +29,8 @@ template class SoftmaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto X = context.Input("X"); - auto Y = context.Output("Y"); + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); // allocate memory on device. Y->mutable_data(context.GetPlace()); @@ -43,29 +43,14 @@ template class SoftmaxGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto Y = context.Input("Y"); - auto dY = context.Input(framework::GradVarName("Y")); - auto dX = context.Output(framework::GradVarName("X")); - dX->mutable_data(context.GetPlace()); - - const int batch_size = Y->dims()[0]; - const int class_num = Y->dims()[1]; - - Eigen::DSizes along_class(1); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, class_num); + auto* Y = context.Input("Y"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); - auto Y_eigen = EigenMatrix::From(*Y); - auto dY_eigen = EigenMatrix::From(*dY); - auto dX_eigen = EigenMatrix::From(*dX); - auto place = context.GetEigenDevice(); + // allocate memory on device. + dX->mutable_data(context.GetPlace()); - auto dot = (Y_eigen * dY_eigen) - .sum(along_class) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class); - dX_eigen.device(place) = (dY_eigen - dot) * Y_eigen; + math::SoftmaxGradFunctor()(context, Y, dY, dX); } }; From 03897f251dc40ae3ded98a84caa3b40fed164de9 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 28 Sep 2017 06:39:23 +0000 Subject: [PATCH 4/4] Finish the SequenceSoftmaxGradKernel, using SoftmaxGradFunctor. --- paddle/operators/mul_op.cc | 32 ++++---- paddle/operators/sequence_softmax_op.cc | 79 ++++++++++++------- paddle/operators/sequence_softmax_op.h | 53 ++++++++++--- .../tests/test_sequence_softmax_op.py | 5 +- 4 files changed, 111 insertions(+), 58 deletions(-) diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 9858c4d9c2..3c8fe04d2e 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/operators/mul_op.h" @@ -35,12 +35,14 @@ class MulOp : public framework::OperatorWithKernel { int x_num_col_dims = ctx->Attrs().Get("x_num_col_dims"); int y_num_col_dims = ctx->Attrs().Get("y_num_col_dims"); - PADDLE_ENFORCE(x_dims.size() > x_num_col_dims, - "The rank of input tensor X should be larger than " - "`mul_op`'s `x_num_col_dims`."); - PADDLE_ENFORCE(y_dims.size() > y_num_col_dims, - "The rank of input tensor Y should be larger than " - "`mul_op`'s `y_num_col_dims`."); + PADDLE_ENFORCE_GT( + x_dims.size(), x_num_col_dims, + "The input tensor X's rank of MulOp should be larger than " + "x_num_col_dims."); + PADDLE_ENFORCE_GT( + y_dims.size(), y_num_col_dims, + "The input tensor Y's rank of MulOp should be larger than " + "y_num_col_dims."); auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); diff --git a/paddle/operators/sequence_softmax_op.cc b/paddle/operators/sequence_softmax_op.cc index 58ef77b1a3..e85b587a94 100644 --- a/paddle/operators/sequence_softmax_op.cc +++ b/paddle/operators/sequence_softmax_op.cc @@ -22,41 +22,42 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("X"), "Input(X) of SequenceSoftmaxOp should not be null."); - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Out"), - "Output(Out) of SequenceSoftmaxOp should not be null."); - - auto *x = ctx.Input("X"); - auto lod = x->lod(); - auto dims = x->dims(); - PADDLE_ENFORCE_GE( - dims[0], - /* batch_size */ static_cast(lod[0].size() - 1), - "The first dimension of Input(X) should be larger than batch size."); - - const size_t level = lod.size() - 1; - PADDLE_ENFORCE_EQ(x->numel(), static_cast(lod[level].back()), - "The width of each timestep in Input(X) of " - "SequenceSoftmaxOp should be 1."); - - std::cout << DebugString() << std::endl; - - ctx.Output("Out")->Resize({dims}); + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceSoftmaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceSoftmaxOp should not be null."); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); } }; class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { public: - SequenceSoftmaxOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + SequenceSoftmaxOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "(LoDTensor)"); - AddOutput("Out", "(LoDTensor)"); + AddInput("X", + "(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension " + "of length 1."); + AddOutput("Out", + "(LoDTensor) 1-D or 2-D output LoDTensor with the 2-nd dimension " + "of length 1."); AddComment(R"DOC( -Softmax of Sequence. +SequenceSoftmaxOp computes softmax activation among all time-steps for each +sequences. The dimension of each time-step should be 1. Thus, the shape of +input Tensor can be either [N, 1] or [N], where N is the sum of all sequences' +length. + +Equation: + for i-th sequence in mini-batch: + Out(X[lod[i]:lod[i+1]], :) = + exp(X[lod[i]:lod[i+1], :]) / sum(exp(X[lod[i]:lod[i+1], :])) + +For example, for a mini-batch of 3 sequences with variable-length, +each containing 2, 3, 2 time-steps, the lod of which is [0, 2, 5, 7], +then softmax will be computed among X[0:2, :], X[2:5, :], X[2:7, :] +and N turns out to be 7. )DOC"); } }; @@ -66,7 +67,25 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override {} + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Out"), + "Input(Out) of SequenceSoftmaxGradOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) of SequenceSoftmaxGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceSoftmaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) of SequenceSoftmaxOp should not be null."); + + PADDLE_ENFORCE_EQ( + ctx->GetInputDim("Out"), + ctx->GetInputDim(framework::GradVarName("Out")), + "Input(Out) and Input(Out@GRAD) of SequenceSoftmaxGradOp should be of " + "the same shape."); + + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } }; } // namespace operators @@ -81,4 +100,4 @@ REGISTER_OP_CPU_KERNEL( ops::SequenceSoftmaxKernel); REGISTER_OP_CPU_KERNEL( sequence_softmax_grad, - ops::SequenceSoftmaxGradKernel); + ops::SequenceSoftmaxGradKernel); diff --git a/paddle/operators/sequence_softmax_op.h b/paddle/operators/sequence_softmax_op.h index f39c2ec6c3..ca5cef4fc6 100644 --- a/paddle/operators/sequence_softmax_op.h +++ b/paddle/operators/sequence_softmax_op.h @@ -16,19 +16,13 @@ limitations under the License. */ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/math/softmax_function.h" +#include "paddle/operators/math/softmax.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -template -using EigenVector = framework::EigenVector; -template -using EigenMatrix = framework::EigenMatrix; template class SequenceSoftmaxKernel : public framework::OpKernel { @@ -38,7 +32,17 @@ class SequenceSoftmaxKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); auto lod = x->lod(); + auto dims = x->dims(); + + PADDLE_ENFORCE_GE( + dims[0], + /* batch_size */ static_cast(lod[0].size() - 1), + "The first dimension of Input(X) should be larger than batch size."); + const size_t level = lod.size() - 1; + PADDLE_ENFORCE_EQ(x->numel(), static_cast(lod[level].back()), + "The width of each timestep in Input(X) of " + "SequenceSoftmaxOp should be 1."); out->mutable_data(ctx.GetPlace()); for (int i = 0; i < static_cast(lod[level].size()) - 1; ++i) { @@ -48,10 +52,10 @@ class SequenceSoftmaxKernel : public framework::OpKernel { Tensor out_i = out->Slice(start_pos, end_pos); // Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos) - framework::DDim dims = framework::make_ddim({1UL, end_pos - start_pos}); - x_i.Resize(dims); - out_i.Resize(dims); - math::SoftmaxFunctor()(&x_i, &out_i, ctx); + framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos}); + x_i.Resize(dims_i); + out_i.Resize(dims_i); + math::SoftmaxFunctor()(ctx, &x_i, &out_i); } } }; @@ -59,7 +63,32 @@ class SequenceSoftmaxKernel : public framework::OpKernel { template class SequenceSoftmaxGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override {} + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* out_grad = ctx.Input(framework::GradVarName("Out")); + auto* x = ctx.Input("X"); + auto* x_grad = ctx.Output(framework::GradVarName("X")); + + auto lod = x->lod(); + const size_t level = lod.size() - 1; + + x_grad->mutable_data(ctx.GetPlace()); + for (int i = 0; i < static_cast(lod[level].size()) - 1; ++i) { + int start_pos = static_cast(lod[level][i]); + int end_pos = static_cast(lod[level][i + 1]); + + Tensor out_i = out->Slice(start_pos, end_pos); + Tensor out_grad_i = out_grad->Slice(start_pos, end_pos); + Tensor x_grad_i = x_grad->Slice(start_pos, end_pos); + + // Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos) + framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos}); + out_i.Resize(dims_i); + out_grad_i.Resize(dims_i); + x_grad_i.Resize(dims_i); + math::SoftmaxGradFunctor()(ctx, &out_i, &out_grad_i, &x_grad_i); + } + } }; } // namespace operators diff --git a/python/paddle/v2/framework/tests/test_sequence_softmax_op.py b/python/paddle/v2/framework/tests/test_sequence_softmax_op.py index d0667c1308..b54a56aa6d 100644 --- a/python/paddle/v2/framework/tests/test_sequence_softmax_op.py +++ b/python/paddle/v2/framework/tests/test_sequence_softmax_op.py @@ -5,7 +5,7 @@ from op_test import OpTest def stable_softmax(x): """Compute the softmax of vector x in a numerically stable way.""" - shiftx = x - np.max(x) + shiftx = x - np.max(x).clip(-64.) exps = np.exp(shiftx) return exps / np.sum(exps) @@ -30,6 +30,9 @@ class TestSequenceSoftmaxOp(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(["X"], "Out", max_relative_error=0.01) + if __name__ == "__main__": unittest.main()