Merge branch 'fix-optimizer-accumulator' of ssh://github.com/jacquesqiao/Paddle into distribute-transpiler-handle-adam-accumulator
	
		
	
				
					
				
			
						commit
						39d88ebc02
					
				| @ -0,0 +1,167 @@ | ||||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #include "paddle/fluid/operators/hierarchical_sigmoid_op.h" | ||||
| #include <vector> | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| 
 | ||||
| /**
 | ||||
|  * Organize the classes into a binary tree. At each node, a sigmoid function | ||||
|  * is used to calculate the probability of belonging to the right branch. | ||||
|  * This idea is from "F. Morin, Y. Bengio (AISTATS 05): | ||||
|  * Hierarchical Probabilistic Neural Network Language Model." | ||||
|  * | ||||
|  * Here we uses a simple way of making the binary tree. | ||||
|  * Assuming the number of classes C = 6, | ||||
|  * The classes are organized as a binary tree in the following way: | ||||
|  * | ||||
|  * @code{.py} | ||||
|  * *-*-*- 2 | ||||
|  * | | |- 3 | ||||
|  * | | | ||||
|  * | |-*- 4 | ||||
|  * |   |- 5 | ||||
|  * | | ||||
|  * |-*- 0 | ||||
|  *   |- 1 | ||||
|  * @endcode | ||||
|  * | ||||
|  * where * indicates an internal node, and each leaf node represents a class. | ||||
|  * - Node 0 ... C-2 are internal nodes. | ||||
|  * - Node C-1 ... 2C-2 are leaf nodes. | ||||
|  * - Class c is represented by leaf node \f$c+C-1\f$. | ||||
|  * | ||||
|  * We assign an id for each node: | ||||
|  * - the id of root be 0. | ||||
|  * - the left child of a node i is 2*i+1. | ||||
|  * - the right child of a node i is 2*i+2. | ||||
|  * | ||||
|  * It's easy to see that: | ||||
|  * - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$. | ||||
|  * - the j-th level ancestor of node i is | ||||
|  * \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$. | ||||
|  * - A node i is a left child of its parent if \f$(i-1)\%2==0\f$. | ||||
|  * | ||||
|  */ | ||||
| 
 | ||||
| class HierarchicalSigmoidOp : public framework::OperatorWithKernel { | ||||
|  public: | ||||
|   using framework::OperatorWithKernel::OperatorWithKernel; | ||||
|   void InferShape(framework::InferShapeContext* ctx) const override { | ||||
|     PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); | ||||
|     PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); | ||||
|     PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); | ||||
|     PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); | ||||
|     PADDLE_ENFORCE(ctx->HasOutput("PreOut"), | ||||
|                    "Output(PreOut) should not be null."); | ||||
|     const int64_t batch_size = ctx->GetInputDim("X")[0]; | ||||
|     std::vector<int64_t> output_shape({batch_size, 1}); | ||||
|     ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   framework::OpKernelType GetExpectedKernelType( | ||||
|       const framework::ExecutionContext& ctx) const override { | ||||
|     return framework::OpKernelType( | ||||
|         framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), | ||||
|         ctx.GetPlace()); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| template <typename AttrType> | ||||
| class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { | ||||
|  public: | ||||
|   void Make() override { | ||||
|     AddInput("X", | ||||
|              "(Tensor, required) The input tensor with shape [N, D], " | ||||
|              "where N is the size of mini-batch, and D is the feature size."); | ||||
|     AddInput("W", | ||||
|              "(Tensor, required), The parameters of hierarchical " | ||||
|              "sigmoid operator, each of them is a 2-D tensor, the shape is" | ||||
|              "[num_classes - 1, D]."); | ||||
|     AddInput("Label", | ||||
|              "(Tensor, required), The labels of training data. It's a" | ||||
|              "tensor with shape [N, 1]."); | ||||
|     AddInput("Bias", | ||||
|              "(Tensor, optional), The bias is a tensor with shape" | ||||
|              "[1, num_classes - 1]."); | ||||
|     AddOutput("Out", | ||||
|               "(Tensor, required) The output of hierarchical sigmoid operator." | ||||
|               "The shape is [N, 1]."); | ||||
|     AddOutput("PreOut", | ||||
|               "(Tensor, required) A intermedia 2-D tensor with shape " | ||||
|               "[batch_size, code_length], where code_length represents the " | ||||
|               "maximum path length from root to leaf nodes.") | ||||
|         .AsIntermediate(); | ||||
|     AddAttr<AttrType>("num_classes", "(int, required), The number of classes") | ||||
|         .SetDefault(2); | ||||
|     AddComment(R"DOC( | ||||
| The hierarchical sigmoid operator organize the classes into a binary tree. | ||||
| At each node, a sigmoid function is used to calculate the probability of | ||||
| belonging to the right branch. This idea is from | ||||
| "F. Morin, Y. Bengio (AISTATS 05): | ||||
| Hierarchical Probabilistic Neural Network Language Model." | ||||
|       )DOC"); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { | ||||
|  public: | ||||
|   using framework::OperatorWithKernel::OperatorWithKernel; | ||||
|   void InferShape(framework::InferShapeContext* ctx) const override { | ||||
|     PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null."); | ||||
|     PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null."); | ||||
|     PADDLE_ENFORCE(ctx->HasInput("PreOut"), | ||||
|                    "Input(Preout) should not be null."); | ||||
|     PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")), | ||||
|                    "Output(W@Grad should not be null.)"); | ||||
|     PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X"))); | ||||
|     if (ctx->HasOutput(framework::GradVarName("Bias"))) { | ||||
|       ctx->SetOutputDim(framework::GradVarName("Bias"), | ||||
|                         ctx->GetInputDim("Bias")); | ||||
|     } | ||||
|     ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); | ||||
|     ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   framework::OpKernelType GetExpectedKernelType( | ||||
|       const framework::ExecutionContext& ctx) const override { | ||||
|     return framework::OpKernelType( | ||||
|         framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), | ||||
|         ctx.GetPlace()); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| 
 | ||||
| namespace ops = paddle::operators; | ||||
| REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp, | ||||
|                   ops::HierarchicalSigmoidOpMaker<int>, | ||||
|                   paddle::framework::DefaultGradOpDescMaker<true>); | ||||
| REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp); | ||||
| REGISTER_OP_CPU_KERNEL( | ||||
|     hierarchical_sigmoid, | ||||
|     ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>, | ||||
|     ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, | ||||
|                                      double>); | ||||
| REGISTER_OP_CPU_KERNEL( | ||||
|     hierarchical_sigmoid_grad, | ||||
|     ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext, | ||||
|                                          float>, | ||||
|     ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext, | ||||
|                                          double>); | ||||
| @ -0,0 +1,135 @@ | ||||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| #include <iostream> | ||||
| #include <vector> | ||||
| #include "paddle/fluid/framework/op_registry.h" | ||||
| #include "paddle/fluid/operators/clip_op.h" | ||||
| #include "paddle/fluid/operators/math/math_function.h" | ||||
| #include "paddle/fluid/operators/math/matrix_bit_code.h" | ||||
| #include "paddle/fluid/platform/transform.h" | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| 
 | ||||
| template <typename T, int MajorType = Eigen::RowMajor, | ||||
|           typename IndexType = Eigen::DenseIndex> | ||||
| using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; | ||||
| using platform::Transform; | ||||
| 
 | ||||
| template <typename DeviceContext, typename T> | ||||
| class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { | ||||
|  public: | ||||
|   void Compute(const framework::ExecutionContext& ctx) const override { | ||||
|     auto* in = ctx.Input<framework::Tensor>("X"); | ||||
|     auto* w = ctx.Input<framework::Tensor>("W"); | ||||
|     auto* label = ctx.Input<framework::Tensor>("Label"); | ||||
|     auto* bias = ctx.Input<framework::Tensor>("Bias"); | ||||
|     auto* out = ctx.Output<framework::Tensor>("Out"); | ||||
|     auto* pre_out = ctx.Output<framework::Tensor>("PreOut"); | ||||
|     size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); | ||||
|     int64_t code_length = math::FindLastSet(num_classes - 1); | ||||
|     int64_t batch_size = in->dims()[0]; | ||||
|     framework::Tensor sum; | ||||
|     auto& dev_ctx = ctx.template device_context<DeviceContext>(); | ||||
|     auto* pre_out_data = pre_out->mutable_data<T>( | ||||
|         framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); | ||||
|     auto pre_out_mat = EigenMatrix<T>::From(*pre_out); | ||||
|     // Not all class(leaf) nodes' path lengths equal code_length, thus init as
 | ||||
|     // 0s can avoid out of path's loss.
 | ||||
|     math::SetConstant<DeviceContext, T> zero; | ||||
|     zero(dev_ctx, pre_out, static_cast<T>(0.0)); | ||||
|     auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); | ||||
|     math::RowwiseSum<DeviceContext, T> row_sum; | ||||
|     math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>()); | ||||
| 
 | ||||
|     std::vector<int64_t> sum_dims({batch_size, 1UL}); | ||||
|     sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace()); | ||||
|     auto sum_mat = EigenMatrix<T>::From(sum); | ||||
|     out->mutable_data<T>(ctx.GetPlace()); | ||||
|     auto out_mat = framework::EigenVector<T>::Flatten(*out); | ||||
|     if (bias) { | ||||
|       bit_code.Add(pre_out, *bias); | ||||
|     } | ||||
|     bit_code.Mul(pre_out, *w, *in); | ||||
|     // clip to [-40, 40]
 | ||||
|     Transform<DeviceContext> trans; | ||||
|     trans(ctx.template device_context<DeviceContext>(), pre_out_data, | ||||
|           pre_out_data + pre_out->numel(), pre_out_data, | ||||
|           ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0))); | ||||
|     bit_code.Sum(*pre_out, out, static_cast<T>(-1)); | ||||
|     // use softrelu to calculate cross entropy
 | ||||
|     pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log(); | ||||
|     row_sum(dev_ctx, *pre_out, &sum); | ||||
|     // TODO(guosheng): Subtract the out of path's loss, since not all
 | ||||
|     // class(leaf) nodes' path lengths equal code_length. But it won't break the
 | ||||
|     // gradient check since both have the out of path's loss and will cancel out
 | ||||
|     // each other.
 | ||||
|     out_mat.device(place) = sum_mat + out_mat; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| template <typename DeviceContext, typename T> | ||||
| class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { | ||||
|  public: | ||||
|   void Compute(const framework::ExecutionContext& ctx) const override { | ||||
|     auto* in = ctx.Input<framework::Tensor>("X"); | ||||
|     auto* w = ctx.Input<framework::Tensor>("W"); | ||||
|     auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); | ||||
|     auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W")); | ||||
|     auto* bias_grad = | ||||
|         ctx.Output<framework::Tensor>(framework::GradVarName("Bias")); | ||||
|     auto* label = ctx.Input<framework::Tensor>("Label"); | ||||
|     auto* pre_out = ctx.Input<framework::Tensor>("PreOut"); | ||||
|     auto* out_grad = | ||||
|         ctx.Input<framework::Tensor>(framework::GradVarName("Out")); | ||||
|     framework::Tensor pre_out_grad; | ||||
| 
 | ||||
|     pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace()); | ||||
|     in_grad->mutable_data<T>(ctx.GetPlace()); | ||||
|     w_grad->mutable_data<T>(ctx.GetPlace()); | ||||
|     auto& dev_ctx = ctx.template device_context<DeviceContext>(); | ||||
|     math::SetConstant<DeviceContext, T> zero; | ||||
|     zero(dev_ctx, in_grad, static_cast<T>(0.0)); | ||||
|     zero(dev_ctx, w_grad, static_cast<T>(0.0)); | ||||
| 
 | ||||
|     size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); | ||||
|     math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>()); | ||||
| 
 | ||||
|     auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); | ||||
|     auto pre_out_mat = EigenMatrix<T>::From(*pre_out); | ||||
|     auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad); | ||||
|     auto out_grad_mat = EigenMatrix<T>::From(*out_grad); | ||||
|     Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}}); | ||||
| 
 | ||||
|     // softrelu derivative
 | ||||
|     pre_out_grad_mat.device(place) = | ||||
|         static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp(); | ||||
|     bit_code.Sub(&pre_out_grad);  // the gradient of clip(w * x + b)
 | ||||
|     pre_out_grad_mat.device(place) = | ||||
|         pre_out_grad_mat * out_grad_mat.broadcast(bcast); | ||||
|     // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
 | ||||
|     // be consistent with the clipping in forward.
 | ||||
|     if (bias_grad) { | ||||
|       bias_grad->mutable_data<T>(ctx.GetPlace()); | ||||
|       zero(dev_ctx, bias_grad, static_cast<T>(0.0)); | ||||
|       bit_code.AddGrad(pre_out_grad, bias_grad); | ||||
|     } | ||||
|     bit_code.MulGradWeight(pre_out_grad, w_grad, *in); | ||||
|     bit_code.MulGradError(pre_out_grad, *w, in_grad); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,176 @@ | ||||
| /* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #include "paddle/fluid/operators/math/matrix_bit_code.h" | ||||
| #include <iostream> | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace math { | ||||
| 
 | ||||
| template <typename T> | ||||
| void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat, | ||||
|                                   const framework::Tensor& vec) { | ||||
|   SimpleCodeTable code_table(num_classes_); | ||||
|   size_t batch_size = tmat->dims()[0]; | ||||
|   size_t width = tmat->dims()[1]; | ||||
|   for (size_t i = 0; i < batch_size; ++i) { | ||||
|     auto code = code_table(static_cast<size_t>(ids_[i])); | ||||
|     int code_length = code.get_length(); | ||||
|     for (int j = 0; j < code_length; ++j) { | ||||
|       size_t index = code.calc_index(j); | ||||
|       tmat->data<T>()[i * width + j] += vec.data<T>()[index]; | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat, | ||||
|                                       framework::Tensor* vec) { | ||||
|   SimpleCodeTable code_table(num_classes_); | ||||
|   size_t batch_size = tmat.dims()[0]; | ||||
|   size_t width = tmat.dims()[1]; | ||||
|   for (size_t i = 0; i < batch_size; ++i) { | ||||
|     auto code = code_table(static_cast<size_t>(ids_[i])); | ||||
|     int code_length = code.get_length(); | ||||
|     for (int j = 0; j < code_length; ++j) { | ||||
|       size_t index = code.calc_index(j); | ||||
|       vec->data<T>()[index] += tmat.data<T>()[i * width + j]; | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat, | ||||
|                                   framework::Tensor* sum, T scale_sum) { | ||||
|   SimpleCodeTable code_table(num_classes_); | ||||
|   size_t num_samples = tmat.dims()[0]; | ||||
|   size_t o_width = tmat.dims()[1]; | ||||
|   for (size_t i = 0; i < num_samples; ++i) { | ||||
|     T sm = static_cast<T>(0.0); | ||||
|     auto code = code_table(static_cast<size_t>(ids_[i])); | ||||
|     int code_length = code.get_length(); | ||||
|     for (int j = 0; j < code_length; ++j) { | ||||
|       if (code.calc_bit(j)) { | ||||
|         // calc_bit starts from right most bit, while data in tmat[i] is in the
 | ||||
|         // reverse order.
 | ||||
|         sm += tmat.data<T>()[i * o_width + j]; | ||||
|       } | ||||
|     } | ||||
|     sum->data<T>()[i] = scale_sum * sm; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat, | ||||
|                                   const framework::Tensor& weight, | ||||
|                                   const framework::Tensor& input) { | ||||
|   SimpleCodeTable code_table(num_classes_); | ||||
|   size_t num_samples = tmat->dims()[0]; | ||||
|   size_t tmat_width = tmat->dims()[1]; | ||||
|   size_t input_width = input.dims()[1]; | ||||
|   size_t weight_width = weight.dims()[1]; | ||||
|   auto tmat_value = tmat->data<T>(); | ||||
|   auto weight_value = weight.data<T>(); | ||||
|   auto input_value = input.data<T>(); | ||||
|   for (size_t i = 0; i < num_samples; ++i) { | ||||
|     auto code = code_table(static_cast<size_t>(ids_[i])); | ||||
|     int code_length = code.get_length(); | ||||
|     for (int j = 0; j < code_length; ++j) { | ||||
|       size_t index = code.calc_index(j); | ||||
|       T sum = static_cast<T>(0.0); | ||||
|       for (size_t k = 0; k < input_width; ++k) { | ||||
|         sum += weight_value[weight_width * index + k] * | ||||
|                input_value[input_width * i + k]; | ||||
|       } | ||||
|       tmat_value[i * tmat_width + j] += sum; | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat, | ||||
|                                             framework::Tensor* weight, | ||||
|                                             const framework::Tensor& input) { | ||||
|   SimpleCodeTable code_table(num_classes_); | ||||
|   size_t num_samples = tmat.dims()[0]; | ||||
|   size_t input_width = input.dims()[1]; | ||||
|   size_t tmat_width = tmat.dims()[1]; | ||||
|   size_t weight_width = weight->dims()[1]; | ||||
|   auto tmat_value = tmat.data<T>(); | ||||
|   auto weight_value = weight->data<T>(); | ||||
|   auto input_value = input.data<T>(); | ||||
|   for (size_t i = 0; i < num_samples; ++i) { | ||||
|     auto code = code_table(static_cast<size_t>(ids_[i])); | ||||
|     int code_length = code.get_length(); | ||||
|     for (int j = 0; j < code_length; ++j) { | ||||
|       size_t index = code.calc_index(j); | ||||
| 
 | ||||
|       for (size_t k = 0; k < input_width; ++k) { | ||||
|         weight_value[weight_width * index + k] += | ||||
|             tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat, | ||||
|                                            const framework::Tensor& weight, | ||||
|                                            framework::Tensor* input) { | ||||
|   SimpleCodeTable code_table(num_classes_); | ||||
|   size_t num_samples = tmat.dims()[0]; | ||||
|   size_t tmat_width = tmat.dims()[1]; | ||||
|   size_t input_width = input->dims()[1]; | ||||
|   size_t weight_width = weight.dims()[1]; | ||||
|   auto tmat_value = tmat.data<T>(); | ||||
|   auto weight_value = weight.data<T>(); | ||||
|   auto input_value = input->data<T>(); | ||||
| 
 | ||||
|   for (size_t i = 0; i < num_samples; ++i) { | ||||
|     auto code = code_table(static_cast<size_t>(ids_[i])); | ||||
|     int code_length = code.get_length(); | ||||
|     for (int j = 0; j < code_length; ++j) { | ||||
|       size_t index = code.calc_index(j); | ||||
| 
 | ||||
|       for (size_t k = 0; k < input_width; ++k) { | ||||
|         input_value[input_width * i + k] += | ||||
|             tmat_value[i * tmat_width + j] * | ||||
|             weight_value[weight_width * index + k]; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) { | ||||
|   SimpleCodeTable code_table(num_classes_); | ||||
|   size_t num_samples = tmat->dims()[0]; | ||||
|   size_t o_width = tmat->dims()[1]; | ||||
|   for (size_t i = 0; i < num_samples; ++i) { | ||||
|     auto code = code_table(static_cast<size_t>(ids_[i])); | ||||
|     int code_length = code.get_length(); | ||||
|     for (int j = 0; j < code_length; ++j) { | ||||
|       if (code.calc_bit(j)) { | ||||
|         tmat->data<T>()[i * o_width + j] -= 1; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template class MatrixBitCodeFunctor<float>; | ||||
| template class MatrixBitCodeFunctor<double>; | ||||
| 
 | ||||
| }  // namespace math
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,143 @@ | ||||
| /* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. */ | ||||
| 
 | ||||
| #pragma once | ||||
| #include "paddle/fluid/framework/eigen.h" | ||||
| #include "paddle/fluid/framework/tensor.h" | ||||
| #include "paddle/fluid/platform/device_context.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace operators { | ||||
| namespace math { | ||||
| /**
 | ||||
|  * SimpleCodeTable class should support 3 functions: | ||||
|  * | ||||
|  * size_t size() | ||||
|  *   return the number of ids | ||||
|  * | ||||
|  * int get_max_code_length() | ||||
|  *   return the maximal code length | ||||
|  * | ||||
|  * SimpleCode operator()(size_t i) | ||||
|  *   return the i-th code. Code class is descriebed below. | ||||
|  * | ||||
|  * SimpleCode class should support 3 functions: | ||||
|  * | ||||
|  * int get_length() | ||||
|  *   return the length of the code | ||||
|  * | ||||
|  * size_t cal_index(int bit) | ||||
|  *   bit ranges from 0 to get_length() - 1 | ||||
|  *   return the index for the (1+bit) level parent | ||||
|  * | ||||
|  * bool calc_bit(int bit) | ||||
|  *   return true if the bit level parent is the right child of (1+bit) level | ||||
|  *   parent | ||||
|  * | ||||
|  */ | ||||
| 
 | ||||
| /**
 | ||||
|  * return the 1-based index of the highest bit set | ||||
|  * | ||||
|  * for x > 0: | ||||
|  * \f[ | ||||
|  *    FindLastSet(x) = 1 + \floor*{\log_{2}x} | ||||
|  * \f] | ||||
|  */ | ||||
| inline constexpr size_t FindLastSet(size_t x) { | ||||
|   return std::is_same<size_t, unsigned int>::value | ||||
|              ? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0) | ||||
|              : (std::is_same<size_t, unsigned long>::value  // NOLINT
 | ||||
|                     ? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0) | ||||
|                     : (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0)); | ||||
| } | ||||
| 
 | ||||
| struct SimpleCode { | ||||
|   SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {} | ||||
|   /**
 | ||||
|    * Here the id of root shoud be 1 rather than 0, thus the encoding of class c | ||||
|    * is `c + num_classes` and all siblings can get the same weight indice using | ||||
|    * prefixes. | ||||
|    * Weight index is the prefixes of encoding, thus leave out the right most | ||||
|    * bit in calc_index. | ||||
|    * Binary classification path is the suffixes of encoding, thus leave out the | ||||
|    * left most bit in calc_bit. | ||||
|    */ | ||||
|   inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; } | ||||
|   inline bool calc_bit(int bit) const { return c_ & (1 << bit); } | ||||
|   inline int get_length() const { return FindLastSet(c_) - 1; } | ||||
| 
 | ||||
|  private: | ||||
|   size_t c_; | ||||
| }; | ||||
| 
 | ||||
| struct SimpleCodeTable { | ||||
|   explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {} | ||||
|   SimpleCode operator()(size_t code) const { | ||||
|     return SimpleCode(code, num_classes_); | ||||
|   } | ||||
|   size_t size() const { return num_classes_; } | ||||
|   int get_max_code_length() const { return FindLastSet(num_classes_ - 1); } | ||||
| 
 | ||||
|  private: | ||||
|   size_t num_classes_; | ||||
| }; | ||||
| 
 | ||||
| template <typename T> | ||||
| class MatrixBitCodeFunctor { | ||||
|  public: | ||||
|   explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids) | ||||
|       : num_classes_(num_classes), ids_(ids) {} | ||||
|   /* For j < code_length
 | ||||
|        tmat(i, j) += vec(0, index(i, j)) | ||||
|   */ | ||||
|   void Add(framework::Tensor* tmat, const framework::Tensor& vec); | ||||
| 
 | ||||
|   /* For j < code_length
 | ||||
|        vec(0, index(i, j)) += tmat(i, j) | ||||
|   */ | ||||
|   void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec); | ||||
| 
 | ||||
|   /* For j < code_length
 | ||||
|     sum(i, 0) = \sum_j bit(i, j) * tmat(i, j) | ||||
|   */ | ||||
|   void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum); | ||||
| 
 | ||||
|   /* For j < code_length
 | ||||
|        tmat(i, j) -= bit(i, j) | ||||
|   */ | ||||
|   void Sub(framework::Tensor* tmat); | ||||
|   /* For j < code_length
 | ||||
|        input.row(i) += tmat(i, j) * weight.row(index(i, j)) | ||||
|   */ | ||||
|   void Mul(framework::Tensor* tmat, const framework::Tensor& weight, | ||||
|            const framework::Tensor& input); | ||||
| 
 | ||||
|   /* For index(i, j) >= 0:
 | ||||
|       weight.row(index(i, j)) += tmat(i, j) * input.row(i) | ||||
|   */ | ||||
|   void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, | ||||
|                      const framework::Tensor& input); | ||||
|   /* For j < code_length
 | ||||
|     input.row(i) += tmat(i, j) * weight.row(index(i, j)) | ||||
|   */ | ||||
|   void MulGradError(const framework::Tensor& tmat, | ||||
|                     const framework::Tensor& weight, framework::Tensor* input); | ||||
| 
 | ||||
|   size_t num_classes_; | ||||
|   const int64_t* ids_; | ||||
| }; | ||||
| }  // namespace math
 | ||||
| }  // namespace operators
 | ||||
| }  // namespace paddle
 | ||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								| @ -1,75 +0,0 @@ | ||||
| #   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import paddle.fluid as fluid | ||||
| import unittest | ||||
| import os | ||||
| import tempfile | ||||
| 
 | ||||
| 
 | ||||
| class TestCheckpoint(unittest.TestCase): | ||||
|     def setUp(self): | ||||
|         self.dirname = tempfile.mktemp() | ||||
|         self.max_num_checkpoints = 3 | ||||
|         self.epoch_interval = 1 | ||||
|         self.step_interval = 1 | ||||
|         self.trainer_id = 0 | ||||
|         self.chief = self.trainer_id == 0 | ||||
|         self.place = fluid.CPUPlace() | ||||
|         self.epoch_id = 100 | ||||
|         self.step_id = 20 | ||||
| 
 | ||||
|     def test_checkpoint(self): | ||||
|         self.save_checkpoint() | ||||
|         serial = fluid.io.get_latest_checkpoint_serial(self.dirname) | ||||
|         self.assertTrue(serial >= 0) | ||||
|         trainer_args = ["epoch_id", "step_id"] | ||||
|         epoch_id, step_id = fluid.io.load_trainer_args( | ||||
|             self.dirname, serial, self.trainer_id, trainer_args) | ||||
|         self.assertEqual(self.step_id, int(step_id)) | ||||
|         self.assertEqual(self.epoch_id, int(epoch_id)) | ||||
| 
 | ||||
|         program = fluid.Program() | ||||
|         with fluid.program_guard(program): | ||||
|             exe = fluid.Executor(self.place) | ||||
|             fluid.io.load_checkpoint(exe, self.dirname, serial, program) | ||||
| 
 | ||||
|         fluid.io.clean_checkpoint(self.dirname, delete_dir=True) | ||||
|         self.assertFalse(os.path.isdir(self.dirname)) | ||||
| 
 | ||||
|     def save_checkpoint(self): | ||||
|         config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints, | ||||
|                                         self.epoch_interval, self.step_interval) | ||||
| 
 | ||||
|         trainer_args = {} | ||||
|         trainer_args["epoch_id"] = self.epoch_id | ||||
|         trainer_args["step_id"] = self.step_id | ||||
| 
 | ||||
|         program = fluid.Program() | ||||
|         with fluid.program_guard(program): | ||||
|             program.global_block().create_var( | ||||
|                 name="scale_0", | ||||
|                 psersistable=True, | ||||
|                 dtype="float32", | ||||
|                 shape=[32, 32]) | ||||
| 
 | ||||
|             exe = fluid.Executor(self.place) | ||||
|             for i in xrange(10): | ||||
|                 fluid.io.save_checkpoint(exe, config.checkpoint_dir, | ||||
|                                          self.trainer_id, trainer_args, program, | ||||
|                                          config.max_num_checkpoints) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
| @ -0,0 +1,99 @@ | ||||
| #   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import unittest | ||||
| import numpy as np | ||||
| import math | ||||
| from op_test import OpTest | ||||
| 
 | ||||
| 
 | ||||
| def find_latest_set(num): | ||||
|     return 1 + int(math.floor(math.log(num, 2))) | ||||
| 
 | ||||
| 
 | ||||
| class CodeTable(object): | ||||
|     def __init__(self, num_classes, code): | ||||
|         self.c = num_classes + code | ||||
| 
 | ||||
|     def cal_index(self, bit): | ||||
|         return (self.c >> (bit + 1)) - 1 | ||||
| 
 | ||||
|     def get_length(self): | ||||
|         return find_latest_set(self.c) - 1 | ||||
| 
 | ||||
|     def cal_bit(self, bit): | ||||
|         return self.c & (1 << bit) | ||||
| 
 | ||||
| 
 | ||||
| def hsigmoid(x, w, label, bias, num_classes): | ||||
|     batch_size = x.shape[0] | ||||
|     code_length = find_latest_set(num_classes - 1) | ||||
|     code_table = [0 for _ in range(code_length)] | ||||
|     pre_output = np.zeros((batch_size, code_length)) | ||||
|     pre_sum = np.zeros((batch_size, 1)) | ||||
|     out = np.zeros((batch_size, 1)).astype("float32") | ||||
|     for i in range(batch_size): | ||||
|         code_table = CodeTable(num_classes, label[i]) | ||||
|         length = code_table.get_length() | ||||
|         for j in range(length): | ||||
|             idx = code_table.cal_index(j) | ||||
|             pre_output[i][j] += bias[0][idx] | ||||
|     for i in range(batch_size): | ||||
|         code_table = CodeTable(num_classes, label[i]) | ||||
|         length = code_table.get_length() | ||||
|         for j in range(length): | ||||
|             idx = code_table.cal_index(j) | ||||
|             pre_output[i][j] += np.dot(w[idx], x[i]) | ||||
|     # clip[-40.0, 40.0] | ||||
|     pre_output = np.clip(pre_output, -40.0, 40.0) | ||||
|     # out(i, 0) = \sum_j  bit(i, j) * preout(i, j) | ||||
|     for i in range(batch_size): | ||||
|         code_table = CodeTable(num_classes, label[i]) | ||||
|         length = code_table.get_length() | ||||
|         sum = 0.0 | ||||
|         for j in range(length): | ||||
|             if code_table.cal_bit(j): | ||||
|                 sum += pre_output[i][j] | ||||
|         out[i] = -1.0 * sum | ||||
|     # soft relu | ||||
|     pre_output = np.log(1 + np.exp(pre_output)) | ||||
|     pre_sum = pre_output.sum(1).reshape((batch_size, 1)) | ||||
|     out += pre_sum | ||||
|     return pre_output, out | ||||
| 
 | ||||
| 
 | ||||
| class TestHSigmoidOp(OpTest): | ||||
|     def setUp(self): | ||||
|         self.op_type = "hierarchical_sigmoid" | ||||
|         num_classes = 6 | ||||
|         feature_size = 8 | ||||
|         batch_size = 4 | ||||
|         x = np.random.random((batch_size, feature_size)).astype("float32") | ||||
|         w = np.random.random((num_classes - 1, feature_size)).astype("float32") | ||||
|         label = np.random.randint(0, num_classes, (batch_size, 1)) | ||||
|         bias = np.random.random((1, num_classes - 1)).astype("float32") | ||||
|         self.attrs = {'num_classes': num_classes} | ||||
|         self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias} | ||||
|         pre_output, out = hsigmoid(x, w, label, bias, num_classes) | ||||
|         self.outputs = {'PreOut': pre_output, 'Out': out} | ||||
| 
 | ||||
|     def test_check_output(self): | ||||
|         self.check_output() | ||||
| 
 | ||||
|     def test_check_grad(self): | ||||
|         self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label')) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
											
												
													File diff suppressed because it is too large
													Load Diff
												
											
										
									
								
					Loading…
					
					
				
		Reference in new issue