Merge branch 'fix-optimizer-accumulator' of ssh://github.com/jacquesqiao/Paddle into distribute-transpiler-handle-adam-accumulator
commit
39d88ebc02
@ -0,0 +1,167 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Organize the classes into a binary tree. At each node, a sigmoid function
|
||||||
|
* is used to calculate the probability of belonging to the right branch.
|
||||||
|
* This idea is from "F. Morin, Y. Bengio (AISTATS 05):
|
||||||
|
* Hierarchical Probabilistic Neural Network Language Model."
|
||||||
|
*
|
||||||
|
* Here we uses a simple way of making the binary tree.
|
||||||
|
* Assuming the number of classes C = 6,
|
||||||
|
* The classes are organized as a binary tree in the following way:
|
||||||
|
*
|
||||||
|
* @code{.py}
|
||||||
|
* *-*-*- 2
|
||||||
|
* | | |- 3
|
||||||
|
* | |
|
||||||
|
* | |-*- 4
|
||||||
|
* | |- 5
|
||||||
|
* |
|
||||||
|
* |-*- 0
|
||||||
|
* |- 1
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* where * indicates an internal node, and each leaf node represents a class.
|
||||||
|
* - Node 0 ... C-2 are internal nodes.
|
||||||
|
* - Node C-1 ... 2C-2 are leaf nodes.
|
||||||
|
* - Class c is represented by leaf node \f$c+C-1\f$.
|
||||||
|
*
|
||||||
|
* We assign an id for each node:
|
||||||
|
* - the id of root be 0.
|
||||||
|
* - the left child of a node i is 2*i+1.
|
||||||
|
* - the right child of a node i is 2*i+2.
|
||||||
|
*
|
||||||
|
* It's easy to see that:
|
||||||
|
* - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$.
|
||||||
|
* - the j-th level ancestor of node i is
|
||||||
|
* \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$.
|
||||||
|
* - A node i is a left child of its parent if \f$(i-1)\%2==0\f$.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
|
||||||
|
"Output(PreOut) should not be null.");
|
||||||
|
const int64_t batch_size = ctx->GetInputDim("X")[0];
|
||||||
|
std::vector<int64_t> output_shape({batch_size, 1});
|
||||||
|
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
return framework::OpKernelType(
|
||||||
|
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
|
||||||
|
ctx.GetPlace());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename AttrType>
|
||||||
|
class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("X",
|
||||||
|
"(Tensor, required) The input tensor with shape [N, D], "
|
||||||
|
"where N is the size of mini-batch, and D is the feature size.");
|
||||||
|
AddInput("W",
|
||||||
|
"(Tensor, required), The parameters of hierarchical "
|
||||||
|
"sigmoid operator, each of them is a 2-D tensor, the shape is"
|
||||||
|
"[num_classes - 1, D].");
|
||||||
|
AddInput("Label",
|
||||||
|
"(Tensor, required), The labels of training data. It's a"
|
||||||
|
"tensor with shape [N, 1].");
|
||||||
|
AddInput("Bias",
|
||||||
|
"(Tensor, optional), The bias is a tensor with shape"
|
||||||
|
"[1, num_classes - 1].");
|
||||||
|
AddOutput("Out",
|
||||||
|
"(Tensor, required) The output of hierarchical sigmoid operator."
|
||||||
|
"The shape is [N, 1].");
|
||||||
|
AddOutput("PreOut",
|
||||||
|
"(Tensor, required) A intermedia 2-D tensor with shape "
|
||||||
|
"[batch_size, code_length], where code_length represents the "
|
||||||
|
"maximum path length from root to leaf nodes.")
|
||||||
|
.AsIntermediate();
|
||||||
|
AddAttr<AttrType>("num_classes", "(int, required), The number of classes")
|
||||||
|
.SetDefault(2);
|
||||||
|
AddComment(R"DOC(
|
||||||
|
The hierarchical sigmoid operator organize the classes into a binary tree.
|
||||||
|
At each node, a sigmoid function is used to calculate the probability of
|
||||||
|
belonging to the right branch. This idea is from
|
||||||
|
"F. Morin, Y. Bengio (AISTATS 05):
|
||||||
|
Hierarchical Probabilistic Neural Network Language Model."
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
|
||||||
|
"Input(Preout) should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
|
||||||
|
"Output(W@Grad should not be null.)");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")));
|
||||||
|
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
|
||||||
|
ctx->SetOutputDim(framework::GradVarName("Bias"),
|
||||||
|
ctx->GetInputDim("Bias"));
|
||||||
|
}
|
||||||
|
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
|
||||||
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
return framework::OpKernelType(
|
||||||
|
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
|
||||||
|
ctx.GetPlace());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
|
||||||
|
ops::HierarchicalSigmoidOpMaker<int>,
|
||||||
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||||
|
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
hierarchical_sigmoid,
|
||||||
|
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||||
|
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
|
||||||
|
double>);
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
hierarchical_sigmoid_grad,
|
||||||
|
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
|
||||||
|
float>,
|
||||||
|
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
|
||||||
|
double>);
|
@ -0,0 +1,135 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/operators/clip_op.h"
|
||||||
|
#include "paddle/fluid/operators/math/math_function.h"
|
||||||
|
#include "paddle/fluid/operators/math/matrix_bit_code.h"
|
||||||
|
#include "paddle/fluid/platform/transform.h"
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename T, int MajorType = Eigen::RowMajor,
|
||||||
|
typename IndexType = Eigen::DenseIndex>
|
||||||
|
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||||
|
using platform::Transform;
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* in = ctx.Input<framework::Tensor>("X");
|
||||||
|
auto* w = ctx.Input<framework::Tensor>("W");
|
||||||
|
auto* label = ctx.Input<framework::Tensor>("Label");
|
||||||
|
auto* bias = ctx.Input<framework::Tensor>("Bias");
|
||||||
|
auto* out = ctx.Output<framework::Tensor>("Out");
|
||||||
|
auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
|
||||||
|
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
|
||||||
|
int64_t code_length = math::FindLastSet(num_classes - 1);
|
||||||
|
int64_t batch_size = in->dims()[0];
|
||||||
|
framework::Tensor sum;
|
||||||
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||||
|
auto* pre_out_data = pre_out->mutable_data<T>(
|
||||||
|
framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
|
||||||
|
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
|
||||||
|
// Not all class(leaf) nodes' path lengths equal code_length, thus init as
|
||||||
|
// 0s can avoid out of path's loss.
|
||||||
|
math::SetConstant<DeviceContext, T> zero;
|
||||||
|
zero(dev_ctx, pre_out, static_cast<T>(0.0));
|
||||||
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||||
|
math::RowwiseSum<DeviceContext, T> row_sum;
|
||||||
|
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
|
||||||
|
|
||||||
|
std::vector<int64_t> sum_dims({batch_size, 1UL});
|
||||||
|
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
|
||||||
|
auto sum_mat = EigenMatrix<T>::From(sum);
|
||||||
|
out->mutable_data<T>(ctx.GetPlace());
|
||||||
|
auto out_mat = framework::EigenVector<T>::Flatten(*out);
|
||||||
|
if (bias) {
|
||||||
|
bit_code.Add(pre_out, *bias);
|
||||||
|
}
|
||||||
|
bit_code.Mul(pre_out, *w, *in);
|
||||||
|
// clip to [-40, 40]
|
||||||
|
Transform<DeviceContext> trans;
|
||||||
|
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
|
||||||
|
pre_out_data + pre_out->numel(), pre_out_data,
|
||||||
|
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
|
||||||
|
bit_code.Sum(*pre_out, out, static_cast<T>(-1));
|
||||||
|
// use softrelu to calculate cross entropy
|
||||||
|
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
|
||||||
|
row_sum(dev_ctx, *pre_out, &sum);
|
||||||
|
// TODO(guosheng): Subtract the out of path's loss, since not all
|
||||||
|
// class(leaf) nodes' path lengths equal code_length. But it won't break the
|
||||||
|
// gradient check since both have the out of path's loss and will cancel out
|
||||||
|
// each other.
|
||||||
|
out_mat.device(place) = sum_mat + out_mat;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* in = ctx.Input<framework::Tensor>("X");
|
||||||
|
auto* w = ctx.Input<framework::Tensor>("W");
|
||||||
|
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||||
|
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
|
||||||
|
auto* bias_grad =
|
||||||
|
ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
|
||||||
|
auto* label = ctx.Input<framework::Tensor>("Label");
|
||||||
|
auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
|
||||||
|
auto* out_grad =
|
||||||
|
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||||
|
framework::Tensor pre_out_grad;
|
||||||
|
|
||||||
|
pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace());
|
||||||
|
in_grad->mutable_data<T>(ctx.GetPlace());
|
||||||
|
w_grad->mutable_data<T>(ctx.GetPlace());
|
||||||
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||||
|
math::SetConstant<DeviceContext, T> zero;
|
||||||
|
zero(dev_ctx, in_grad, static_cast<T>(0.0));
|
||||||
|
zero(dev_ctx, w_grad, static_cast<T>(0.0));
|
||||||
|
|
||||||
|
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
|
||||||
|
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
|
||||||
|
|
||||||
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||||
|
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
|
||||||
|
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
|
||||||
|
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
|
||||||
|
Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
|
||||||
|
|
||||||
|
// softrelu derivative
|
||||||
|
pre_out_grad_mat.device(place) =
|
||||||
|
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
|
||||||
|
bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b)
|
||||||
|
pre_out_grad_mat.device(place) =
|
||||||
|
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
|
||||||
|
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
|
||||||
|
// be consistent with the clipping in forward.
|
||||||
|
if (bias_grad) {
|
||||||
|
bias_grad->mutable_data<T>(ctx.GetPlace());
|
||||||
|
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
|
||||||
|
bit_code.AddGrad(pre_out_grad, bias_grad);
|
||||||
|
}
|
||||||
|
bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
|
||||||
|
bit_code.MulGradError(pre_out_grad, *w, in_grad);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,176 @@
|
|||||||
|
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/math/matrix_bit_code.h"
|
||||||
|
#include <iostream>
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
|
||||||
|
const framework::Tensor& vec) {
|
||||||
|
SimpleCodeTable code_table(num_classes_);
|
||||||
|
size_t batch_size = tmat->dims()[0];
|
||||||
|
size_t width = tmat->dims()[1];
|
||||||
|
for (size_t i = 0; i < batch_size; ++i) {
|
||||||
|
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||||
|
int code_length = code.get_length();
|
||||||
|
for (int j = 0; j < code_length; ++j) {
|
||||||
|
size_t index = code.calc_index(j);
|
||||||
|
tmat->data<T>()[i * width + j] += vec.data<T>()[index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
|
||||||
|
framework::Tensor* vec) {
|
||||||
|
SimpleCodeTable code_table(num_classes_);
|
||||||
|
size_t batch_size = tmat.dims()[0];
|
||||||
|
size_t width = tmat.dims()[1];
|
||||||
|
for (size_t i = 0; i < batch_size; ++i) {
|
||||||
|
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||||
|
int code_length = code.get_length();
|
||||||
|
for (int j = 0; j < code_length; ++j) {
|
||||||
|
size_t index = code.calc_index(j);
|
||||||
|
vec->data<T>()[index] += tmat.data<T>()[i * width + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
|
||||||
|
framework::Tensor* sum, T scale_sum) {
|
||||||
|
SimpleCodeTable code_table(num_classes_);
|
||||||
|
size_t num_samples = tmat.dims()[0];
|
||||||
|
size_t o_width = tmat.dims()[1];
|
||||||
|
for (size_t i = 0; i < num_samples; ++i) {
|
||||||
|
T sm = static_cast<T>(0.0);
|
||||||
|
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||||
|
int code_length = code.get_length();
|
||||||
|
for (int j = 0; j < code_length; ++j) {
|
||||||
|
if (code.calc_bit(j)) {
|
||||||
|
// calc_bit starts from right most bit, while data in tmat[i] is in the
|
||||||
|
// reverse order.
|
||||||
|
sm += tmat.data<T>()[i * o_width + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sum->data<T>()[i] = scale_sum * sm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
|
||||||
|
const framework::Tensor& weight,
|
||||||
|
const framework::Tensor& input) {
|
||||||
|
SimpleCodeTable code_table(num_classes_);
|
||||||
|
size_t num_samples = tmat->dims()[0];
|
||||||
|
size_t tmat_width = tmat->dims()[1];
|
||||||
|
size_t input_width = input.dims()[1];
|
||||||
|
size_t weight_width = weight.dims()[1];
|
||||||
|
auto tmat_value = tmat->data<T>();
|
||||||
|
auto weight_value = weight.data<T>();
|
||||||
|
auto input_value = input.data<T>();
|
||||||
|
for (size_t i = 0; i < num_samples; ++i) {
|
||||||
|
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||||
|
int code_length = code.get_length();
|
||||||
|
for (int j = 0; j < code_length; ++j) {
|
||||||
|
size_t index = code.calc_index(j);
|
||||||
|
T sum = static_cast<T>(0.0);
|
||||||
|
for (size_t k = 0; k < input_width; ++k) {
|
||||||
|
sum += weight_value[weight_width * index + k] *
|
||||||
|
input_value[input_width * i + k];
|
||||||
|
}
|
||||||
|
tmat_value[i * tmat_width + j] += sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
|
||||||
|
framework::Tensor* weight,
|
||||||
|
const framework::Tensor& input) {
|
||||||
|
SimpleCodeTable code_table(num_classes_);
|
||||||
|
size_t num_samples = tmat.dims()[0];
|
||||||
|
size_t input_width = input.dims()[1];
|
||||||
|
size_t tmat_width = tmat.dims()[1];
|
||||||
|
size_t weight_width = weight->dims()[1];
|
||||||
|
auto tmat_value = tmat.data<T>();
|
||||||
|
auto weight_value = weight->data<T>();
|
||||||
|
auto input_value = input.data<T>();
|
||||||
|
for (size_t i = 0; i < num_samples; ++i) {
|
||||||
|
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||||
|
int code_length = code.get_length();
|
||||||
|
for (int j = 0; j < code_length; ++j) {
|
||||||
|
size_t index = code.calc_index(j);
|
||||||
|
|
||||||
|
for (size_t k = 0; k < input_width; ++k) {
|
||||||
|
weight_value[weight_width * index + k] +=
|
||||||
|
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
|
||||||
|
const framework::Tensor& weight,
|
||||||
|
framework::Tensor* input) {
|
||||||
|
SimpleCodeTable code_table(num_classes_);
|
||||||
|
size_t num_samples = tmat.dims()[0];
|
||||||
|
size_t tmat_width = tmat.dims()[1];
|
||||||
|
size_t input_width = input->dims()[1];
|
||||||
|
size_t weight_width = weight.dims()[1];
|
||||||
|
auto tmat_value = tmat.data<T>();
|
||||||
|
auto weight_value = weight.data<T>();
|
||||||
|
auto input_value = input->data<T>();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num_samples; ++i) {
|
||||||
|
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||||
|
int code_length = code.get_length();
|
||||||
|
for (int j = 0; j < code_length; ++j) {
|
||||||
|
size_t index = code.calc_index(j);
|
||||||
|
|
||||||
|
for (size_t k = 0; k < input_width; ++k) {
|
||||||
|
input_value[input_width * i + k] +=
|
||||||
|
tmat_value[i * tmat_width + j] *
|
||||||
|
weight_value[weight_width * index + k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
|
||||||
|
SimpleCodeTable code_table(num_classes_);
|
||||||
|
size_t num_samples = tmat->dims()[0];
|
||||||
|
size_t o_width = tmat->dims()[1];
|
||||||
|
for (size_t i = 0; i < num_samples; ++i) {
|
||||||
|
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||||
|
int code_length = code.get_length();
|
||||||
|
for (int j = 0; j < code_length; ++j) {
|
||||||
|
if (code.calc_bit(j)) {
|
||||||
|
tmat->data<T>()[i * o_width + j] -= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template class MatrixBitCodeFunctor<float>;
|
||||||
|
template class MatrixBitCodeFunctor<double>;
|
||||||
|
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,143 @@
|
|||||||
|
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "paddle/fluid/framework/eigen.h"
|
||||||
|
#include "paddle/fluid/framework/tensor.h"
|
||||||
|
#include "paddle/fluid/platform/device_context.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace math {
|
||||||
|
/**
|
||||||
|
* SimpleCodeTable class should support 3 functions:
|
||||||
|
*
|
||||||
|
* size_t size()
|
||||||
|
* return the number of ids
|
||||||
|
*
|
||||||
|
* int get_max_code_length()
|
||||||
|
* return the maximal code length
|
||||||
|
*
|
||||||
|
* SimpleCode operator()(size_t i)
|
||||||
|
* return the i-th code. Code class is descriebed below.
|
||||||
|
*
|
||||||
|
* SimpleCode class should support 3 functions:
|
||||||
|
*
|
||||||
|
* int get_length()
|
||||||
|
* return the length of the code
|
||||||
|
*
|
||||||
|
* size_t cal_index(int bit)
|
||||||
|
* bit ranges from 0 to get_length() - 1
|
||||||
|
* return the index for the (1+bit) level parent
|
||||||
|
*
|
||||||
|
* bool calc_bit(int bit)
|
||||||
|
* return true if the bit level parent is the right child of (1+bit) level
|
||||||
|
* parent
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* return the 1-based index of the highest bit set
|
||||||
|
*
|
||||||
|
* for x > 0:
|
||||||
|
* \f[
|
||||||
|
* FindLastSet(x) = 1 + \floor*{\log_{2}x}
|
||||||
|
* \f]
|
||||||
|
*/
|
||||||
|
inline constexpr size_t FindLastSet(size_t x) {
|
||||||
|
return std::is_same<size_t, unsigned int>::value
|
||||||
|
? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0)
|
||||||
|
: (std::is_same<size_t, unsigned long>::value // NOLINT
|
||||||
|
? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0)
|
||||||
|
: (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SimpleCode {
|
||||||
|
SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {}
|
||||||
|
/**
|
||||||
|
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c
|
||||||
|
* is `c + num_classes` and all siblings can get the same weight indice using
|
||||||
|
* prefixes.
|
||||||
|
* Weight index is the prefixes of encoding, thus leave out the right most
|
||||||
|
* bit in calc_index.
|
||||||
|
* Binary classification path is the suffixes of encoding, thus leave out the
|
||||||
|
* left most bit in calc_bit.
|
||||||
|
*/
|
||||||
|
inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
|
||||||
|
inline bool calc_bit(int bit) const { return c_ & (1 << bit); }
|
||||||
|
inline int get_length() const { return FindLastSet(c_) - 1; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t c_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SimpleCodeTable {
|
||||||
|
explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {}
|
||||||
|
SimpleCode operator()(size_t code) const {
|
||||||
|
return SimpleCode(code, num_classes_);
|
||||||
|
}
|
||||||
|
size_t size() const { return num_classes_; }
|
||||||
|
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t num_classes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class MatrixBitCodeFunctor {
|
||||||
|
public:
|
||||||
|
explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
|
||||||
|
: num_classes_(num_classes), ids_(ids) {}
|
||||||
|
/* For j < code_length
|
||||||
|
tmat(i, j) += vec(0, index(i, j))
|
||||||
|
*/
|
||||||
|
void Add(framework::Tensor* tmat, const framework::Tensor& vec);
|
||||||
|
|
||||||
|
/* For j < code_length
|
||||||
|
vec(0, index(i, j)) += tmat(i, j)
|
||||||
|
*/
|
||||||
|
void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
|
||||||
|
|
||||||
|
/* For j < code_length
|
||||||
|
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
|
||||||
|
*/
|
||||||
|
void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum);
|
||||||
|
|
||||||
|
/* For j < code_length
|
||||||
|
tmat(i, j) -= bit(i, j)
|
||||||
|
*/
|
||||||
|
void Sub(framework::Tensor* tmat);
|
||||||
|
/* For j < code_length
|
||||||
|
input.row(i) += tmat(i, j) * weight.row(index(i, j))
|
||||||
|
*/
|
||||||
|
void Mul(framework::Tensor* tmat, const framework::Tensor& weight,
|
||||||
|
const framework::Tensor& input);
|
||||||
|
|
||||||
|
/* For index(i, j) >= 0:
|
||||||
|
weight.row(index(i, j)) += tmat(i, j) * input.row(i)
|
||||||
|
*/
|
||||||
|
void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight,
|
||||||
|
const framework::Tensor& input);
|
||||||
|
/* For j < code_length
|
||||||
|
input.row(i) += tmat(i, j) * weight.row(index(i, j))
|
||||||
|
*/
|
||||||
|
void MulGradError(const framework::Tensor& tmat,
|
||||||
|
const framework::Tensor& weight, framework::Tensor* input);
|
||||||
|
|
||||||
|
size_t num_classes_;
|
||||||
|
const int64_t* ids_;
|
||||||
|
};
|
||||||
|
} // namespace math
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -1,75 +0,0 @@
|
|||||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import paddle.fluid as fluid
|
|
||||||
import unittest
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
|
|
||||||
class TestCheckpoint(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.dirname = tempfile.mktemp()
|
|
||||||
self.max_num_checkpoints = 3
|
|
||||||
self.epoch_interval = 1
|
|
||||||
self.step_interval = 1
|
|
||||||
self.trainer_id = 0
|
|
||||||
self.chief = self.trainer_id == 0
|
|
||||||
self.place = fluid.CPUPlace()
|
|
||||||
self.epoch_id = 100
|
|
||||||
self.step_id = 20
|
|
||||||
|
|
||||||
def test_checkpoint(self):
|
|
||||||
self.save_checkpoint()
|
|
||||||
serial = fluid.io.get_latest_checkpoint_serial(self.dirname)
|
|
||||||
self.assertTrue(serial >= 0)
|
|
||||||
trainer_args = ["epoch_id", "step_id"]
|
|
||||||
epoch_id, step_id = fluid.io.load_trainer_args(
|
|
||||||
self.dirname, serial, self.trainer_id, trainer_args)
|
|
||||||
self.assertEqual(self.step_id, int(step_id))
|
|
||||||
self.assertEqual(self.epoch_id, int(epoch_id))
|
|
||||||
|
|
||||||
program = fluid.Program()
|
|
||||||
with fluid.program_guard(program):
|
|
||||||
exe = fluid.Executor(self.place)
|
|
||||||
fluid.io.load_checkpoint(exe, self.dirname, serial, program)
|
|
||||||
|
|
||||||
fluid.io.clean_checkpoint(self.dirname, delete_dir=True)
|
|
||||||
self.assertFalse(os.path.isdir(self.dirname))
|
|
||||||
|
|
||||||
def save_checkpoint(self):
|
|
||||||
config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints,
|
|
||||||
self.epoch_interval, self.step_interval)
|
|
||||||
|
|
||||||
trainer_args = {}
|
|
||||||
trainer_args["epoch_id"] = self.epoch_id
|
|
||||||
trainer_args["step_id"] = self.step_id
|
|
||||||
|
|
||||||
program = fluid.Program()
|
|
||||||
with fluid.program_guard(program):
|
|
||||||
program.global_block().create_var(
|
|
||||||
name="scale_0",
|
|
||||||
psersistable=True,
|
|
||||||
dtype="float32",
|
|
||||||
shape=[32, 32])
|
|
||||||
|
|
||||||
exe = fluid.Executor(self.place)
|
|
||||||
for i in xrange(10):
|
|
||||||
fluid.io.save_checkpoint(exe, config.checkpoint_dir,
|
|
||||||
self.trainer_id, trainer_args, program,
|
|
||||||
config.max_num_checkpoints)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
@ -0,0 +1,99 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
def find_latest_set(num):
|
||||||
|
return 1 + int(math.floor(math.log(num, 2)))
|
||||||
|
|
||||||
|
|
||||||
|
class CodeTable(object):
|
||||||
|
def __init__(self, num_classes, code):
|
||||||
|
self.c = num_classes + code
|
||||||
|
|
||||||
|
def cal_index(self, bit):
|
||||||
|
return (self.c >> (bit + 1)) - 1
|
||||||
|
|
||||||
|
def get_length(self):
|
||||||
|
return find_latest_set(self.c) - 1
|
||||||
|
|
||||||
|
def cal_bit(self, bit):
|
||||||
|
return self.c & (1 << bit)
|
||||||
|
|
||||||
|
|
||||||
|
def hsigmoid(x, w, label, bias, num_classes):
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
code_length = find_latest_set(num_classes - 1)
|
||||||
|
code_table = [0 for _ in range(code_length)]
|
||||||
|
pre_output = np.zeros((batch_size, code_length))
|
||||||
|
pre_sum = np.zeros((batch_size, 1))
|
||||||
|
out = np.zeros((batch_size, 1)).astype("float32")
|
||||||
|
for i in range(batch_size):
|
||||||
|
code_table = CodeTable(num_classes, label[i])
|
||||||
|
length = code_table.get_length()
|
||||||
|
for j in range(length):
|
||||||
|
idx = code_table.cal_index(j)
|
||||||
|
pre_output[i][j] += bias[0][idx]
|
||||||
|
for i in range(batch_size):
|
||||||
|
code_table = CodeTable(num_classes, label[i])
|
||||||
|
length = code_table.get_length()
|
||||||
|
for j in range(length):
|
||||||
|
idx = code_table.cal_index(j)
|
||||||
|
pre_output[i][j] += np.dot(w[idx], x[i])
|
||||||
|
# clip[-40.0, 40.0]
|
||||||
|
pre_output = np.clip(pre_output, -40.0, 40.0)
|
||||||
|
# out(i, 0) = \sum_j bit(i, j) * preout(i, j)
|
||||||
|
for i in range(batch_size):
|
||||||
|
code_table = CodeTable(num_classes, label[i])
|
||||||
|
length = code_table.get_length()
|
||||||
|
sum = 0.0
|
||||||
|
for j in range(length):
|
||||||
|
if code_table.cal_bit(j):
|
||||||
|
sum += pre_output[i][j]
|
||||||
|
out[i] = -1.0 * sum
|
||||||
|
# soft relu
|
||||||
|
pre_output = np.log(1 + np.exp(pre_output))
|
||||||
|
pre_sum = pre_output.sum(1).reshape((batch_size, 1))
|
||||||
|
out += pre_sum
|
||||||
|
return pre_output, out
|
||||||
|
|
||||||
|
|
||||||
|
class TestHSigmoidOp(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "hierarchical_sigmoid"
|
||||||
|
num_classes = 6
|
||||||
|
feature_size = 8
|
||||||
|
batch_size = 4
|
||||||
|
x = np.random.random((batch_size, feature_size)).astype("float32")
|
||||||
|
w = np.random.random((num_classes - 1, feature_size)).astype("float32")
|
||||||
|
label = np.random.randint(0, num_classes, (batch_size, 1))
|
||||||
|
bias = np.random.random((1, num_classes - 1)).astype("float32")
|
||||||
|
self.attrs = {'num_classes': num_classes}
|
||||||
|
self.inputs = {'X': x, 'W': w, 'Label': label, 'Bias': bias}
|
||||||
|
pre_output, out = hsigmoid(x, w, label, bias, num_classes)
|
||||||
|
self.outputs = {'PreOut': pre_output, 'Out': out}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
self.check_grad(['Bias', 'X', 'W'], ['Out'], no_grad_set=set('Label'))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue