From 695b10377e3905f8ac519668e005b8deaa8f2ed9 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 15 Nov 2017 20:05:08 +0800 Subject: [PATCH 1/8] port hsigmoid layer --- paddle/operators/hierarchical_sigmoid_op.cc | 121 ++++++++++++++++++++ paddle/operators/hierarchical_sigmoid_op.h | 35 ++++++ paddle/operators/math/CMakeLists.txt | 1 + paddle/operators/math/matrix_bit_code.cc | 84 ++++++++++++++ paddle/operators/math/matrix_bit_code.h | 64 +++++++++++ 5 files changed, 305 insertions(+) create mode 100644 paddle/operators/hierarchical_sigmoid_op.cc create mode 100644 paddle/operators/hierarchical_sigmoid_op.h create mode 100644 paddle/operators/math/matrix_bit_code.cc create mode 100644 paddle/operators/math/matrix_bit_code.h diff --git a/paddle/operators/hierarchical_sigmoid_op.cc b/paddle/operators/hierarchical_sigmoid_op.cc new file mode 100644 index 0000000000..1f77ff1268 --- /dev/null +++ b/paddle/operators/hierarchical_sigmoid_op.cc @@ -0,0 +1,121 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hierarchical_sigmoid_op.h" + +namespace paddle { +namespace operators { + +/** + * Organize the classes into a binary tree. At each node, a sigmoid function + * is used to calculate the probability of belonging to the right branch. + * This idea is from "F. Morin, Y. Bengio (AISTATS 05): + * Hierarchical Probabilistic Neural Network Language Model." + * + * Here we uses a simple way of making the binary tree. + * Assuming the number of classes C = 6, + * The classes are organized as a binary tree in the following way: + * + * @code{.py} + * *-*-*- 2 + * | | |- 3 + * | | + * | |-*- 4 + * | |- 5 + * | + * |-*- 0 + * |- 1 + * @endcode + * + * where * indicates an internal node, and each leaf node represents a class. + * - Node 0 ... C-2 are internal nodes. + * - Node C-1 ... 2C-2 are leaf nodes. + * - Class c is represented by leaf node \f$c+C-1\f$. + * + * We assign an id for each node: + * - the id of root be 0. + * - the left child of a node i is 2*i+1. + * - the right child of a node i is 2*i+2. + * + * It's easy to see that: + * - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$. + * - the j-th level ancestor of node i is + * \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$. + * - A node i is a left child of its parent if \f$(i-1)\%2==0\f$. + * + */ + +class HierarchicalSigmoidOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); + const int64_t batch_size = ctx->GetInputsDim("X")[0][0]; + const int64_t size = ctx->GetInputsDim("X").size(); + std::vector output_shape({batch_size, size}); + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } +}; + +class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override {} +}; + +class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { + public: + HierarchicalSigmoidOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(TensorArray, required) The input array. Each Tensor has the " + "same shape with [N * D]." + .AsDuplicable(); + AddInput("Label", + "(Tensor, required), The labels of training data. It's a" + "1-D tensor."); + AddInput("Bias", + "(Tensor, optional), The bias is a 1-D tensor, " + "which is applied to the output"); + AddOutput("Out", + "(Tensor, required) The output of hierarchical sigmoid operator."); + AddAttr("num_classes", + "(int, required)", + "The number of classes"); + AddComment(R"DOC( +The hierarchical sigmoid operator organize the classes into a binary tree. +At each node, a sigmoid function is used to caculate the probability of +belonging to the right branch. This idea is from +"F. Morin, Y. Bengio (AISTATS 05): +Hierarchical Probabilistic Neural Network Language Model." + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, + ops::HierarchicalSigmoidOpMaker, hierarchical_sigmoid_grad, + ops::HierarchicalSigmoidGradOp); +REGISTER_OP_CPU_KERNEL( + hierarchical_sigmoid, + ops::HierarchicalSigmoidOpKernel); +REGISTER_OP_CPU_KERNEL( + hierarchical_sigmoid_grad, + ops::HierarchicalSigmoidGradOpKernel); diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h new file mode 100644 index 0000000000..8a753605d6 --- /dev/null +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/matrix_bit_code.h" + +namespace paddle { +namespace operators { +template + +class HierarchicalSigmoidOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +template +class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index ab7f23f570..cc96b27c25 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -26,6 +26,7 @@ else() cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) + cc_library(matrix_bit_code SRCS matrix_bit_code.cc) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc new file mode 100644 index 0000000000..3f1dbbf399 --- /dev/null +++ b/paddle/operators/math/matrix_bit_code.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "matrix_bit_code.h" + +namespace paddle { +namespace operators { +namespace math { + +/** + * CodeTable class should support 3 functions: + * + * size_t size() + * return the number of codes + * + * int getMaxCodeLength() + * return the maximal code length + * + * Code operator()(size_t i) + * return the i-th code. Code class is descriebed below. + * + * Code class should support 3 functions: + * + * int getLength() + * return the length of the code + * + * bool calcIndex(int bit) + * bit ranges from 0 to getLength() - 1 + * return the index for the (1+bit) level parent + * + * bool calcBit(int bit) + * return true if the bit level parent is the right child of (1+bit) level + * parent + * + */ + +/* + for i: + for j < codeLength: + op(a(i, j), b(0, index(i, j))) +*/ +template +static void AddByBitCodeT(Op op, CodeTable code_table, + const framework::Tensor& codes, framework::Tensor& a, + framework::Tensor& b) { + size_t num_classes = code_table.size(); + size_t max_code_length = code_table.get_max_code_length(); + size_t num_sample = a.dims()[0].size(); + size_t width = a.dims()[1].size(); + + for (size_t i = 0; i < num_sample; ++i) { + auto code = code_table(codes.data()[i]) int code_length = + code.get_length(); + for (int j = 0; j < code_length; + j) { + size_t index = code.calc_index(j); + op(a.data()[i * width + j], b.data()[index]); + } + } +} + +/* For j < codeLength: + a(i, j) += b(0, index(i, j)) +*/ +template +void AddByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& a, const framework::Tensor& b) { + auto op = [](T& t, T& v) { t += v; }; + AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, a, b); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/matrix_bit_code.h b/paddle/operators/math/matrix_bit_code.h new file mode 100644 index 0000000000..a0dd89ebe0 --- /dev/null +++ b/paddle/operators/math/matrix_bit_code.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +/** + * return the 1-based index of the highest bit set + * + * for x > 0: + * \f[ + * findLastSet(x) = 1 + \floor*{\log_{2}x} + * \f] + */ +inline constexpr size_t FindLastSet(size_t x) { + return std::is_same::value + ? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0) + : (std::is_same::value // NOLINT + ? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0) + : (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0)); +} + +struct SimpleCode { + SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {} + inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } + inline bool calc_bit(int bit) const { return c_ & (1 << bit); } + inline int get_length() const { return FindLastSet(c_) - 1; } + + private: + size_t c_; +}; + +struct SimpleCodeTable { + explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {} + SimpleCode operator()(size_t code) const { + return SimpleCode(code, num_classes_); + } + size_t size() const { return num_classes_; } + int get_max_code_length() const { return FindLastSet(num_classes_ - 1); } + + private: + size_t num_classes_; + int max_code_length_; +}; + +} // namespace math +} // namespace operators +} // namespace paddle From a25c3aeba6d1a370e4b361f26f0e112fd00e7c4e Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Fri, 17 Nov 2017 10:31:44 +0800 Subject: [PATCH 2/8] add forward --- paddle/operators/hierarchical_sigmoid_op.cc | 13 ++++++------- paddle/operators/hierarchical_sigmoid_op.h | 16 +++++++++++++++- paddle/operators/math/matrix_bit_code.cc | 6 +++--- paddle/operators/math/matrix_bit_code.h | 4 ++++ 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/paddle/operators/hierarchical_sigmoid_op.cc b/paddle/operators/hierarchical_sigmoid_op.cc index 1f77ff1268..9b7af92662 100644 --- a/paddle/operators/hierarchical_sigmoid_op.cc +++ b/paddle/operators/hierarchical_sigmoid_op.cc @@ -83,19 +83,18 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "(TensorArray, required) The input array. Each Tensor has the " - "same shape with [N * D]." - .AsDuplicable(); + "same shape with [N * D].") + .AsDuplicable(); AddInput("Label", "(Tensor, required), The labels of training data. It's a" "1-D tensor."); AddInput("Bias", "(Tensor, optional), The bias is a 1-D tensor, " "which is applied to the output"); - AddOutput("Out", - "(Tensor, required) The output of hierarchical sigmoid operator."); - AddAttr("num_classes", - "(int, required)", - "The number of classes"); + AddOutput( + "Out", + "(Tensor, required) The output of hierarchical sigmoid operator."); + AddAttr("num_classes", "(int, required)", "The number of classes"); AddComment(R"DOC( The hierarchical sigmoid operator organize the classes into a binary tree. At each node, a sigmoid function is used to caculate the probability of diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index 8a753605d6..11a553a403 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -22,7 +22,21 @@ template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override {} + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto* label = ctx.Input("Label"); + auto* bias = ctx.Input("Bias"); + size_t num_classes = static_cast(ctx.Attr("num_classes")); + int64_t batch_size = ins[0]->dims()[0]; + int64_t size = ins.size(); + framework::Tensor pre_out; + std::vector pre_out_dims({batch_size, size}); + pre_out.mutable_data(framework::make_ddim(pre_out_dims), ctx.GetPlace()); + + if (bias != NULL) { + math::AddByBitCode(num_classes, *label, pre_out, *bias); + } + } }; template diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index 3f1dbbf399..30c2ffc2cf 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -50,7 +50,7 @@ namespace math { for j < codeLength: op(a(i, j), b(0, index(i, j))) */ -template +template static void AddByBitCodeT(Op op, CodeTable code_table, const framework::Tensor& codes, framework::Tensor& a, framework::Tensor& b) { @@ -72,11 +72,11 @@ static void AddByBitCodeT(Op op, CodeTable code_table, /* For j < codeLength: a(i, j) += b(0, index(i, j)) */ -template +template void AddByBitCode(size_t num_classes, const framework::Tensor& codes, framework::Tensor& a, const framework::Tensor& b) { auto op = [](T& t, T& v) { t += v; }; - AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, a, b); + AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, a, b); } } // namespace math diff --git a/paddle/operators/math/matrix_bit_code.h b/paddle/operators/math/matrix_bit_code.h index a0dd89ebe0..bb0599aa17 100644 --- a/paddle/operators/math/matrix_bit_code.h +++ b/paddle/operators/math/matrix_bit_code.h @@ -59,6 +59,10 @@ struct SimpleCodeTable { int max_code_length_; }; +template +void AddByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& a, const framework::Tensor& b); + } // namespace math } // namespace operators } // namespace paddle From 1abd3b3a29b6964323d679d47dea31830f5b5e6a Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 27 Nov 2017 19:28:57 +0800 Subject: [PATCH 3/8] implement forward --- paddle/operators/CMakeLists.txt | 4 +- paddle/operators/hierarchical_sigmoid_op.cc | 6 +- paddle/operators/hierarchical_sigmoid_op.h | 39 ++++++++++- paddle/operators/math/math_function.cc | 2 + paddle/operators/math/math_function.cu | 2 + paddle/operators/math/math_function.h | 6 ++ paddle/operators/math/math_function_impl.h | 14 ++++ paddle/operators/math/matrix_bit_code.cc | 77 +++++++++++++++++++-- paddle/operators/math/matrix_bit_code.h | 19 ++++- 9 files changed, 157 insertions(+), 12 deletions(-) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a719da2560..93ec763424 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -185,7 +185,8 @@ set(DEPS_OPS tensor_array_read_write_op gru_op adagrad_op - sgd_op) + sgd_op + hierarchical_sigmoid_op) op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) @@ -203,6 +204,7 @@ op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table) op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op) op_library(array_to_lod_tensor_op SRCS array_to_lod_tensor_op.cc DEPS lod_rank_table_op) op_library(tensor_array_read_write_op SRCS tensor_array_read_write_op.cc) +op_library(hierarchical_sigmoid_op DEPS matrix_bit_code) if(WITH_GPU) op_library(nccl_op DEPS nccl_common) endif() diff --git a/paddle/operators/hierarchical_sigmoid_op.cc b/paddle/operators/hierarchical_sigmoid_op.cc index 9b7af92662..f81f3d34d1 100644 --- a/paddle/operators/hierarchical_sigmoid_op.cc +++ b/paddle/operators/hierarchical_sigmoid_op.cc @@ -85,12 +85,16 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { "(TensorArray, required) The input array. Each Tensor has the " "same shape with [N * D].") .AsDuplicable(); + AddInput("Parameters", + "(Tensor, required), The parameters of hierarchical " + "sigmoid operator, each of them is s a 2-D tensor.") + .AsDuplicable(); AddInput("Label", "(Tensor, required), The labels of training data. It's a" "1-D tensor."); AddInput("Bias", "(Tensor, optional), The bias is a 1-D tensor, " - "which is applied to the output"); + "which is applied to the output."); AddOutput( "Out", "(Tensor, required) The output of hierarchical sigmoid operator."); diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index 11a553a403..baf655f214 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -14,28 +14,61 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/matrix_bit_code.h" namespace paddle { namespace operators { -template +template +using EigenMatrix = framework::EigenMatrix; + +template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput("X"); + auto params = ctx.MultiInput("Parameters"); auto* label = ctx.Input("Label"); auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); size_t num_classes = static_cast(ctx.Attr("num_classes")); + + framework::Tensor sum; + framework::Tensor pre_out; + auto place = ctx.GetEigenDevice(); + auto& device_ctx = ctx.device_context(); + math::ColwiseSum col_sum; + math::RowwiseSum row_sum; + + auto pre_out_mat = EigenMatrix::From(pre_out); int64_t batch_size = ins[0]->dims()[0]; int64_t size = ins.size(); - framework::Tensor pre_out; + std::vector pre_out_dims({batch_size, size}); pre_out.mutable_data(framework::make_ddim(pre_out_dims), ctx.GetPlace()); + std::vector sum_dims({batch_size, 1UL}); + sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); + out->mutable_data(ctx.GetPlace()); - if (bias != NULL) { + if (bias) { math::AddByBitCode(num_classes, *label, pre_out, *bias); } + + for (size_t i = 0; i < ins.size(); ++i) { + math::MulByBitCode(num_classes, *label, pre_out, *params[i], *ins[i]); + } + // clip the matrix with (-40, 40) + pre_out_mat.device(place) = + pre_out_mat.abs().cwiseMax(static_cast(40.0)); + math::SumByBitCode(num_classes, *label, *out, pre_out, + static_cast(-1)); + // softrelu + pre_out_mat.device(place) = (static_cast(1) + pre_out_mat.exp()).log(); + + row_sum(device_ctx, pre_out, &sum); + col_sum(device_ctx, *out, &sum); } }; diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 2e333a8cde..3bc0945fe3 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -314,6 +314,8 @@ template struct RowwiseAdd; template struct RowwiseAdd; template struct ColwiseSum; template struct ColwiseSum; +template struct RowwiseSum; +template struct RowwiseSum; } // namespace math } // namespace operators diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 58356a4b77..1a226821f7 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -298,6 +298,8 @@ template struct RowwiseAdd; template struct RowwiseAdd; template struct ColwiseSum; template struct ColwiseSum; +template struct RowwiseSum; +template struct RowwiseSum; } // namespace math } // namespace operators diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index ffb99f5380..c21a20fc32 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -130,6 +130,12 @@ struct ColwiseSum { const framework::Tensor& input, framework::Tensor* vec); }; +template +struct RowwiseSum { + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor* vec); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index 4dc17a4e52..8c1971fc61 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -78,6 +78,20 @@ void ColwiseSum::operator()(const platform::DeviceContext& context, in.sum(Eigen::array({{0}})).reshape(shape); } +template +void RowwiseSum::operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor* vector) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[1]; + PADDLE_ENFORCE_EQ(vector->numel(), size); + + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenMatrix::From(*vector); + Eigen::array shape({{static_cast(size), 1}}); + vec.reshape(shape).device(*context.GetEigenDevice()) = + in.sum(Eigen::array({{0}})).reshape(shape); +} } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index 30c2ffc2cf..8f68e2f79d 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -53,18 +53,18 @@ namespace math { template static void AddByBitCodeT(Op op, CodeTable code_table, const framework::Tensor& codes, framework::Tensor& a, - framework::Tensor& b) { + const framework::Tensor& b) { size_t num_classes = code_table.size(); size_t max_code_length = code_table.get_max_code_length(); - size_t num_sample = a.dims()[0].size(); - size_t width = a.dims()[1].size(); + size_t num_sample = a.dims()[0]; + size_t width = a.dims()[1]; for (size_t i = 0; i < num_sample; ++i) { - auto code = code_table(codes.data()[i]) int code_length = - code.get_length(); + auto code = code_table(codes.data()[i]); + int code_length = code.get_length(); for (int j = 0; j < code_length; + j) { size_t index = code.calc_index(j); - op(a.data()[i * width + j], b.data()[index]); + op(a.data()[i * width + j], b.data()[index]); } } } @@ -79,6 +79,71 @@ void AddByBitCode(size_t num_classes, const framework::Tensor& codes, AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, a, b); } +template +void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes, + framework::Tensor& tmat, framework::Tensor& sum, + const T& scale_sum) { + size_t max_code_length = code_table.get_max_code_length(); + size_t num_samples = tmat.dims()[0]; + size_t o_width = tmat.dims()[1]; + for (size_t i = 0; i < num_samples; ++i) { + T sm = 0; + auto code = code_table(codes.data()[i]); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + sm += tmat.data()[i * o_width + j]; + } + } + sum.data()[i] = scale_sum * sm; + } +} +/* For j < codeLength: + sum(i, 0) = \sum_j bit(i, j) * input(i, j) +*/ +template +void SumByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& tmat, framework::Tensor& sum, + T scale_sum) { + SumByBitCodeT(SimpleCodeTable(num_classes), codes, tmat, scale_sum); +} + +template +void MulByBitCodeT(Op op, CodeTable code_table, const framework::Tensor& codes, + framework::Tensor& tmat, framework::Tensor& weight, + framework::Tensor& input) { + size_t num_classes = code_table.size(); + size_t max_code_length = code_table.get_max_code_length(); + size_t num_samples = tmat.dims()[0]; + size_t input_dim = input.dims()[1]; + size_t o_width = tmat.dims()[1]; + + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(codes.data()[i]); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); + op(tmat.data()[i * o_width + j], + weight.data() + index * weight.dims()[1], + input.data() + i * input.dims()[1], input_dim); + } + } +} + +template +void MulByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& tmat, const framework::Tensor& weight, + const framework::Tensor& input) { + auto op = [](T& t, const T* weight_row, const T* input_row, + size_t input_dim) { + T sum = 0; + for (size_t k = 0; k < input_dim; ++k) { + sum += weight_row[k] * input_row[k]; + } + t += sum; + }; + MulByBitCode(op, SimpleCodeTable(num_classes), codes, tmat, weight, input); +} } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/matrix_bit_code.h b/paddle/operators/math/matrix_bit_code.h index bb0599aa17..7bef5077b9 100644 --- a/paddle/operators/math/matrix_bit_code.h +++ b/paddle/operators/math/matrix_bit_code.h @@ -59,10 +59,27 @@ struct SimpleCodeTable { int max_code_length_; }; +/* For j < codeLength + tmat(i, j) += vec(0, index(i, j)) +*/ template void AddByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& a, const framework::Tensor& b); + framework::Tensor& tmat, const framework::Tensor& vec); +/* For j < codeLength + sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) +*/ +template +void SumByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& tmat, framework::Tensor& sum, T scale_sum); + +/* For j < codeLength + input.row(i) += tmat(i, j) * weight.row(index(i, j)) +*/ +template +void MulByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& tmat, const framework::Tensor& weight, + const framework::Tensor& input); } // namespace math } // namespace operators } // namespace paddle From 1f9426fd47d4cc3911e9f0a2f23274d69dd104e8 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 29 Nov 2017 20:17:22 +0800 Subject: [PATCH 4/8] add backward --- paddle/operators/hierarchical_sigmoid_op.h | 52 ++++++++++++++-- paddle/operators/math/matrix_bit_code.cc | 71 +++++++++++++++++++--- paddle/operators/math/matrix_bit_code.h | 36 ++++++++++- 3 files changed, 144 insertions(+), 15 deletions(-) diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index baf655f214..186c767932 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -44,9 +44,9 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { auto pre_out_mat = EigenMatrix::From(pre_out); int64_t batch_size = ins[0]->dims()[0]; - int64_t size = ins.size(); + int64_t code_length = math::FindLastSet(num_classes - 1); - std::vector pre_out_dims({batch_size, size}); + std::vector pre_out_dims({batch_size, code_length}); pre_out.mutable_data(framework::make_ddim(pre_out_dims), ctx.GetPlace()); std::vector sum_dims({batch_size, 1UL}); sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); @@ -64,8 +64,11 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { pre_out_mat.abs().cwiseMax(static_cast(40.0)); math::SumByBitCode(num_classes, *label, *out, pre_out, static_cast(-1)); - // softrelu - pre_out_mat.device(place) = (static_cast(1) + pre_out_mat.exp()).log(); + + // softrelu with threshold is 40.0 + pre_out_mat.device(place) = + pre_out_mat.abs().cwiseMax(static_cast(40.0)); + pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); row_sum(device_ctx, pre_out, &sum); col_sum(device_ctx, *out, &sum); @@ -75,7 +78,46 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { template class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override {} + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto ins_grad = + ctx.MultiOutput(framework::GradVarName("X")); + auto params = ctx.MultiOutput( + framework::GradVarName("Parameters")); + auto* bias = ctx.Output(framework::GradVarName("Bias")); + auto* label = + ctx.Output(framework::GradVarName("Label")); + size_t num_classes = static_cast(ctx.Attr("num_classes")); + + framework::Tensor pre_out; + auto place = ctx.GetEigenDevice(); + auto& dev_ctx = ctx.device_context(); + int64_t batch_size = ins_grad.size(); + int64_t code_length = math::FindLastSet(num_classes - 1); + auto pre_out_mat = EigenMatrix::From(pre_out); + + // init pre_out matrix with {1.0} + std::vector pre_out_dims({batch_size, code_length}); + pre_out.mutable_data(framework::make_ddim(pre_out_dims), ctx.GetPlace()); + math::SetConstant set; + set(dev_ctx, &pre_out, static_cast(1.0)); + // softrelu derivative + pre_out_mat.device(place) = + pre_out_mat * (static_cast(1.0) - static_cast(1.0) / pre_out_mat); + + math::SubByBitCode(num_classes, *label, pre_out); + + if (bias) { + math::AddByBitCodeGrad(num_classes, *label, pre_out, *bias); + } + + for (size_t i = 0; i < ins_grad.size(); ++i) { + math::MulByBitCodeGradWeight(num_classes, *label, pre_out, *params[i], + *ins[i]); + math::MulByBitCodeGradError(num_classes, *label, pre_out, *params[i], + *ins_grad[i]); + } + } }; } // namespace operators diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index 8f68e2f79d..996e0b819f 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -69,19 +69,23 @@ static void AddByBitCodeT(Op op, CodeTable code_table, } } -/* For j < codeLength: - a(i, j) += b(0, index(i, j)) -*/ template void AddByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& a, const framework::Tensor& b) { + framework::Tensor& tmat, const framework::Tensor& vec) { auto op = [](T& t, T& v) { t += v; }; - AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, a, b); + AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, vec); +} + +template +void AddByBitCodeGrad(size_t num_classes, const framework::Tensor& codes, + const framework::Tensor& tmat, framework::Tensor& vec) { + auto op = [](T& t, T& v) { v += t; }; + AddByBitCode(op, SimpleCodeTable(num_classes), codes, tmat, vec); } template void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes, - framework::Tensor& tmat, framework::Tensor& sum, + framework::Tensor& tmat, const framework::Tensor& sum, const T& scale_sum) { size_t max_code_length = code_table.get_max_code_length(); size_t num_samples = tmat.dims()[0]; @@ -142,8 +146,61 @@ void MulByBitCode(size_t num_classes, const framework::Tensor& codes, } t += sum; }; - MulByBitCode(op, SimpleCodeTable(num_classes), codes, tmat, weight, input); + MulByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, weight, + input); +} + +template +void MulByBitCodeGradWeight(size_t num_classes, const framework::Tensor& codes, + const framework::Tensor& tmat, + framework::Tensor& weight, + const framework::Tensor& input) { + auto op = [](const T t, T* weight_row, const T* input_row, size_t input_dim) { + for (size_t k = 0; k < input_dim; ++k) { + weight_row[k] += t * input_row[k]; + } + }; + MulByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, weight, + input); } + +template +void MulByBitCodeGradError(size_t num_classes, const framework::Tensor& codes, + const framework::Tensor& tmat, + const framework::Tensor& weight, + framework::Tensor& input) { + auto op = [](const T t, const T* weight_row, T* input_row, size_t input_dim) { + for (size_t k = 0; k < input_dim; ++k) { + input_row[k] += t * weight_row[k]; + } + }; + MulByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, weight, + input); +} + +template +void SubByBitCodeT(CodeTable code_table, const framework::Tensor& codes, + framework::Tensor& tmat) { + size_t max_code_length = code_table.get_max_code_length(); + size_t num_samples = tmat.dims()[0]; + size_t o_width = tmat.dims()[1]; + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(codes.data()[i]); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + tmat.data()[i * o_width + j] -= 1; + } + } + } +} + +template +void SubByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& tmat) { + SubByBitCodeT(SimpleCodeTable(num_classes), codes, tmat); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/matrix_bit_code.h b/paddle/operators/math/matrix_bit_code.h index 7bef5077b9..43c9d43d89 100644 --- a/paddle/operators/math/matrix_bit_code.h +++ b/paddle/operators/math/matrix_bit_code.h @@ -59,27 +59,57 @@ struct SimpleCodeTable { int max_code_length_; }; -/* For j < codeLength +/* For j < code_length tmat(i, j) += vec(0, index(i, j)) */ template void AddByBitCode(size_t num_classes, const framework::Tensor& codes, framework::Tensor& tmat, const framework::Tensor& vec); -/* For j < codeLength +/* For j < code_length + vec(0, index(i, j)) += tmat(i, j) +*/ +template +void AddByBitCodeGrad(size_t num_classes, const framework::Tensor& codes, + const framework::Tensor& tmat, framework::Tensor& vec); +/* For j < code_length sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) */ template void SumByBitCode(size_t num_classes, const framework::Tensor& codes, framework::Tensor& tmat, framework::Tensor& sum, T scale_sum); -/* For j < codeLength +/* For j < code_length input.row(i) += tmat(i, j) * weight.row(index(i, j)) */ template void MulByBitCode(size_t num_classes, const framework::Tensor& codes, framework::Tensor& tmat, const framework::Tensor& weight, const framework::Tensor& input); + +/* For index(i, j) >= 0: + weight.row(index(i, j)) += tmat(i, j) * input.row(i) +*/ +template +void MulByBitCodeGradWeight(size_t num_classes, const framework::Tensor& codes, + const framework::Tensor& tmat, + framework::Tensor& weight, + const framework::Tensor& input); +/* For j < code_length + input.row(i) += tmat(i, j) * weight.row(index(i, j)) +*/ +template +void MulByBitCodeGradError(size_t num_classes, const framework::Tensor& codes, + const framework::Tensor& tmat, + const framework::Tensor& weight, + framework::Tensor& input); + +/* For j < code_length + tmat(i, j) -= bit(i, j) +*/ +template +void SubByBitCode(size_t num_classes, const framework::Tensor& codes, + framework::Tensor& tmat); } // namespace math } // namespace operators } // namespace paddle From f8395631e12e19d433ce4d5b6ddcefc0b04db6e1 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 14 Dec 2017 13:12:00 +0800 Subject: [PATCH 5/8] fix invalid dims --- paddle/operators/hierarchical_sigmoid_op.cc | 28 +++++++-------- paddle/operators/hierarchical_sigmoid_op.h | 26 +++++++------- paddle/operators/math/matrix_bit_code.cc | 11 +++--- .../paddle/v2/fluid/tests/test_hsigmoid_op.py | 34 +++++++++++++++++++ 4 files changed, 67 insertions(+), 32 deletions(-) create mode 100644 python/paddle/v2/fluid/tests/test_hsigmoid_op.py diff --git a/paddle/operators/hierarchical_sigmoid_op.cc b/paddle/operators/hierarchical_sigmoid_op.cc index f81f3d34d1..063f8576e6 100644 --- a/paddle/operators/hierarchical_sigmoid_op.cc +++ b/paddle/operators/hierarchical_sigmoid_op.cc @@ -60,12 +60,11 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null."); + PADDLE_ENFORCE(ctx->hasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); - const int64_t batch_size = ctx->GetInputsDim("X")[0][0]; - const int64_t size = ctx->GetInputsDim("X").size(); - std::vector output_shape({batch_size, size}); + const int64_t batch_size = ctx->GetInputDim("X")[0]; + std::vector output_shape({batch_size, num_classes_ - 1}); ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); } }; @@ -82,22 +81,23 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "(TensorArray, required) The input array. Each Tensor has the " - "same shape with [N * D].") - .AsDuplicable(); + "(Tensor, required) The input Tensor, which the shape is" + "[N * D], which N is the size of mini-batch," + "D is the embded size"); AddInput("Parameters", "(Tensor, required), The parameters of hierarchical " - "sigmoid operator, each of them is s a 2-D tensor.") - .AsDuplicable(); + "sigmoid operator, each of them is s a 3-D tensor, the shape is" + "[N, num_classes - 1, D]"); AddInput("Label", "(Tensor, required), The labels of training data. It's a" - "1-D tensor."); + "1-D tensor, which the shape is [1, N]"); AddInput("Bias", "(Tensor, optional), The bias is a 1-D tensor, " - "which is applied to the output."); - AddOutput( - "Out", - "(Tensor, required) The output of hierarchical sigmoid operator."); + "which is applied to the output, the shape is" + "[1, num_classes -1]"); + AddOutput("Out", + "(Tensor, required) The output of hierarchical sigmoid operator." + "the shape is [N, 1]"); AddAttr("num_classes", "(int, required)", "The number of classes"); AddComment(R"DOC( The hierarchical sigmoid operator organize the classes into a binary tree. diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index 186c767932..e3f0bcacd8 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -28,8 +28,8 @@ template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto ins = ctx.MultiInput("X"); - auto params = ctx.MultiInput("Parameters"); + auto* in = ctx.Input("X"); + auto* param = ctx.Input("Parameter"); auto* label = ctx.Input("Label"); auto* bias = ctx.Input("Bias"); auto* out = ctx.Output("Out"); @@ -56,8 +56,9 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { math::AddByBitCode(num_classes, *label, pre_out, *bias); } - for (size_t i = 0; i < ins.size(); ++i) { - math::MulByBitCode(num_classes, *label, pre_out, *params[i], *ins[i]); + for (size_t i = 0; i < in.dims()[0]; ++i) { + math::MulByBitCode(num_classes, *label, pre_out, + *params->Slice(i, i + 1), *in->Slice(i, i + 1)); } // clip the matrix with (-40, 40) pre_out_mat.device(place) = @@ -79,11 +80,10 @@ template class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto ins = ctx.MultiInput("X"); - auto ins_grad = - ctx.MultiOutput(framework::GradVarName("X")); - auto params = ctx.MultiOutput( - framework::GradVarName("Parameters")); + auto* in = ctx.Input("X"); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + auto* params = + ctx.Output(framework::GradVarName("Parameters")); auto* bias = ctx.Output(framework::GradVarName("Bias")); auto* label = ctx.Output(framework::GradVarName("Label")); @@ -92,7 +92,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { framework::Tensor pre_out; auto place = ctx.GetEigenDevice(); auto& dev_ctx = ctx.device_context(); - int64_t batch_size = ins_grad.size(); + int64_t batch_size = in_grad.dims()[0]; int64_t code_length = math::FindLastSet(num_classes - 1); auto pre_out_mat = EigenMatrix::From(pre_out); @@ -111,11 +111,11 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { math::AddByBitCodeGrad(num_classes, *label, pre_out, *bias); } - for (size_t i = 0; i < ins_grad.size(); ++i) { + for (size_t i = 0; i < in_grad.dims()[0]; ++i) { math::MulByBitCodeGradWeight(num_classes, *label, pre_out, *params[i], - *ins[i]); + *in[i]->Slice(i, i + 1)); math::MulByBitCodeGradError(num_classes, *label, pre_out, *params[i], - *ins_grad[i]); + *ins_grad[i]->Slice(i, i + 1)); } } }; diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index 996e0b819f..df98851054 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -52,19 +52,20 @@ namespace math { */ template static void AddByBitCodeT(Op op, CodeTable code_table, - const framework::Tensor& codes, framework::Tensor& a, - const framework::Tensor& b) { + const framework::Tensor& codes, + framework::Tensor& tmat, + const framework::Tensor& vec) { size_t num_classes = code_table.size(); size_t max_code_length = code_table.get_max_code_length(); - size_t num_sample = a.dims()[0]; - size_t width = a.dims()[1]; + size_t num_sample = tmat.dims()[0]; + size_t width = vec.dims()[1]; for (size_t i = 0; i < num_sample; ++i) { auto code = code_table(codes.data()[i]); int code_length = code.get_length(); for (int j = 0; j < code_length; + j) { size_t index = code.calc_index(j); - op(a.data()[i * width + j], b.data()[index]); + op(tmat.data()[i * width + j], vec.data()[index]); } } } diff --git a/python/paddle/v2/fluid/tests/test_hsigmoid_op.py b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py new file mode 100644 index 0000000000..25c13aabe9 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py @@ -0,0 +1,34 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestHSigmoidOp(OpTest): + def setUp(self): + self.op_type = "hierarchical_sigmoid_op" + num_classes = 6 + embded_size = 10 + batch_size = 5 + x = np.random.random((batch_size, embded_size)).astype("float32") + parameter = np.random.random( + (batch_size, num_classes - 1, embded_size)).astype("float32") + label = np.random.randint(0, num_classes, batch_size).astype("int64") + bias = np.random.random((1, num_classes - 1)) + self.inputs = { + 'X': x, + 'Parameters': parameter, + 'Label': label, + 'Bias': bias + } + self.attrs = {'num_classes': num_classes} + self.outputs = {'Out': label} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['x0'], 'Out') + + +if __name__ == '__main__': + unittest.main() From fb9c08f0438fe0a25d0d2517e5e770ea5a22555d Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 25 Dec 2017 08:51:04 +0800 Subject: [PATCH 6/8] make forward work --- paddle/operators/CMakeLists.txt | 2 +- paddle/operators/hierarchical_sigmoid_op.cc | 50 ++++- paddle/operators/hierarchical_sigmoid_op.h | 95 ++++---- paddle/operators/math/CMakeLists.txt | 2 +- paddle/operators/math/math_function.cc | 12 +- paddle/operators/math/math_function.h | 6 +- paddle/operators/math/math_function_impl.h | 16 +- paddle/operators/math/matrix_bit_code.cc | 211 ++++++++++-------- paddle/operators/math/matrix_bit_code.h | 86 ++++--- paddle/pybind/pybind.cc | 2 + python/paddle/v2/fluid/tests/op_test.py | 11 +- .../paddle/v2/fluid/tests/test_hsigmoid_op.py | 15 +- 12 files changed, 285 insertions(+), 223 deletions(-) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index d79f19e670..5fb14cc6d4 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -207,7 +207,7 @@ set(DEPS_OPS gru_op adagrad_op sgd_op - hierarchical_sigmoid_op) + hierarchical_sigmoid_op save_op load_op send_op diff --git a/paddle/operators/hierarchical_sigmoid_op.cc b/paddle/operators/hierarchical_sigmoid_op.cc index 063f8576e6..fa816d9215 100644 --- a/paddle/operators/hierarchical_sigmoid_op.cc +++ b/paddle/operators/hierarchical_sigmoid_op.cc @@ -60,19 +60,48 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->hasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Parameters"), + "Input(Parameters)" + "should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); const int64_t batch_size = ctx->GetInputDim("X")[0]; - std::vector output_shape({batch_size, num_classes_ - 1}); + std::vector output_shape({batch_size, 1}); ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace()); + } }; class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Parameters"), + "Input(Parameters)" + "should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input(Label)" + "should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Parameters")), + "Input(Parameters@Grad should not be null.)"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X"))); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace()); + } }; class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { @@ -98,7 +127,8 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(Tensor, required) The output of hierarchical sigmoid operator." "the shape is [N, 1]"); - AddAttr("num_classes", "(int, required)", "The number of classes"); + AddAttr("num_classes", "(int, required)", "The number of classes") + .SetDefault(2); AddComment(R"DOC( The hierarchical sigmoid operator organize the classes into a binary tree. At each node, a sigmoid function is used to caculate the probability of @@ -116,9 +146,9 @@ namespace ops = paddle::operators; REGISTER_OP(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, ops::HierarchicalSigmoidOpMaker, hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp); -REGISTER_OP_CPU_KERNEL( - hierarchical_sigmoid, - ops::HierarchicalSigmoidOpKernel); -REGISTER_OP_CPU_KERNEL( - hierarchical_sigmoid_grad, - ops::HierarchicalSigmoidGradOpKernel); +REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid, + ops::HierarchicalSigmoidOpKernel< + paddle::platform::CPUDeviceContext, float>); +REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid_grad, + ops::HierarchicalSigmoidGradOpKernel< + paddle::platform::CPUDeviceContext, float>); diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index e3f0bcacd8..531fd9f7fc 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -14,8 +14,10 @@ limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" +#include "paddle/operators/clip_op.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/matrix_bit_code.h" +#include "paddle/platform/transform.h" namespace paddle { namespace operators { @@ -23,60 +25,64 @@ namespace operators { template using EigenMatrix = framework::EigenMatrix; +using platform::Transform; -template +template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); - auto* param = ctx.Input("Parameter"); + auto* params = ctx.Input("Parameters"); auto* label = ctx.Input("Label"); auto* bias = ctx.Input("Bias"); auto* out = ctx.Output("Out"); size_t num_classes = static_cast(ctx.Attr("num_classes")); - framework::Tensor sum; + int64_t code_length = math::FindLastSet(num_classes - 1); + int64_t batch_size = in->dims()[0]; + auto* ids = label->data(); framework::Tensor pre_out; - auto place = ctx.GetEigenDevice(); - auto& device_ctx = ctx.device_context(); - math::ColwiseSum col_sum; - math::RowwiseSum row_sum; - + framework::Tensor sum; + auto pre_out_data = pre_out.mutable_data( + framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); auto pre_out_mat = EigenMatrix::From(pre_out); - int64_t batch_size = ins[0]->dims()[0]; - int64_t code_length = math::FindLastSet(num_classes - 1); - std::vector pre_out_dims({batch_size, code_length}); - pre_out.mutable_data(framework::make_ddim(pre_out_dims), ctx.GetPlace()); + auto& place = *ctx.template device_context().eigen_device(); + auto& device_ctx = ctx.template device_context(); + math::RowwiseSum row_sum; + math::MatrixBitCodeFunctor bit_code; + std::vector sum_dims({batch_size, 1UL}); sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); + auto sum_mat = EigenMatrix::From(sum); out->mutable_data(ctx.GetPlace()); + auto out_mat = framework::EigenVector::Flatten(*out); if (bias) { - math::AddByBitCode(num_classes, *label, pre_out, *bias); + bit_code.Add(num_classes, ids, pre_out, *bias); } - - for (size_t i = 0; i < in.dims()[0]; ++i) { - math::MulByBitCode(num_classes, *label, pre_out, - *params->Slice(i, i + 1), *in->Slice(i, i + 1)); + for (int i = 0; i < in->dims()[0]; ++i) { + bit_code.Mul(num_classes, ids, pre_out, params->Slice(i, i + 1), + in->Slice(i, i + 1)); } // clip the matrix with (-40, 40) - pre_out_mat.device(place) = - pre_out_mat.abs().cwiseMax(static_cast(40.0)); - math::SumByBitCode(num_classes, *label, *out, pre_out, - static_cast(-1)); - + Transform trans; + trans(ctx.template device_context(), pre_out_data, + pre_out_data + pre_out.numel(), pre_out_data, + ClipFunctor(static_cast(-40.0), static_cast(40.0))); + bit_code.Sum(num_classes, ids, pre_out, *out, static_cast(-1)); // softrelu with threshold is 40.0 - pre_out_mat.device(place) = - pre_out_mat.abs().cwiseMax(static_cast(40.0)); + trans(ctx.template device_context(), pre_out_data, + pre_out_data + pre_out.numel(), pre_out_data, + ClipFunctor(static_cast(-40.0), static_cast(40.0))); pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); row_sum(device_ctx, pre_out, &sum); - col_sum(device_ctx, *out, &sum); + out_mat.device(place) = sum_mat + out_mat; } }; -template +template class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -85,37 +91,40 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto* params = ctx.Output(framework::GradVarName("Parameters")); auto* bias = ctx.Output(framework::GradVarName("Bias")); - auto* label = - ctx.Output(framework::GradVarName("Label")); + auto* label = ctx.Input("Label"); size_t num_classes = static_cast(ctx.Attr("num_classes")); + int64_t code_length = math::FindLastSet(num_classes - 1); + int64_t batch_size = in->dims()[0]; framework::Tensor pre_out; - auto place = ctx.GetEigenDevice(); - auto& dev_ctx = ctx.device_context(); - int64_t batch_size = in_grad.dims()[0]; - int64_t code_length = math::FindLastSet(num_classes - 1); + pre_out.mutable_data(framework::make_ddim({batch_size, code_length}), + ctx.GetPlace()); + auto& place = *ctx.template device_context().eigen_device(); + auto& device_ctx = ctx.template device_context(); auto pre_out_mat = EigenMatrix::From(pre_out); + auto* ids = label->data(); // init pre_out matrix with {1.0} - std::vector pre_out_dims({batch_size, code_length}); - pre_out.mutable_data(framework::make_ddim(pre_out_dims), ctx.GetPlace()); - math::SetConstant set; - set(dev_ctx, &pre_out, static_cast(1.0)); + math::SetConstant one; + math::MatrixBitCodeFunctor bit_code; + one(device_ctx, &pre_out, static_cast(1.0)); // softrelu derivative pre_out_mat.device(place) = pre_out_mat * (static_cast(1.0) - static_cast(1.0) / pre_out_mat); - math::SubByBitCode(num_classes, *label, pre_out); + bit_code.Sub(num_classes, ids, pre_out); if (bias) { - math::AddByBitCodeGrad(num_classes, *label, pre_out, *bias); + bit_code.AddGrad(num_classes, ids, pre_out, *bias); } - for (size_t i = 0; i < in_grad.dims()[0]; ++i) { - math::MulByBitCodeGradWeight(num_classes, *label, pre_out, *params[i], - *in[i]->Slice(i, i + 1)); - math::MulByBitCodeGradError(num_classes, *label, pre_out, *params[i], - *ins_grad[i]->Slice(i, i + 1)); + for (int i = 0; i < in_grad->dims()[0]; ++i) { + auto p_sliced = params->Slice(i, i + 1); + auto in_sliced = in->Slice(i, i + 1); + auto in_grad_sliced = in_grad->Slice(i, i + 1); + bit_code.MulGradWeight(num_classes, ids, pre_out, p_sliced, in_sliced); + bit_code.MulGradError(num_classes, ids, pre_out, p_sliced, + in_grad_sliced); } } }; diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 6467d8ddb3..82ba24f35b 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -27,7 +27,7 @@ else() cc_library(context_project SRCS context_project.cc DEPS device_context math_function) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) - cc_library(matrix_bit_code SRCS matrix_bit_code.cc) + cc_library(matrix_bit_code SRCS matrix_bit_code.cc DEPS device_context) cc_library(maxouting SRCS maxouting.cc DEPS device_context) cc_library(unpooling SRCS unpooling.cc DEPS device_context) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index ead0fe1971..474fd0b0a9 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -302,12 +302,12 @@ void set_constant(const platform::DeviceContext& context, #endif } -template struct RowwiseAdd; -template struct RowwiseAdd; -template struct ColwiseSum; -template struct ColwiseSum; -template struct RowwiseSum; -template struct RowwiseSum; +template struct RowwiseAdd; +template struct RowwiseAdd; +template struct ColwiseSum; +template struct ColwiseSum; +template struct RowwiseSum; +template struct RowwiseSum; } // namespace math } // namespace operators diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 51e0fd9ad7..b49294e621 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -128,10 +128,10 @@ struct ColwiseSum { framework::Tensor* vec); }; -template +template struct RowwiseSum { - void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor* vec); + void operator()(const DeviceContext& context, const framework::Tensor& input, + framework::Tensor* vec); }; } // namespace math diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index 4d5e848101..2b3b6c335b 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -79,19 +79,19 @@ void ColwiseSum::operator()(const DeviceContext& context, in.sum(Eigen::array({{0}})).reshape(shape); } -template -void RowwiseSum::operator()(const platform::DeviceContext& context, - const framework::Tensor& input, - framework::Tensor* vector) { +template +void RowwiseSum::operator()(const DeviceContext& context, + const framework::Tensor& input, + framework::Tensor* vector) { auto in_dims = input.dims(); auto size = input.numel() / in_dims[1]; PADDLE_ENFORCE_EQ(vector->numel(), size); - auto in = framework::EigenMatrix::From(input); - auto vec = framework::EigenMatrix::From(*vector); + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenMatrix::From(*vector); Eigen::array shape({{static_cast(size), 1}}); - vec.reshape(shape).device(*context.GetEigenDevice()) = - in.sum(Eigen::array({{0}})).reshape(shape); + vec.reshape(shape).device(*context.eigen_device()) = + in.sum(Eigen::array({{1}})).reshape(shape); } } // namespace math } // namespace operators diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index df98851054..9e3836b06d 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -50,50 +50,52 @@ namespace math { for j < codeLength: op(a(i, j), b(0, index(i, j))) */ -template -static void AddByBitCodeT(Op op, CodeTable code_table, - const framework::Tensor& codes, - framework::Tensor& tmat, +template +static void AddByBitCodeT(Op op, CodeTable code_table, const int64_t* codes, + const framework::Tensor& tmat, const framework::Tensor& vec) { - size_t num_classes = code_table.size(); - size_t max_code_length = code_table.get_max_code_length(); size_t num_sample = tmat.dims()[0]; size_t width = vec.dims()[1]; for (size_t i = 0; i < num_sample; ++i) { - auto code = code_table(codes.data()[i]); + auto code = code_table(static_cast(codes[i])); int code_length = code.get_length(); - for (int j = 0; j < code_length; + j) { + for (int j = 0; j < code_length; ++j) { size_t index = code.calc_index(j); - op(tmat.data()[i * width + j], vec.data()[index]); + auto t = tmat.data()[i * width + j]; + auto v = vec.data()[index]; + op(t, v); } } } -template -void AddByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat, const framework::Tensor& vec) { - auto op = [](T& t, T& v) { t += v; }; - AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, vec); -} - -template -void AddByBitCodeGrad(size_t num_classes, const framework::Tensor& codes, - const framework::Tensor& tmat, framework::Tensor& vec) { - auto op = [](T& t, T& v) { v += t; }; - AddByBitCode(op, SimpleCodeTable(num_classes), codes, tmat, vec); +template +void SubByBitCodeT(CodeTable code_table, const int64_t* codes, + framework::Tensor& tmat) { + // size_t max_code_length = code_table.get_max_code_length(); + size_t num_samples = tmat.dims()[0]; + size_t o_width = tmat.dims()[1]; + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(codes[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + tmat.data()[i * o_width + j] -= 1; + } + } + } } -template -void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes, - framework::Tensor& tmat, const framework::Tensor& sum, +template +void SumByBitCodeT(CodeTable code_table, const int64_t* codes, + framework::Tensor& tmat, framework::Tensor& sum, const T& scale_sum) { - size_t max_code_length = code_table.get_max_code_length(); + // size_t max_code_length = code_table.get_max_code_length(); size_t num_samples = tmat.dims()[0]; size_t o_width = tmat.dims()[1]; for (size_t i = 0; i < num_samples; ++i) { - T sm = 0; - auto code = code_table(codes.data()[i]); + T sm = static_cast(0.0); + auto code = code_table(static_cast(codes[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { if (code.calc_bit(j)) { @@ -103,105 +105,124 @@ void SumByBitCodeT(CodeTable code_table, const framework::Tensor& codes, sum.data()[i] = scale_sum * sm; } } -/* For j < codeLength: - sum(i, 0) = \sum_j bit(i, j) * input(i, j) -*/ + template -void SumByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat, framework::Tensor& sum, - T scale_sum) { - SumByBitCodeT(SimpleCodeTable(num_classes), codes, tmat, scale_sum); +void MatrixBitCodeFunctor::Add(size_t num_classes, const int64_t* codes, + framework::Tensor& tmat, + const framework::Tensor& vec) { + auto op = [](T& t, const T& v) { t += v; }; + AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, vec); } -template -void MulByBitCodeT(Op op, CodeTable code_table, const framework::Tensor& codes, - framework::Tensor& tmat, framework::Tensor& weight, - framework::Tensor& input) { - size_t num_classes = code_table.size(); - size_t max_code_length = code_table.get_max_code_length(); - size_t num_samples = tmat.dims()[0]; - size_t input_dim = input.dims()[1]; - size_t o_width = tmat.dims()[1]; +template +void MatrixBitCodeFunctor::AddGrad(size_t num_classes, const int64_t* codes, + framework::Tensor& tmat, + framework::Tensor& vec) { + auto op = [](T& t, T& v) { v += t; }; + AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, vec); +} +template +void MatrixBitCodeFunctor::Sum(size_t num_classes, const int64_t* codes, + framework::Tensor& tmat, + framework::Tensor& sum, T scale_sum) { + SumByBitCodeT(SimpleCodeTable(num_classes), codes, tmat, sum, scale_sum); +} + +template +void MatrixBitCodeFunctor::Mul(size_t num_classes, const int64_t* codes, + framework::Tensor& tmat, + const framework::Tensor& weight, + const framework::Tensor& input) { + size_t num_samples = tmat.dims()[0]; + size_t tmat_width = tmat.dims()[1]; + size_t input_width = input.dims()[1]; + size_t weight_width = weight.dims()[1]; + auto tmat_p = tmat.data(); + auto weight_p = weight.data(); + auto input_p = input.data(); + auto code_table = SimpleCodeTable(num_classes); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(codes.data()[i]); + auto code = code_table(static_cast(codes[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code.calc_index(j); - op(tmat.data()[i * o_width + j], - weight.data() + index * weight.dims()[1], - input.data() + i * input.dims()[1], input_dim); + + T sum = static_cast(0.0); + for (size_t k = 0; k < input_width; ++k) { + sum += + weight_p[weight_width * index + k] * input_p[input_width * i + k]; + } + std::cout << sum << std::endl; + tmat_p[i * tmat_width + j] += sum; } } } template -void MulByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat, const framework::Tensor& weight, - const framework::Tensor& input) { - auto op = [](T& t, const T* weight_row, const T* input_row, - size_t input_dim) { - T sum = 0; - for (size_t k = 0; k < input_dim; ++k) { - sum += weight_row[k] * input_row[k]; - } - t += sum; - }; - MulByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, weight, - input); -} +void MatrixBitCodeFunctor::MulGradWeight(size_t num_classes, + const int64_t* codes, + const framework::Tensor& tmat, + framework::Tensor& weight, + const framework::Tensor& input) { + size_t num_samples = tmat.dims()[0]; + size_t input_width = input.dims()[1]; + size_t weight_width = weight.dims()[1]; + auto tmat_p = tmat.data(); + auto weight_p = weight.data(); + auto input_p = input.data(); + auto code_table = SimpleCodeTable(num_classes); + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(codes[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + size_t index = code.calc_index(j); -template -void MulByBitCodeGradWeight(size_t num_classes, const framework::Tensor& codes, - const framework::Tensor& tmat, - framework::Tensor& weight, - const framework::Tensor& input) { - auto op = [](const T t, T* weight_row, const T* input_row, size_t input_dim) { - for (size_t k = 0; k < input_dim; ++k) { - weight_row[k] += t * input_row[k]; + for (size_t k = 0; k < input_width; ++k) { + weight_p[weight_width * index * k] += + tmat_p[i * weight_width * j] * input_p[input_width * i + k]; + } } - }; - MulByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, weight, - input); + } } template -void MulByBitCodeGradError(size_t num_classes, const framework::Tensor& codes, - const framework::Tensor& tmat, - const framework::Tensor& weight, - framework::Tensor& input) { - auto op = [](const T t, const T* weight_row, T* input_row, size_t input_dim) { - for (size_t k = 0; k < input_dim; ++k) { - input_row[k] += t * weight_row[k]; - } - }; - MulByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, weight, - input); -} - -template -void SubByBitCodeT(CodeTable code_table, const framework::Tensor& codes, - framework::Tensor& tmat) { - size_t max_code_length = code_table.get_max_code_length(); +void MatrixBitCodeFunctor::MulGradError(size_t num_classes, + const int64_t* codes, + const framework::Tensor& tmat, + const framework::Tensor& weight, + framework::Tensor& input) { size_t num_samples = tmat.dims()[0]; - size_t o_width = tmat.dims()[1]; + size_t input_width = input.dims()[1]; + size_t weight_width = weight.dims()[1]; + auto tmat_p = tmat.data(); + auto weight_p = weight.data(); + auto input_p = input.data(); + auto code_table = SimpleCodeTable(num_classes); + for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(codes.data()[i]); + auto code = code_table(static_cast(codes[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { - if (code.calc_bit(j)) { - tmat.data()[i * o_width + j] -= 1; + size_t index = code.calc_index(j); + + for (size_t k = 0; k < input_width; ++k) { + input_p[weight_width * index * k] += + tmat_p[i * weight_width * j] * weight_p[weight_width * i + k]; } } } } template -void SubByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat) { +void MatrixBitCodeFunctor::Sub(size_t num_classes, const int64_t* codes, + framework::Tensor& tmat) { SubByBitCodeT(SimpleCodeTable(num_classes), codes, tmat); } +template class MatrixBitCodeFunctor; +template class MatrixBitCodeFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/matrix_bit_code.h b/paddle/operators/math/matrix_bit_code.h index 43c9d43d89..d2ebf182c8 100644 --- a/paddle/operators/math/matrix_bit_code.h +++ b/paddle/operators/math/matrix_bit_code.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace operators { @@ -59,57 +60,50 @@ struct SimpleCodeTable { int max_code_length_; }; -/* For j < code_length - tmat(i, j) += vec(0, index(i, j)) -*/ template -void AddByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat, const framework::Tensor& vec); +class MatrixBitCodeFunctor { + public: + /* For j < code_length + tmat(i, j) += vec(0, index(i, j)) + */ + void Add(size_t num_classes, const int64_t* codes, framework::Tensor& tmat, + const framework::Tensor& vec); -/* For j < code_length - vec(0, index(i, j)) += tmat(i, j) -*/ -template -void AddByBitCodeGrad(size_t num_classes, const framework::Tensor& codes, - const framework::Tensor& tmat, framework::Tensor& vec); -/* For j < code_length + /* For j < code_length + vec(0, index(i, j)) += tmat(i, j) + */ + void AddGrad(size_t num_classes, const int64_t* codes, + framework::Tensor& tmat, framework::Tensor& vec); + + /* For j < code_length sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) -*/ -template -void SumByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat, framework::Tensor& sum, T scale_sum); + */ + void Sum(size_t num_classes, const int64_t* codes, framework::Tensor& tmat, + framework::Tensor& sum, T scale_sum); -/* For j < code_length - input.row(i) += tmat(i, j) * weight.row(index(i, j)) -*/ -template -void MulByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat, const framework::Tensor& weight, - const framework::Tensor& input); + /* For j < code_length + tmat(i, j) -= bit(i, j) + */ + void Sub(size_t num_classes, const int64_t* codes, framework::Tensor& tmat); + /* For j < code_length + input.row(i) += tmat(i, j) * weight.row(index(i, j)) + */ + void Mul(size_t num_classes, const int64_t* codes, framework::Tensor& tmat, + const framework::Tensor& weight, const framework::Tensor& input); -/* For index(i, j) >= 0: - weight.row(index(i, j)) += tmat(i, j) * input.row(i) -*/ -template -void MulByBitCodeGradWeight(size_t num_classes, const framework::Tensor& codes, - const framework::Tensor& tmat, - framework::Tensor& weight, - const framework::Tensor& input); -/* For j < code_length - input.row(i) += tmat(i, j) * weight.row(index(i, j)) -*/ -template -void MulByBitCodeGradError(size_t num_classes, const framework::Tensor& codes, - const framework::Tensor& tmat, - const framework::Tensor& weight, - framework::Tensor& input); - -/* For j < code_length - tmat(i, j) -= bit(i, j) -*/ -template -void SubByBitCode(size_t num_classes, const framework::Tensor& codes, - framework::Tensor& tmat); + /* For index(i, j) >= 0: + weight.row(index(i, j)) += tmat(i, j) * input.row(i) + */ + void MulGradWeight(size_t num_classes, const int64_t* codes, + const framework::Tensor& tmat, framework::Tensor& weight, + const framework::Tensor& input); + /* For j < code_length + input.row(i) += tmat(i, j) * weight.row(index(i, j)) + */ + void MulGradError(size_t num_classes, const int64_t* codes, + const framework::Tensor& tmat, + const framework::Tensor& weight, framework::Tensor& input); +}; } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index c16d3e0cbe..a05fcd0451 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -126,6 +126,8 @@ PYBIND11_PLUGIN(core) { .def("shape", [](Tensor &self) { return vectorize(self.dims()); }) .def("set_float_element", TensorSetElement) .def("get_float_element", TensorGetElement) + .def("set_int64_element", TensorSetElement) + .def("get_int64_element", TensorGetElement) .def("set_double_element", TensorSetElement) .def("get_double_element", TensorGetElement) .def("dtype", [](Tensor &self) { return ToDataType(self.type()); }); diff --git a/python/paddle/v2/fluid/tests/op_test.py b/python/paddle/v2/fluid/tests/op_test.py index e83c4a0622..edf68075be 100644 --- a/python/paddle/v2/fluid/tests/op_test.py +++ b/python/paddle/v2/fluid/tests/op_test.py @@ -49,7 +49,6 @@ def create_op(scope, op_type, inputs, outputs, attrs): for attr_name in Operator.get_op_attr_names(op_type): if attr_name in attrs: kwargs[attr_name] = attrs[attr_name] - return Operator(op_type, **kwargs) @@ -107,6 +106,8 @@ def get_numeric_gradient(scope, tensor_to_check_dtype = np.float32 elif tensor_to_check_dtype == core.DataType.FP64: tensor_to_check_dtype = np.float64 + elif tensor_to_check_dtype == core.DataType.INT64: + tensor_to_check_dtype = np.int64 else: raise ValueError("Not supported data type " + str( tensor_to_check_dtype)) @@ -116,12 +117,16 @@ def get_numeric_gradient(scope, def __get_elem__(tensor, i): if tensor_to_check_dtype == np.float32: return tensor.get_float_element(i) + elif tensor_to_check_dtype == np.int64: + return tensor.get_int64_element(i) else: return tensor.get_double_element(i) def __set_elem__(tensor, i, e): if tensor_to_check_dtype == np.float32: tensor.set_float_element(i, e) + elif tensor_to_check_dtype == np.int64: + tensor.set_int64_element(i, e) else: tensor.set_double_element(i, e) @@ -355,13 +360,11 @@ class OpTest(unittest.TestCase): op_attrs = self.attrs if hasattr(self, "attrs") else dict() self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, op_attrs) - if no_grad_set is None: no_grad_set = set() if not type(output_names) is list: output_names = [output_names] - numeric_grads = user_defined_grads or [ get_numeric_gradient( self.scope, @@ -457,9 +460,7 @@ class OpTest(unittest.TestCase): # infer variable type and infer shape in compile-time op.desc.infer_var_type(block.desc) op.desc.infer_shape(block.desc) - mean_inputs = map(block.var, output_names) - if len(mean_inputs) == 1: loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1]) op = block.append_op( diff --git a/python/paddle/v2/fluid/tests/test_hsigmoid_op.py b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py index 25c13aabe9..194d5e315f 100644 --- a/python/paddle/v2/fluid/tests/test_hsigmoid_op.py +++ b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py @@ -5,15 +5,15 @@ from op_test import OpTest class TestHSigmoidOp(OpTest): def setUp(self): - self.op_type = "hierarchical_sigmoid_op" + self.op_type = "hierarchical_sigmoid" num_classes = 6 embded_size = 10 batch_size = 5 x = np.random.random((batch_size, embded_size)).astype("float32") parameter = np.random.random( (batch_size, num_classes - 1, embded_size)).astype("float32") - label = np.random.randint(0, num_classes, batch_size).astype("int64") - bias = np.random.random((1, num_classes - 1)) + label = np.random.randint(0, num_classes, batch_size) + bias = np.random.random((1, num_classes - 1)).astype("float32") self.inputs = { 'X': x, 'Parameters': parameter, @@ -21,13 +21,18 @@ class TestHSigmoidOp(OpTest): 'Bias': bias } self.attrs = {'num_classes': num_classes} - self.outputs = {'Out': label} + self.outputs = { + 'Out': np.random.random((batch_size, 1)).astype("float32") + } def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['x0'], 'Out') + self.check_grad( + ['X', 'Parameters', 'Label', 'Bias'], + 'Out', + no_grad_set=set(['Label'])) if __name__ == '__main__': From f07164912bca60a36a72dc6ce22f8e00caa99301 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Wed, 3 Jan 2018 20:00:07 +0800 Subject: [PATCH 7/8] fix backward --- paddle/operators/hierarchical_sigmoid_op.cc | 28 +++++++------- paddle/operators/hierarchical_sigmoid_op.h | 38 +++++++++---------- paddle/operators/math/matrix_bit_code.cc | 1 - paddle/pybind/pybind.cc | 2 - python/paddle/v2/fluid/executor.py | 1 - python/paddle/v2/fluid/tests/op_test.py | 2 - .../paddle/v2/fluid/tests/test_hsigmoid_op.py | 16 ++------ 7 files changed, 37 insertions(+), 51 deletions(-) diff --git a/paddle/operators/hierarchical_sigmoid_op.cc b/paddle/operators/hierarchical_sigmoid_op.cc index 4b3487f8b9..bc6ceb9874 100644 --- a/paddle/operators/hierarchical_sigmoid_op.cc +++ b/paddle/operators/hierarchical_sigmoid_op.cc @@ -61,10 +61,8 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Parameters"), - "Input(Parameters)" - "should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); const int64_t batch_size = ctx->GetInputDim("X")[0]; std::vector output_shape({batch_size, 1}); @@ -84,15 +82,17 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Parameters"), - "Input(Parameters)" - "should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Label"), - "Input(Label)" - "should not be null."); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Parameters")), - "Input(Parameters@Grad should not be null.)"); + PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), + "Input(W@Grad should not be null.)"); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X"))); + if (ctx->HasOutput(framework::GradVarName("Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Bias"), + ctx->GetInputDim("Bias")); + } + ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } protected: @@ -112,11 +112,11 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor, required) The input Tensor, which the shape is" "[N * D], which N is the size of mini-batch," "D is the embded size"); - AddInput("Parameters", + AddInput("W", "(Tensor, required), The parameters of hierarchical " "sigmoid operator, each of them is s a 3-D tensor, the shape is" "[N, num_classes - 1, D]"); - AddInput("Label", + AddInput("Ids", "(Tensor, required), The labels of training data. It's a" "1-D tensor, which the shape is [1, N]"); AddInput("Bias", diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index 531fd9f7fc..1b8d21c095 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -32,15 +32,14 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); - auto* params = ctx.Input("Parameters"); - auto* label = ctx.Input("Label"); + auto* w = ctx.Input("W"); + auto* ids = ctx.Input("Ids"); auto* bias = ctx.Input("Bias"); auto* out = ctx.Output("Out"); size_t num_classes = static_cast(ctx.Attr("num_classes")); int64_t code_length = math::FindLastSet(num_classes - 1); int64_t batch_size = in->dims()[0]; - auto* ids = label->data(); framework::Tensor pre_out; framework::Tensor sum; auto pre_out_data = pre_out.mutable_data( @@ -59,18 +58,19 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { auto out_mat = framework::EigenVector::Flatten(*out); if (bias) { - bit_code.Add(num_classes, ids, pre_out, *bias); + bit_code.Add(num_classes, ids->data(), pre_out, *bias); } for (int i = 0; i < in->dims()[0]; ++i) { - bit_code.Mul(num_classes, ids, pre_out, params->Slice(i, i + 1), - in->Slice(i, i + 1)); + bit_code.Mul(num_classes, ids->data(), pre_out, + w->Slice(i, i + 1), in->Slice(i, i + 1)); } // clip the matrix with (-40, 40) Transform trans; trans(ctx.template device_context(), pre_out_data, pre_out_data + pre_out.numel(), pre_out_data, ClipFunctor(static_cast(-40.0), static_cast(40.0))); - bit_code.Sum(num_classes, ids, pre_out, *out, static_cast(-1)); + bit_code.Sum(num_classes, ids->data(), pre_out, *out, + static_cast(-1)); // softrelu with threshold is 40.0 trans(ctx.template device_context(), pre_out_data, pre_out_data + pre_out.numel(), pre_out_data, @@ -88,10 +88,9 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* in = ctx.Input("X"); auto* in_grad = ctx.Output(framework::GradVarName("X")); - auto* params = - ctx.Output(framework::GradVarName("Parameters")); + auto* w = ctx.Output(framework::GradVarName("W")); auto* bias = ctx.Output(framework::GradVarName("Bias")); - auto* label = ctx.Input("Label"); + auto* ids = ctx.Input("Ids"); size_t num_classes = static_cast(ctx.Attr("num_classes")); int64_t code_length = math::FindLastSet(num_classes - 1); int64_t batch_size = in->dims()[0]; @@ -102,8 +101,6 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto& place = *ctx.template device_context().eigen_device(); auto& device_ctx = ctx.template device_context(); auto pre_out_mat = EigenMatrix::From(pre_out); - auto* ids = label->data(); - // init pre_out matrix with {1.0} math::SetConstant one; math::MatrixBitCodeFunctor bit_code; @@ -112,19 +109,22 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { pre_out_mat.device(place) = pre_out_mat * (static_cast(1.0) - static_cast(1.0) / pre_out_mat); - bit_code.Sub(num_classes, ids, pre_out); + bit_code.Sub(num_classes, ids->data(), pre_out); if (bias) { - bit_code.AddGrad(num_classes, ids, pre_out, *bias); + bias->mutable_data(ctx.GetPlace()); + bit_code.AddGrad(num_classes, ids->data(), pre_out, *bias); } - + in_grad->mutable_data(ctx.GetPlace()); + w->mutable_data(ctx.GetPlace()); for (int i = 0; i < in_grad->dims()[0]; ++i) { - auto p_sliced = params->Slice(i, i + 1); + auto p_sliced = w->Slice(i, i + 1); auto in_sliced = in->Slice(i, i + 1); auto in_grad_sliced = in_grad->Slice(i, i + 1); - bit_code.MulGradWeight(num_classes, ids, pre_out, p_sliced, in_sliced); - bit_code.MulGradError(num_classes, ids, pre_out, p_sliced, - in_grad_sliced); + bit_code.MulGradWeight(num_classes, ids->data(), pre_out, + p_sliced, in_sliced); + bit_code.MulGradError(num_classes, ids->data(), pre_out, + p_sliced, in_grad_sliced); } } }; diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index 4ad0a00008..b192183b10 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -56,7 +56,6 @@ static void AddByBitCodeT(Op op, CodeTable code_table, const int64_t* codes, const framework::Tensor& vec) { size_t num_sample = tmat.dims()[0]; size_t width = vec.dims()[1]; - for (size_t i = 0; i < num_sample; ++i) { auto code = code_table(static_cast(codes[i])); int code_length = code.get_length(); diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 921b316a69..de6b24f70d 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -109,8 +109,6 @@ PYBIND11_PLUGIN(core) { .def("shape", [](Tensor &self) { return vectorize(self.dims()); }) .def("set_float_element", TensorSetElement) .def("get_float_element", TensorGetElement) - .def("set_int64_element", TensorSetElement) - .def("get_int64_element", TensorGetElement) .def("set_double_element", TensorSetElement) .def("get_double_element", TensorGetElement) .def("dtype", [](Tensor &self) { return ToDataType(self.type()); }); diff --git a/python/paddle/v2/fluid/executor.py b/python/paddle/v2/fluid/executor.py index cdd576294f..a054d5eafb 100644 --- a/python/paddle/v2/fluid/executor.py +++ b/python/paddle/v2/fluid/executor.py @@ -148,7 +148,6 @@ class Executor(object): inputs={'X': [var]}, outputs={'Out': [fetch_var]}, attrs={'col': i}) - self.executor.run(program.desc, scope, 0, True, True) outs = [ core.get_fetch_variable(scope, fetch_var_name, i) diff --git a/python/paddle/v2/fluid/tests/op_test.py b/python/paddle/v2/fluid/tests/op_test.py index 287dc29804..0493a0c206 100644 --- a/python/paddle/v2/fluid/tests/op_test.py +++ b/python/paddle/v2/fluid/tests/op_test.py @@ -123,8 +123,6 @@ def get_numeric_gradient(scope, def __set_elem__(tensor, i, e): if tensor_to_check_dtype == np.float32: tensor.set_float_element(i, e) - elif tensor_to_check_dtype == np.int64: - tensor.set_int64_element(i, e) else: tensor.set_double_element(i, e) diff --git a/python/paddle/v2/fluid/tests/test_hsigmoid_op.py b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py index 194d5e315f..b6d961b631 100644 --- a/python/paddle/v2/fluid/tests/test_hsigmoid_op.py +++ b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py @@ -10,16 +10,11 @@ class TestHSigmoidOp(OpTest): embded_size = 10 batch_size = 5 x = np.random.random((batch_size, embded_size)).astype("float32") - parameter = np.random.random( + w = np.random.random( (batch_size, num_classes - 1, embded_size)).astype("float32") - label = np.random.randint(0, num_classes, batch_size) + ids = np.random.randint(0, num_classes, batch_size) bias = np.random.random((1, num_classes - 1)).astype("float32") - self.inputs = { - 'X': x, - 'Parameters': parameter, - 'Label': label, - 'Bias': bias - } + self.inputs = {'X': x, 'W': w, 'Ids': ids, 'Bias': bias} self.attrs = {'num_classes': num_classes} self.outputs = { 'Out': np.random.random((batch_size, 1)).astype("float32") @@ -29,10 +24,7 @@ class TestHSigmoidOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad( - ['X', 'Parameters', 'Label', 'Bias'], - 'Out', - no_grad_set=set(['Label'])) + self.check_grad(['X', 'W', 'Bias'], 'Out', no_grad_set=set('Ids')) if __name__ == '__main__': From 80ce7edbb79244f6946cf38e233d3914ef40ddf5 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 9 Jan 2018 20:26:37 +0800 Subject: [PATCH 8/8] make farward correct --- paddle/operators/hierarchical_sigmoid_op.cc | 4 +- paddle/operators/hierarchical_sigmoid_op.h | 35 ++-- paddle/operators/math/math_function_impl.h | 8 +- paddle/operators/math/matrix_bit_code.cc | 156 ++++++++---------- paddle/operators/math/matrix_bit_code.h | 25 ++- python/paddle/v2/fluid/tests/op_test.py | 9 +- .../paddle/v2/fluid/tests/test_hsigmoid_op.py | 70 +++++++- 7 files changed, 170 insertions(+), 137 deletions(-) diff --git a/paddle/operators/hierarchical_sigmoid_op.cc b/paddle/operators/hierarchical_sigmoid_op.cc index bc6ceb9874..e2ba65d6f9 100644 --- a/paddle/operators/hierarchical_sigmoid_op.cc +++ b/paddle/operators/hierarchical_sigmoid_op.cc @@ -70,7 +70,7 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetKernelType( + framework::OpKernelType GetActualKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), @@ -96,7 +96,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { } protected: - framework::OpKernelType GetKernelType( + framework::OpKernelType GetActualKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), diff --git a/paddle/operators/hierarchical_sigmoid_op.h b/paddle/operators/hierarchical_sigmoid_op.h index 1b8d21c095..f5b1b97169 100644 --- a/paddle/operators/hierarchical_sigmoid_op.h +++ b/paddle/operators/hierarchical_sigmoid_op.h @@ -49,34 +49,31 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { auto& place = *ctx.template device_context().eigen_device(); auto& device_ctx = ctx.template device_context(); math::RowwiseSum row_sum; - math::MatrixBitCodeFunctor bit_code; + math::MatrixBitCodeFunctor bit_code(num_classes, ids->data()); std::vector sum_dims({batch_size, 1UL}); sum.mutable_data(framework::make_ddim(sum_dims), ctx.GetPlace()); auto sum_mat = EigenMatrix::From(sum); out->mutable_data(ctx.GetPlace()); auto out_mat = framework::EigenVector::Flatten(*out); - if (bias) { - bit_code.Add(num_classes, ids->data(), pre_out, *bias); + bit_code.Add(pre_out, *bias); } - for (int i = 0; i < in->dims()[0]; ++i) { - bit_code.Mul(num_classes, ids->data(), pre_out, - w->Slice(i, i + 1), in->Slice(i, i + 1)); + for (int64_t i = 0; i < batch_size; ++i) { + auto w_i = w->Slice(i, i + 1); + bit_code.Mul(pre_out, w_i, *in); } // clip the matrix with (-40, 40) Transform trans; trans(ctx.template device_context(), pre_out_data, pre_out_data + pre_out.numel(), pre_out_data, ClipFunctor(static_cast(-40.0), static_cast(40.0))); - bit_code.Sum(num_classes, ids->data(), pre_out, *out, - static_cast(-1)); + bit_code.Sum(pre_out, *out, static_cast(-1)); // softrelu with threshold is 40.0 trans(ctx.template device_context(), pre_out_data, pre_out_data + pre_out.numel(), pre_out_data, ClipFunctor(static_cast(-40.0), static_cast(40.0))); pre_out_mat.device(place) = (static_cast(1.0) + pre_out_mat.exp()).log(); - row_sum(device_ctx, pre_out, &sum); out_mat.device(place) = sum_mat + out_mat; } @@ -103,28 +100,26 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto pre_out_mat = EigenMatrix::From(pre_out); // init pre_out matrix with {1.0} math::SetConstant one; - math::MatrixBitCodeFunctor bit_code; + math::MatrixBitCodeFunctor bit_code(num_classes, ids->data()); one(device_ctx, &pre_out, static_cast(1.0)); // softrelu derivative pre_out_mat.device(place) = pre_out_mat * (static_cast(1.0) - static_cast(1.0) / pre_out_mat); - bit_code.Sub(num_classes, ids->data(), pre_out); + bit_code.Sub(pre_out); if (bias) { bias->mutable_data(ctx.GetPlace()); - bit_code.AddGrad(num_classes, ids->data(), pre_out, *bias); + bit_code.AddGrad(pre_out, *bias); } in_grad->mutable_data(ctx.GetPlace()); w->mutable_data(ctx.GetPlace()); - for (int i = 0; i < in_grad->dims()[0]; ++i) { - auto p_sliced = w->Slice(i, i + 1); - auto in_sliced = in->Slice(i, i + 1); - auto in_grad_sliced = in_grad->Slice(i, i + 1); - bit_code.MulGradWeight(num_classes, ids->data(), pre_out, - p_sliced, in_sliced); - bit_code.MulGradError(num_classes, ids->data(), pre_out, - p_sliced, in_grad_sliced); + for (int i = 0; i < batch_size; ++i) { + auto w_i = w->Slice(i, i + 1); + // auto in_i = in->Slice(i, i + 1); + // auto in_grad_i = in_grad->Slice(i, i + 1); + bit_code.MulGradWeight(pre_out, w_i, *in); + bit_code.MulGradError(pre_out, w_i, *in_grad); } } }; diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index 98722ff5d2..63fb7182df 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -62,13 +62,13 @@ void ColwiseSum::operator()(const DeviceContext& context, template void RowwiseSum::operator()(const DeviceContext& context, const framework::Tensor& input, - framework::Tensor* vector) { + framework::Tensor* out) { auto in_dims = input.dims(); auto size = input.numel() / in_dims[1]; - PADDLE_ENFORCE_EQ(vector->numel(), size); + PADDLE_ENFORCE_EQ(out->numel(), size); - auto in = framework::EigenMatrix::From(input); - auto vec = framework::EigenMatrix::From(*vector); + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenVector::Flatten(*out); vec.device(*context.eigen_device()) = in.sum(Eigen::array({{1}})); } diff --git a/paddle/operators/math/matrix_bit_code.cc b/paddle/operators/math/matrix_bit_code.cc index b192183b10..34f5f6ef61 100644 --- a/paddle/operators/math/matrix_bit_code.cc +++ b/paddle/operators/math/matrix_bit_code.cc @@ -22,7 +22,7 @@ namespace math { * CodeTable class should support 3 functions: * * size_t size() - * return the number of codes + * return the number of ids * * int getMaxCodeLength() * return the maximal code length @@ -45,56 +45,47 @@ namespace math { * */ -/* - for i: - for j < codeLength: - op(a(i, j), b(0, index(i, j))) -*/ -template -static void AddByBitCodeT(Op op, CodeTable code_table, const int64_t* codes, - const framework::Tensor& tmat, - const framework::Tensor& vec) { - size_t num_sample = tmat.dims()[0]; - size_t width = vec.dims()[1]; - for (size_t i = 0; i < num_sample; ++i) { - auto code = code_table(static_cast(codes[i])); +template +void MatrixBitCodeFunctor::Add(framework::Tensor& tmat, + const framework::Tensor& vec) { + SimpleCodeTable code_table(num_classes_); + size_t batch_size = tmat.dims()[0]; + size_t width = tmat.dims()[1]; + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table(static_cast(ids_[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code.calc_index(j); - auto t = tmat.data()[i * width + j]; - auto v = vec.data()[index]; - op(t, v); + tmat.data()[i * width + j] += vec.data()[index]; } } } -template -void SubByBitCodeT(CodeTable code_table, const int64_t* codes, - framework::Tensor& tmat) { - // size_t max_code_length = code_table.get_max_code_length(); - size_t num_samples = tmat.dims()[0]; - size_t o_width = tmat.dims()[1]; - for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(codes[i])); +template +void MatrixBitCodeFunctor::AddGrad(framework::Tensor& tmat, + framework::Tensor& vec) { + SimpleCodeTable code_table(num_classes_); + size_t batch_size = tmat.dims()[0]; + size_t width = tmat.dims()[1]; + for (size_t i = 0; i < batch_size; ++i) { + auto code = code_table(static_cast(ids_[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { - if (code.calc_bit(j)) { - tmat.data()[i * o_width + j] -= 1; - } + size_t index = code.calc_index(j); + vec.data()[index] += tmat.data()[i * width + j]; } } } -template -void SumByBitCodeT(CodeTable code_table, const int64_t* codes, - framework::Tensor& tmat, framework::Tensor& sum, - const T& scale_sum) { - // size_t max_code_length = code_table.get_max_code_length(); +template +void MatrixBitCodeFunctor::Sum(framework::Tensor& tmat, + framework::Tensor& sum, T scale_sum) { + SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t o_width = tmat.dims()[1]; for (size_t i = 0; i < num_samples; ++i) { T sm = static_cast(0.0); - auto code = code_table(static_cast(codes[i])); + auto code = code_table(static_cast(ids_[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { if (code.calc_bit(j)) { @@ -106,116 +97,99 @@ void SumByBitCodeT(CodeTable code_table, const int64_t* codes, } template -void MatrixBitCodeFunctor::Add(size_t num_classes, const int64_t* codes, - framework::Tensor& tmat, - const framework::Tensor& vec) { - auto op = [](T& t, const T& v) { t += v; }; - AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, vec); -} - -template -void MatrixBitCodeFunctor::AddGrad(size_t num_classes, const int64_t* codes, - framework::Tensor& tmat, - framework::Tensor& vec) { - auto op = [](T& t, T& v) { v += t; }; - AddByBitCodeT(op, SimpleCodeTable(num_classes), codes, tmat, vec); -} - -template -void MatrixBitCodeFunctor::Sum(size_t num_classes, const int64_t* codes, - framework::Tensor& tmat, - framework::Tensor& sum, T scale_sum) { - SumByBitCodeT(SimpleCodeTable(num_classes), codes, tmat, sum, scale_sum); -} - -template -void MatrixBitCodeFunctor::Mul(size_t num_classes, const int64_t* codes, - framework::Tensor& tmat, +void MatrixBitCodeFunctor::Mul(framework::Tensor& tmat, const framework::Tensor& weight, const framework::Tensor& input) { + SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t tmat_width = tmat.dims()[1]; size_t input_width = input.dims()[1]; - size_t weight_width = weight.dims()[1]; - auto tmat_p = tmat.data(); - auto weight_p = weight.data(); - auto input_p = input.data(); - auto code_table = SimpleCodeTable(num_classes); + size_t weight_width = weight.dims()[2]; + auto tmat_value = tmat.data(); + auto weight_value = weight.data(); + auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(codes[i])); + auto code = code_table(static_cast(ids_[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code.calc_index(j); T sum = static_cast(0.0); for (size_t k = 0; k < input_width; ++k) { - sum += - weight_p[weight_width * index + k] * input_p[input_width * i + k]; + sum += weight_value[weight_width * index + k] * + input_value[input_width * i + k]; } - tmat_p[i * tmat_width + j] += sum; + tmat_value[i * tmat_width + j] += sum; } } } template -void MatrixBitCodeFunctor::MulGradWeight(size_t num_classes, - const int64_t* codes, - const framework::Tensor& tmat, +void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, framework::Tensor& weight, const framework::Tensor& input) { + SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t weight_width = weight.dims()[1]; - auto tmat_p = tmat.data(); - auto weight_p = weight.data(); - auto input_p = input.data(); - auto code_table = SimpleCodeTable(num_classes); + auto tmat_value = tmat.data(); + auto weight_value = weight.data(); + auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(codes[i])); + auto code = code_table(static_cast(ids_[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code.calc_index(j); for (size_t k = 0; k < input_width; ++k) { - weight_p[weight_width * index * k] += - tmat_p[i * weight_width * j] * input_p[input_width * i + k]; + weight_value[weight_width * index * k] += + tmat_value[i * weight_width * j] * input_value[input_width * i + k]; } } } } template -void MatrixBitCodeFunctor::MulGradError(size_t num_classes, - const int64_t* codes, - const framework::Tensor& tmat, +void MatrixBitCodeFunctor::MulGradError(const framework::Tensor& tmat, const framework::Tensor& weight, framework::Tensor& input) { + SimpleCodeTable code_table(num_classes_); size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t weight_width = weight.dims()[1]; - auto tmat_p = tmat.data(); - auto weight_p = weight.data(); - auto input_p = input.data(); - auto code_table = SimpleCodeTable(num_classes); + auto tmat_value = tmat.data(); + auto weight_value = weight.data(); + auto input_value = input.data(); for (size_t i = 0; i < num_samples; ++i) { - auto code = code_table(static_cast(codes[i])); + auto code = code_table(static_cast(ids_[i])); int code_length = code.get_length(); for (int j = 0; j < code_length; ++j) { size_t index = code.calc_index(j); for (size_t k = 0; k < input_width; ++k) { - input_p[weight_width * index * k] += - tmat_p[i * weight_width * j] * weight_p[weight_width * i + k]; + input_value[weight_width * index * k] += + tmat_value[i * weight_width * j] * + weight_value[weight_width * i + k]; } } } } template -void MatrixBitCodeFunctor::Sub(size_t num_classes, const int64_t* codes, - framework::Tensor& tmat) { - SubByBitCodeT(SimpleCodeTable(num_classes), codes, tmat); +void MatrixBitCodeFunctor::Sub(framework::Tensor& tmat) { + SimpleCodeTable code_table(num_classes_); + size_t num_samples = tmat.dims()[0]; + size_t o_width = tmat.dims()[1]; + for (size_t i = 0; i < num_samples; ++i) { + auto code = code_table(static_cast(ids_[i])); + int code_length = code.get_length(); + for (int j = 0; j < code_length; ++j) { + if (code.calc_bit(j)) { + tmat.data()[i * o_width + j] -= 1; + } + } + } } template class MatrixBitCodeFunctor; diff --git a/paddle/operators/math/matrix_bit_code.h b/paddle/operators/math/matrix_bit_code.h index d2ebf182c8..43c676f5cc 100644 --- a/paddle/operators/math/matrix_bit_code.h +++ b/paddle/operators/math/matrix_bit_code.h @@ -63,46 +63,45 @@ struct SimpleCodeTable { template class MatrixBitCodeFunctor { public: + explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) + : num_classes_(num_classes), ids_(ids) {} /* For j < code_length tmat(i, j) += vec(0, index(i, j)) */ - void Add(size_t num_classes, const int64_t* codes, framework::Tensor& tmat, - const framework::Tensor& vec); + void Add(framework::Tensor& tmat, const framework::Tensor& vec); /* For j < code_length vec(0, index(i, j)) += tmat(i, j) */ - void AddGrad(size_t num_classes, const int64_t* codes, - framework::Tensor& tmat, framework::Tensor& vec); + void AddGrad(framework::Tensor& tmat, framework::Tensor& vec); /* For j < code_length sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) */ - void Sum(size_t num_classes, const int64_t* codes, framework::Tensor& tmat, - framework::Tensor& sum, T scale_sum); + void Sum(framework::Tensor& tmat, framework::Tensor& sum, T scale_sum); /* For j < code_length tmat(i, j) -= bit(i, j) */ - void Sub(size_t num_classes, const int64_t* codes, framework::Tensor& tmat); + void Sub(framework::Tensor& tmat); /* For j < code_length input.row(i) += tmat(i, j) * weight.row(index(i, j)) */ - void Mul(size_t num_classes, const int64_t* codes, framework::Tensor& tmat, - const framework::Tensor& weight, const framework::Tensor& input); + void Mul(framework::Tensor& tmat, const framework::Tensor& weight, + const framework::Tensor& input); /* For index(i, j) >= 0: weight.row(index(i, j)) += tmat(i, j) * input.row(i) */ - void MulGradWeight(size_t num_classes, const int64_t* codes, - const framework::Tensor& tmat, framework::Tensor& weight, + void MulGradWeight(const framework::Tensor& tmat, framework::Tensor& weight, const framework::Tensor& input); /* For j < code_length input.row(i) += tmat(i, j) * weight.row(index(i, j)) */ - void MulGradError(size_t num_classes, const int64_t* codes, - const framework::Tensor& tmat, + void MulGradError(const framework::Tensor& tmat, const framework::Tensor& weight, framework::Tensor& input); + size_t num_classes_; + const int64_t* ids_; }; } // namespace math } // namespace operators diff --git a/python/paddle/v2/fluid/tests/op_test.py b/python/paddle/v2/fluid/tests/op_test.py index acc42fd3b3..b77d2b1268 100644 --- a/python/paddle/v2/fluid/tests/op_test.py +++ b/python/paddle/v2/fluid/tests/op_test.py @@ -49,6 +49,7 @@ def create_op(scope, op_type, inputs, outputs, attrs): for attr_name in Operator.get_op_attr_names(op_type): if attr_name in attrs: kwargs[attr_name] = attrs[attr_name] + return Operator(op_type, **kwargs) @@ -104,8 +105,6 @@ def get_numeric_gradient(scope, tensor_to_check_dtype = np.float32 elif tensor_to_check_dtype == core.DataType.FP64: tensor_to_check_dtype = np.float64 - elif tensor_to_check_dtype == core.DataType.INT64: - tensor_to_check_dtype = np.int64 else: raise ValueError("Not supported data type " + str( tensor_to_check_dtype)) @@ -115,8 +114,6 @@ def get_numeric_gradient(scope, def __get_elem__(tensor, i): if tensor_to_check_dtype == np.float32: return tensor.get_float_element(i) - elif tensor_to_check_dtype == np.int64: - return tensor.get_int64_element(i) else: return tensor.get_double_element(i) @@ -356,11 +353,13 @@ class OpTest(unittest.TestCase): op_attrs = self.attrs if hasattr(self, "attrs") else dict() self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, op_attrs) + if no_grad_set is None: no_grad_set = set() if not type(output_names) is list: output_names = [output_names] + numeric_grads = user_defined_grads or [ get_numeric_gradient( self.scope, @@ -456,7 +455,9 @@ class OpTest(unittest.TestCase): # infer variable type and infer shape in compile-time op.desc.infer_var_type(block.desc) op.desc.infer_shape(block.desc) + mean_inputs = map(block.var, output_names) + if len(mean_inputs) == 1: loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1]) op = block.append_op( diff --git a/python/paddle/v2/fluid/tests/test_hsigmoid_op.py b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py index b6d961b631..41e95e4363 100644 --- a/python/paddle/v2/fluid/tests/test_hsigmoid_op.py +++ b/python/paddle/v2/fluid/tests/test_hsigmoid_op.py @@ -1,6 +1,71 @@ import unittest import numpy as np from op_test import OpTest +import math + + +def find_latest_set(num): + return 1 + int(math.floor(math.log(num, 2))) + + +class CodeTable(object): + def __init__(self, num_classes, code): + self.c = num_classes + code + + def cal_index(self, bit): + return (self.c >> (bit + 1)) - 1 + + def get_length(self): + return find_latest_set(self.c) - 1 + + def cal_bit(self, bit): + return self.c & (1 << bit) + + +def hsigmoid(x, w, ids, bias, num_classes): + # code length = + # initialize pre out with dims={batch_size, code_length} + batch_size = x.shape[0] + code_length = find_latest_set(num_classes - 1) + code_table = [0 for _ in range(code_length)] + pre_output = np.zeros((batch_size, code_length)) + pre_sum = np.zeros((batch_size, 1)) + out = np.zeros((batch_size, 1)).astype("float32") + # pre_out += code(bias) + for i in xrange(batch_size): + code_table = CodeTable(num_classes, ids[i]) + length = code_table.get_length() + for j in xrange(length): + idx = code_table.cal_index(j) + pre_output[i][j] += bias[0][idx] + # pre_out += code(w) * x + for i in xrange(batch_size): + for j in xrange(batch_size): + code_table = CodeTable(num_classes, ids[j]) + length = code_table.get_length() + for k in xrange(length): + idx = code_table.cal_index(k) + sum = 0.0 + for l in xrange(x.shape[1]): + sum += w[i][idx][l] * x[j][l] + pre_output[j][k] += sum + # clip[-40.0, 40.0] + np.clip(pre_output, -40.0, 40.0) + # out(i, 0) = \sum_j bit(i, j) * preout(i, j) + for i in xrange(batch_size): + code_table = CodeTable(num_classes, ids[i]) + length = code_table.get_length() + sum = 0.0 + for j in xrange(length): + if code_table.cal_bit(j): + sum += pre_output[i][j] + out[i] = -1.0 * sum + # soft relu + np.clip(pre_output, -40.0, 40.0) + pre_output = np.log(1 + np.exp(pre_output)) + pre_sum = pre_output.sum(1).reshape((batch_size, 1)) + out += pre_sum + return out class TestHSigmoidOp(OpTest): @@ -16,9 +81,8 @@ class TestHSigmoidOp(OpTest): bias = np.random.random((1, num_classes - 1)).astype("float32") self.inputs = {'X': x, 'W': w, 'Ids': ids, 'Bias': bias} self.attrs = {'num_classes': num_classes} - self.outputs = { - 'Out': np.random.random((batch_size, 1)).astype("float32") - } + out = hsigmoid(x, w, ids, bias, num_classes) + self.outputs = {'Out': out} def test_check_output(self): self.check_output()