Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_is_taged
commit
f8e83196ff
@ -0,0 +1,167 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
/**
|
||||
* Organize the classes into a binary tree. At each node, a sigmoid function
|
||||
* is used to calculate the probability of belonging to the right branch.
|
||||
* This idea is from "F. Morin, Y. Bengio (AISTATS 05):
|
||||
* Hierarchical Probabilistic Neural Network Language Model."
|
||||
*
|
||||
* Here we uses a simple way of making the binary tree.
|
||||
* Assuming the number of classes C = 6,
|
||||
* The classes are organized as a binary tree in the following way:
|
||||
*
|
||||
* @code{.py}
|
||||
* *-*-*- 2
|
||||
* | | |- 3
|
||||
* | |
|
||||
* | |-*- 4
|
||||
* | |- 5
|
||||
* |
|
||||
* |-*- 0
|
||||
* |- 1
|
||||
* @endcode
|
||||
*
|
||||
* where * indicates an internal node, and each leaf node represents a class.
|
||||
* - Node 0 ... C-2 are internal nodes.
|
||||
* - Node C-1 ... 2C-2 are leaf nodes.
|
||||
* - Class c is represented by leaf node \f$c+C-1\f$.
|
||||
*
|
||||
* We assign an id for each node:
|
||||
* - the id of root be 0.
|
||||
* - the left child of a node i is 2*i+1.
|
||||
* - the right child of a node i is 2*i+2.
|
||||
*
|
||||
* It's easy to see that:
|
||||
* - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$.
|
||||
* - the j-th level ancestor of node i is
|
||||
* \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$.
|
||||
* - A node i is a left child of its parent if \f$(i-1)\%2==0\f$.
|
||||
*
|
||||
*/
|
||||
|
||||
class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
|
||||
"Output(PreOut) should not be null.");
|
||||
const int64_t batch_size = ctx->GetInputDim("X")[0];
|
||||
std::vector<int64_t> output_shape({batch_size, 1});
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AttrType>
|
||||
class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(Tensor, required) The input tensor with shape [N, D], "
|
||||
"where N is the size of mini-batch, and D is the feature size.");
|
||||
AddInput("W",
|
||||
"(Tensor, required), The parameters of hierarchical "
|
||||
"sigmoid operator, each of them is a 2-D tensor, the shape is"
|
||||
"[num_classes - 1, D].");
|
||||
AddInput("Label",
|
||||
"(Tensor, required), The labels of training data. It's a"
|
||||
"tensor with shape [N, 1].");
|
||||
AddInput("Bias",
|
||||
"(Tensor, optional), The bias is a tensor with shape"
|
||||
"[1, num_classes - 1].");
|
||||
AddOutput("Out",
|
||||
"(Tensor, required) The output of hierarchical sigmoid operator."
|
||||
"The shape is [N, 1].");
|
||||
AddOutput("PreOut",
|
||||
"(Tensor, required) A intermedia 2-D tensor with shape "
|
||||
"[batch_size, code_length], where code_length represents the "
|
||||
"maximum path length from root to leaf nodes.")
|
||||
.AsIntermediate();
|
||||
AddAttr<AttrType>("num_classes", "(int, required), The number of classes")
|
||||
.SetDefault(2);
|
||||
AddComment(R"DOC(
|
||||
The hierarchical sigmoid operator organize the classes into a binary tree.
|
||||
At each node, a sigmoid function is used to calculate the probability of
|
||||
belonging to the right branch. This idea is from
|
||||
"F. Morin, Y. Bengio (AISTATS 05):
|
||||
Hierarchical Probabilistic Neural Network Language Model."
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
|
||||
"Input(Preout) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
|
||||
"Output(W@Grad should not be null.)");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")));
|
||||
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("Bias"),
|
||||
ctx->GetInputDim("Bias"));
|
||||
}
|
||||
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
|
||||
ops::HierarchicalSigmoidOpMaker<int>,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
hierarchical_sigmoid,
|
||||
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
|
||||
double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
hierarchical_sigmoid_grad,
|
||||
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
|
||||
float>,
|
||||
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
|
||||
double>);
|
@ -0,0 +1,135 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/clip_op.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/operators/math/matrix_bit_code.h"
|
||||
#include "paddle/fluid/platform/transform.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
using platform::Transform;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in = ctx.Input<framework::Tensor>("X");
|
||||
auto* w = ctx.Input<framework::Tensor>("W");
|
||||
auto* label = ctx.Input<framework::Tensor>("Label");
|
||||
auto* bias = ctx.Input<framework::Tensor>("Bias");
|
||||
auto* out = ctx.Output<framework::Tensor>("Out");
|
||||
auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
|
||||
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
|
||||
int64_t code_length = math::FindLastSet(num_classes - 1);
|
||||
int64_t batch_size = in->dims()[0];
|
||||
framework::Tensor sum;
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
auto* pre_out_data = pre_out->mutable_data<T>(
|
||||
framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
|
||||
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
|
||||
// Not all class(leaf) nodes' path lengths equal code_length, thus init as
|
||||
// 0s can avoid out of path's loss.
|
||||
math::SetConstant<DeviceContext, T> zero;
|
||||
zero(dev_ctx, pre_out, static_cast<T>(0.0));
|
||||
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
math::RowwiseSum<DeviceContext, T> row_sum;
|
||||
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
|
||||
|
||||
std::vector<int64_t> sum_dims({batch_size, 1UL});
|
||||
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
|
||||
auto sum_mat = EigenMatrix<T>::From(sum);
|
||||
out->mutable_data<T>(ctx.GetPlace());
|
||||
auto out_mat = framework::EigenVector<T>::Flatten(*out);
|
||||
if (bias) {
|
||||
bit_code.Add(pre_out, *bias);
|
||||
}
|
||||
bit_code.Mul(pre_out, *w, *in);
|
||||
// clip to [-40, 40]
|
||||
Transform<DeviceContext> trans;
|
||||
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
|
||||
pre_out_data + pre_out->numel(), pre_out_data,
|
||||
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
|
||||
bit_code.Sum(*pre_out, out, static_cast<T>(-1));
|
||||
// use softrelu to calculate cross entropy
|
||||
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
|
||||
row_sum(dev_ctx, *pre_out, &sum);
|
||||
// TODO(guosheng): Subtract the out of path's loss, since not all
|
||||
// class(leaf) nodes' path lengths equal code_length. But it won't break the
|
||||
// gradient check since both have the out of path's loss and will cancel out
|
||||
// each other.
|
||||
out_mat.device(place) = sum_mat + out_mat;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in = ctx.Input<framework::Tensor>("X");
|
||||
auto* w = ctx.Input<framework::Tensor>("W");
|
||||
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
|
||||
auto* bias_grad =
|
||||
ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
|
||||
auto* label = ctx.Input<framework::Tensor>("Label");
|
||||
auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
|
||||
auto* out_grad =
|
||||
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
framework::Tensor pre_out_grad;
|
||||
|
||||
pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace());
|
||||
in_grad->mutable_data<T>(ctx.GetPlace());
|
||||
w_grad->mutable_data<T>(ctx.GetPlace());
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
math::SetConstant<DeviceContext, T> zero;
|
||||
zero(dev_ctx, in_grad, static_cast<T>(0.0));
|
||||
zero(dev_ctx, w_grad, static_cast<T>(0.0));
|
||||
|
||||
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
|
||||
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
|
||||
|
||||
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
|
||||
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
|
||||
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
|
||||
Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
|
||||
|
||||
// softrelu derivative
|
||||
pre_out_grad_mat.device(place) =
|
||||
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
|
||||
bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b)
|
||||
pre_out_grad_mat.device(place) =
|
||||
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
|
||||
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
|
||||
// be consistent with the clipping in forward.
|
||||
if (bias_grad) {
|
||||
bias_grad->mutable_data<T>(ctx.GetPlace());
|
||||
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
|
||||
bit_code.AddGrad(pre_out_grad, bias_grad);
|
||||
}
|
||||
bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
|
||||
bit_code.MulGradError(pre_out_grad, *w, in_grad);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,176 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/matrix_bit_code.h"
|
||||
#include <iostream>
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
|
||||
const framework::Tensor& vec) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t batch_size = tmat->dims()[0];
|
||||
size_t width = tmat->dims()[1];
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
size_t index = code.calc_index(j);
|
||||
tmat->data<T>()[i * width + j] += vec.data<T>()[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
|
||||
framework::Tensor* vec) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t batch_size = tmat.dims()[0];
|
||||
size_t width = tmat.dims()[1];
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
size_t index = code.calc_index(j);
|
||||
vec->data<T>()[index] += tmat.data<T>()[i * width + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
|
||||
framework::Tensor* sum, T scale_sum) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t num_samples = tmat.dims()[0];
|
||||
size_t o_width = tmat.dims()[1];
|
||||
for (size_t i = 0; i < num_samples; ++i) {
|
||||
T sm = static_cast<T>(0.0);
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
if (code.calc_bit(j)) {
|
||||
// calc_bit starts from right most bit, while data in tmat[i] is in the
|
||||
// reverse order.
|
||||
sm += tmat.data<T>()[i * o_width + j];
|
||||
}
|
||||
}
|
||||
sum->data<T>()[i] = scale_sum * sm;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
|
||||
const framework::Tensor& weight,
|
||||
const framework::Tensor& input) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t num_samples = tmat->dims()[0];
|
||||
size_t tmat_width = tmat->dims()[1];
|
||||
size_t input_width = input.dims()[1];
|
||||
size_t weight_width = weight.dims()[1];
|
||||
auto tmat_value = tmat->data<T>();
|
||||
auto weight_value = weight.data<T>();
|
||||
auto input_value = input.data<T>();
|
||||
for (size_t i = 0; i < num_samples; ++i) {
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
size_t index = code.calc_index(j);
|
||||
T sum = static_cast<T>(0.0);
|
||||
for (size_t k = 0; k < input_width; ++k) {
|
||||
sum += weight_value[weight_width * index + k] *
|
||||
input_value[input_width * i + k];
|
||||
}
|
||||
tmat_value[i * tmat_width + j] += sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
|
||||
framework::Tensor* weight,
|
||||
const framework::Tensor& input) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t num_samples = tmat.dims()[0];
|
||||
size_t input_width = input.dims()[1];
|
||||
size_t tmat_width = tmat.dims()[1];
|
||||
size_t weight_width = weight->dims()[1];
|
||||
auto tmat_value = tmat.data<T>();
|
||||
auto weight_value = weight->data<T>();
|
||||
auto input_value = input.data<T>();
|
||||
for (size_t i = 0; i < num_samples; ++i) {
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
size_t index = code.calc_index(j);
|
||||
|
||||
for (size_t k = 0; k < input_width; ++k) {
|
||||
weight_value[weight_width * index + k] +=
|
||||
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
|
||||
const framework::Tensor& weight,
|
||||
framework::Tensor* input) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t num_samples = tmat.dims()[0];
|
||||
size_t tmat_width = tmat.dims()[1];
|
||||
size_t input_width = input->dims()[1];
|
||||
size_t weight_width = weight.dims()[1];
|
||||
auto tmat_value = tmat.data<T>();
|
||||
auto weight_value = weight.data<T>();
|
||||
auto input_value = input->data<T>();
|
||||
|
||||
for (size_t i = 0; i < num_samples; ++i) {
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
size_t index = code.calc_index(j);
|
||||
|
||||
for (size_t k = 0; k < input_width; ++k) {
|
||||
input_value[input_width * i + k] +=
|
||||
tmat_value[i * tmat_width + j] *
|
||||
weight_value[weight_width * index + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t num_samples = tmat->dims()[0];
|
||||
size_t o_width = tmat->dims()[1];
|
||||
for (size_t i = 0; i < num_samples; ++i) {
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
if (code.calc_bit(j)) {
|
||||
tmat->data<T>()[i * o_width + j] -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template class MatrixBitCodeFunctor<float>;
|
||||
template class MatrixBitCodeFunctor<double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,143 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
/**
|
||||
* SimpleCodeTable class should support 3 functions:
|
||||
*
|
||||
* size_t size()
|
||||
* return the number of ids
|
||||
*
|
||||
* int get_max_code_length()
|
||||
* return the maximal code length
|
||||
*
|
||||
* SimpleCode operator()(size_t i)
|
||||
* return the i-th code. Code class is descriebed below.
|
||||
*
|
||||
* SimpleCode class should support 3 functions:
|
||||
*
|
||||
* int get_length()
|
||||
* return the length of the code
|
||||
*
|
||||
* size_t cal_index(int bit)
|
||||
* bit ranges from 0 to get_length() - 1
|
||||
* return the index for the (1+bit) level parent
|
||||
*
|
||||
* bool calc_bit(int bit)
|
||||
* return true if the bit level parent is the right child of (1+bit) level
|
||||
* parent
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* return the 1-based index of the highest bit set
|
||||
*
|
||||
* for x > 0:
|
||||
* \f[
|
||||
* FindLastSet(x) = 1 + \floor*{\log_{2}x}
|
||||
* \f]
|
||||
*/
|
||||
inline constexpr size_t FindLastSet(size_t x) {
|
||||
return std::is_same<size_t, unsigned int>::value
|
||||
? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0)
|
||||
: (std::is_same<size_t, unsigned long>::value // NOLINT
|
||||
? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0)
|
||||
: (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0));
|
||||
}
|
||||
|
||||
struct SimpleCode {
|
||||
SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {}
|
||||
/**
|
||||
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c
|
||||
* is `c + num_classes` and all siblings can get the same weight indice using
|
||||
* prefixes.
|
||||
* Weight index is the prefixes of encoding, thus leave out the right most
|
||||
* bit in calc_index.
|
||||
* Binary classification path is the suffixes of encoding, thus leave out the
|
||||
* left most bit in calc_bit.
|
||||
*/
|
||||
inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
|
||||
inline bool calc_bit(int bit) const { return c_ & (1 << bit); }
|
||||
inline int get_length() const { return FindLastSet(c_) - 1; }
|
||||
|
||||
private:
|
||||
size_t c_;
|
||||
};
|
||||
|
||||
struct SimpleCodeTable {
|
||||
explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {}
|
||||
SimpleCode operator()(size_t code) const {
|
||||
return SimpleCode(code, num_classes_);
|
||||
}
|
||||
size_t size() const { return num_classes_; }
|
||||
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
|
||||
|
||||
private:
|
||||
size_t num_classes_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MatrixBitCodeFunctor {
|
||||
public:
|
||||
explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
|
||||
: num_classes_(num_classes), ids_(ids) {}
|
||||
/* For j < code_length
|
||||
tmat(i, j) += vec(0, index(i, j))
|
||||
*/
|
||||
void Add(framework::Tensor* tmat, const framework::Tensor& vec);
|
||||
|
||||
/* For j < code_length
|
||||
vec(0, index(i, j)) += tmat(i, j)
|
||||
*/
|
||||
void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
|
||||
|
||||
/* For j < code_length
|
||||
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
|
||||
*/
|
||||
void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum);
|
||||
|
||||
/* For j < code_length
|
||||
tmat(i, j) -= bit(i, j)
|
||||
*/
|
||||
void Sub(framework::Tensor* tmat);
|
||||
/* For j < code_length
|
||||
input.row(i) += tmat(i, j) * weight.row(index(i, j))
|
||||
*/
|
||||
void Mul(framework::Tensor* tmat, const framework::Tensor& weight,
|
||||
const framework::Tensor& input);
|
||||
|
||||
/* For index(i, j) >= 0:
|
||||
weight.row(index(i, j)) += tmat(i, j) * input.row(i)
|
||||
*/
|
||||
void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight,
|
||||
const framework::Tensor& input);
|
||||
/* For j < code_length
|
||||
input.row(i) += tmat(i, j) * weight.row(index(i, j))
|
||||
*/
|
||||
void MulGradError(const framework::Tensor& tmat,
|
||||
const framework::Tensor& weight, framework::Tensor* input);
|
||||
|
||||
size_t num_classes_;
|
||||
const int64_t* ids_;
|
||||
};
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,202 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SqueezeOpInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of SqueezeOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of SqueezeOp should not be null.");
|
||||
|
||||
const auto &x_dims = ctx->GetInputDim("X");
|
||||
// Check input tensor dims (<6) Eigen limit.
|
||||
PADDLE_ENFORCE(x_dims.size() <= 6,
|
||||
"Invalid dimnesions, the rank of Input(X) "
|
||||
"should be in the range of [1, 6] (Eigen limit).");
|
||||
|
||||
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
|
||||
for (int a : axes) {
|
||||
PADDLE_ENFORCE_LT(a, x_dims.size(),
|
||||
"The squeeze axis should be less than input "
|
||||
"tensor's rank.");
|
||||
}
|
||||
|
||||
auto out_dims = GetOutputShape(axes, x_dims);
|
||||
ctx->SetOutputDim("Out", out_dims);
|
||||
if (x_dims[0] == out_dims[0]) {
|
||||
// Only pass LoD when the first dimension of output and Input(X)
|
||||
// are the same.
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
}
|
||||
|
||||
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
|
||||
const framework::DDim &in_dims) {
|
||||
size_t num_squeeze_dims = squeeze_dims.size();
|
||||
int cnt_squeezed_dims = 0;
|
||||
bool should_squeeze[9] = {false};
|
||||
|
||||
// Determines number of dimensions of output tensor after squeeze.
|
||||
// Mark and count the dimensions need to be squeezed
|
||||
if (num_squeeze_dims == 0) {
|
||||
for (int idx = 0; idx < in_dims.size(); ++idx) {
|
||||
if (in_dims[idx] == 1) {
|
||||
should_squeeze[idx] = true;
|
||||
++cnt_squeezed_dims;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
|
||||
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
|
||||
: squeeze_dims[idx];
|
||||
// Check current index, the upper limit has beed checked in line 36.
|
||||
PADDLE_ENFORCE(current >= 0,
|
||||
"Invalid axis, the negative axis is out of range.");
|
||||
PADDLE_ENFORCE(in_dims[current] == 1,
|
||||
"Invalid axis index, the axis that will be squeezed "
|
||||
"should be equal to 1.");
|
||||
|
||||
if (!(should_squeeze[current])) {
|
||||
++cnt_squeezed_dims;
|
||||
}
|
||||
should_squeeze[current] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Make output dimensions
|
||||
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
|
||||
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
|
||||
if (!should_squeeze[in_idx]) {
|
||||
output_shape[out_idx++] = in_dims[in_idx];
|
||||
}
|
||||
}
|
||||
|
||||
return framework::make_ddim(output_shape);
|
||||
}
|
||||
};
|
||||
|
||||
class SqueezeOp : public framework::OperatorBase {
|
||||
public:
|
||||
using OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &place) const override {
|
||||
auto &axes = Attr<std::vector<int>>("axes");
|
||||
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
|
||||
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims);
|
||||
|
||||
framework::AttributeMap attrs;
|
||||
attrs["shape"] = framework::vectorize2int(out_dims);
|
||||
attrs["inplace"] = Attr<bool>("inplace");
|
||||
// Invoke Reshape Op
|
||||
auto reshape_op = framework::OpRegistry::CreateOp(
|
||||
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
|
||||
{{"Out", {Output("Out")}}}, attrs);
|
||||
reshape_op->Run(scope, place);
|
||||
}
|
||||
};
|
||||
|
||||
class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor). The input tensor of squeeze operator.");
|
||||
AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
|
||||
AddAttr<std::vector<int>>("axes",
|
||||
"(std::vector<int>). List of integers,"
|
||||
" indicating the dimensions to squeeze.")
|
||||
.SetDefault({});
|
||||
AddAttr<bool>("inplace",
|
||||
"(default: false) Squeeze the source tensor's shape without "
|
||||
"memory copy. When Attr(inplace) is set true, the output "
|
||||
"tensor shares memory with Input(X), otherwise, a new output "
|
||||
"tensor is created, and its data are copied from Input(x).")
|
||||
.SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
Squeeze Operator.
|
||||
|
||||
Remove single-dimensional entries from the shape of a tensor.
|
||||
Takes a parameter axes with a list of axes to squeeze.
|
||||
If axes is not provided, all the single dimensions will be removed from the shape.
|
||||
If an axis is selected with shape entry not equal to one, an error is raised.
|
||||
|
||||
Examples:
|
||||
Case 1:
|
||||
Given
|
||||
X.shape = (1, 3, 1, 5)
|
||||
and
|
||||
axes = [0]
|
||||
we get:
|
||||
Out.shape = (3, 1, 5)
|
||||
|
||||
Case 2:
|
||||
Given
|
||||
X.shape = (1, 3, 1, 5)
|
||||
and
|
||||
axes = []
|
||||
we get:
|
||||
Out.shape = (3, 5)
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class SqueezeGradInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *context) const override {
|
||||
context->SetOutputDim(framework::GradVarName("X"),
|
||||
context->GetInputDim("X"));
|
||||
context->ShareLoD("X", framework::GradVarName("X"));
|
||||
}
|
||||
};
|
||||
|
||||
class SqueezeGradOp : public framework::OperatorBase {
|
||||
public:
|
||||
using OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &place) const override {
|
||||
auto dx_name = Output(framework::GradVarName("X"));
|
||||
auto dout_name = Input(framework::GradVarName("Out"));
|
||||
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
|
||||
framework::AttributeMap attrs;
|
||||
attrs["shape"] = framework::vectorize2int(x_dims);
|
||||
attrs["inplace"] = Attr<bool>("inplace");
|
||||
|
||||
auto reshape_op = framework::OpRegistry::CreateOp(
|
||||
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
|
||||
attrs);
|
||||
reshape_op->Run(scope, place);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
// Tell linker to use reshape op
|
||||
USE_OP(reshape);
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
|
||||
ops::SqueezeOpInferShape,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape);
|
@ -0,0 +1,191 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class UnsqueezeOpInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of UnsqueezeOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of UnsqueezeOp should not be null.");
|
||||
|
||||
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
|
||||
const auto &x_dims = ctx->GetInputDim("X");
|
||||
// Validity Check: input tensor dims (<6).
|
||||
PADDLE_ENFORCE(x_dims.size() <= 6,
|
||||
"Invalid dimensions, the rank of Input(X) "
|
||||
"should be in the range of [1, 6] (Eigen limit)");
|
||||
auto out_dims = GetOutputShape(axes, x_dims);
|
||||
ctx->SetOutputDim("Out", out_dims);
|
||||
if (x_dims[0] == out_dims[0]) {
|
||||
// Only pass LoD when the first dimension of output and Input(X)
|
||||
// are the same.
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
}
|
||||
|
||||
static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
|
||||
const framework::DDim &in_dims) {
|
||||
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
|
||||
int cur_output_size = in_dims.size();
|
||||
std::vector<int64_t> output_shape(output_size, 0);
|
||||
|
||||
// Validity Check: rank range.
|
||||
PADDLE_ENFORCE(output_size <= 6,
|
||||
"The output tensor's rank should be less than 6.");
|
||||
|
||||
for (int axis : unsqz_dims) {
|
||||
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
|
||||
// Vaildity Check: the axis bound
|
||||
PADDLE_ENFORCE(
|
||||
cur >= 0 && cur <= cur_output_size,
|
||||
"The unsqueeze dims must be within range of current rank.");
|
||||
// Move old axis, and insert new axis
|
||||
for (int i = cur_output_size; i >= cur; --i) {
|
||||
if (output_shape[i] == 1) {
|
||||
// Move axis
|
||||
output_shape[i + 1] = 1;
|
||||
output_shape[i] = 0;
|
||||
}
|
||||
}
|
||||
output_shape[cur] = 1;
|
||||
// Add the output size.
|
||||
cur_output_size++;
|
||||
}
|
||||
|
||||
// Make output shape
|
||||
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
|
||||
if (output_shape[out_idx] == 0) {
|
||||
output_shape[out_idx] = in_dims[in_idx++];
|
||||
}
|
||||
}
|
||||
|
||||
return framework::make_ddim(output_shape);
|
||||
}
|
||||
};
|
||||
|
||||
class UnsqueezeOp : public framework::OperatorBase {
|
||||
public:
|
||||
using OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &place) const override {
|
||||
auto &axes = Attr<std::vector<int>>("axes");
|
||||
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
|
||||
auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims);
|
||||
|
||||
framework::AttributeMap attrs;
|
||||
attrs["shape"] = framework::vectorize2int(out_dims);
|
||||
attrs["inplace"] = Attr<bool>("inplace");
|
||||
// Invoke Reshape op.
|
||||
auto reshape_op = framework::OpRegistry::CreateOp(
|
||||
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
|
||||
{{"Out", {Output("Out")}}}, attrs);
|
||||
reshape_op->Run(scope, place);
|
||||
}
|
||||
};
|
||||
|
||||
class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor). The input tensor of unsqueeze operator.");
|
||||
AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator.");
|
||||
AddAttr<std::vector<int>>("axes",
|
||||
"(std::vector<int>). List of integers,"
|
||||
" indicating the dimensions to be inserted")
|
||||
.AddCustomChecker([](const std::vector<int> &axes) {
|
||||
PADDLE_ENFORCE(!axes.empty(),
|
||||
"Invalid axes, The unsqueeze axes is empty.");
|
||||
// Validity Check: axes dims (<6).
|
||||
PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6,
|
||||
"Invalid dimensions, dynamic dimensions should be "
|
||||
"within [1, 6] dimensions (Eigen limit).");
|
||||
// Validity Check: the range of unsqueeze aixs.
|
||||
for (int axis : axes) {
|
||||
PADDLE_ENFORCE(axis < 6,
|
||||
"Invalid dimensions, input axis should be"
|
||||
" within [1, 6] dimensions (Eigen limit).");
|
||||
}
|
||||
});
|
||||
AddAttr<bool>(
|
||||
"inplace",
|
||||
"(default: false) Unsqueeze the source tensor's shape without "
|
||||
"memory copy. When Attr(inplace) is set true, the output "
|
||||
"tensor shares memory with Input(X), otherwise, a new output "
|
||||
"tensor is created, and its data are copied from Input(x).")
|
||||
.SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
Unsqueeze Operator.
|
||||
|
||||
Insert single-dimensional entries to the shape of a tensor.
|
||||
Takes one required argument axes, a list of dimensions that will be inserted.
|
||||
Dimension indices in axes are as seen in the output tensor.
|
||||
|
||||
For example:
|
||||
Given a tensor such that tensor with shape [3, 4, 5],
|
||||
then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class UnsqueezeGradInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *ctx) const override {
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
ctx->ShareLoD("X", framework::GradVarName("X"));
|
||||
}
|
||||
};
|
||||
|
||||
class UnsqueezeGradOp : public framework::OperatorBase {
|
||||
public:
|
||||
using OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &place) const override {
|
||||
auto dx_name = Output(framework::GradVarName("X"));
|
||||
auto dout_name = Input(framework::GradVarName("Out"));
|
||||
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
|
||||
|
||||
framework::AttributeMap attrs;
|
||||
attrs["shape"] = framework::vectorize2int(x_dims);
|
||||
attrs["inplace"] = Attr<bool>("inplace");
|
||||
|
||||
auto reshape_op = framework::OpRegistry::CreateOp(
|
||||
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
|
||||
attrs);
|
||||
reshape_op->Run(scope, place);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
// Tell linker to use reshape op.
|
||||
USE_OP(reshape);
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
|
||||
ops::UnsqueezeOpInferShape,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
|
||||
ops::UnsqueezeGradInferShape);
|
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import sys
|
||||
|
||||
__all__ = ['deprecated']
|
||||
|
||||
|
||||
def deprecated(since, instead, extra_message=""):
|
||||
def decorator(func):
|
||||
err_msg = "API {0} is deprecated since {1}. Please use {2} instead.".format(
|
||||
func.__name__, since, instead)
|
||||
if len(extra_message) != 0:
|
||||
err_msg += "\n"
|
||||
err_msg += extra_message
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
print >> sys.stderr, err_msg
|
||||
return func(*args, **kwargs)
|
||||
|
||||
wrapper.__doc__ += "\n "
|
||||
wrapper.__doc__ += err_msg
|
||||
return wrapper
|
||||
|
||||
return decorator
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue