Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/hide_api_cont
commit
4ff1bde5fb
@ -0,0 +1,167 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/hierarchical_sigmoid_op.h"
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
/**
|
||||
* Organize the classes into a binary tree. At each node, a sigmoid function
|
||||
* is used to calculate the probability of belonging to the right branch.
|
||||
* This idea is from "F. Morin, Y. Bengio (AISTATS 05):
|
||||
* Hierarchical Probabilistic Neural Network Language Model."
|
||||
*
|
||||
* Here we uses a simple way of making the binary tree.
|
||||
* Assuming the number of classes C = 6,
|
||||
* The classes are organized as a binary tree in the following way:
|
||||
*
|
||||
* @code{.py}
|
||||
* *-*-*- 2
|
||||
* | | |- 3
|
||||
* | |
|
||||
* | |-*- 4
|
||||
* | |- 5
|
||||
* |
|
||||
* |-*- 0
|
||||
* |- 1
|
||||
* @endcode
|
||||
*
|
||||
* where * indicates an internal node, and each leaf node represents a class.
|
||||
* - Node 0 ... C-2 are internal nodes.
|
||||
* - Node C-1 ... 2C-2 are leaf nodes.
|
||||
* - Class c is represented by leaf node \f$c+C-1\f$.
|
||||
*
|
||||
* We assign an id for each node:
|
||||
* - the id of root be 0.
|
||||
* - the left child of a node i is 2*i+1.
|
||||
* - the right child of a node i is 2*i+2.
|
||||
*
|
||||
* It's easy to see that:
|
||||
* - the parent of node i is \f$\left\lfloor(i-1)/2\right\rfloor\f$.
|
||||
* - the j-th level ancestor of node i is
|
||||
* \f$\left\lfloor(i+1)/2^{j+1}\right\rfloor - 1\f$.
|
||||
* - A node i is a left child of its parent if \f$(i-1)\%2==0\f$.
|
||||
*
|
||||
*/
|
||||
|
||||
class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
|
||||
"Output(PreOut) should not be null.");
|
||||
const int64_t batch_size = ctx->GetInputDim("X")[0];
|
||||
std::vector<int64_t> output_shape({batch_size, 1});
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AttrType>
|
||||
class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(Tensor, required) The input tensor with shape [N, D], "
|
||||
"where N is the size of mini-batch, and D is the feature size.");
|
||||
AddInput("W",
|
||||
"(Tensor, required), The parameters of hierarchical "
|
||||
"sigmoid operator, each of them is a 2-D tensor, the shape is"
|
||||
"[num_classes - 1, D].");
|
||||
AddInput("Label",
|
||||
"(Tensor, required), The labels of training data. It's a"
|
||||
"tensor with shape [N, 1].");
|
||||
AddInput("Bias",
|
||||
"(Tensor, optional), The bias is a tensor with shape"
|
||||
"[1, num_classes - 1].");
|
||||
AddOutput("Out",
|
||||
"(Tensor, required) The output of hierarchical sigmoid operator."
|
||||
"The shape is [N, 1].");
|
||||
AddOutput("PreOut",
|
||||
"(Tensor, required) A intermedia 2-D tensor with shape "
|
||||
"[batch_size, code_length], where code_length represents the "
|
||||
"maximum path length from root to leaf nodes.")
|
||||
.AsIntermediate();
|
||||
AddAttr<AttrType>("num_classes", "(int, required), The number of classes")
|
||||
.SetDefault(2);
|
||||
AddComment(R"DOC(
|
||||
The hierarchical sigmoid operator organize the classes into a binary tree.
|
||||
At each node, a sigmoid function is used to calculate the probability of
|
||||
belonging to the right branch. This idea is from
|
||||
"F. Morin, Y. Bengio (AISTATS 05):
|
||||
Hierarchical Probabilistic Neural Network Language Model."
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
|
||||
"Input(Preout) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
|
||||
"Output(W@Grad should not be null.)");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")));
|
||||
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("Bias"),
|
||||
ctx->GetInputDim("Bias"));
|
||||
}
|
||||
ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W"));
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
|
||||
ops::HierarchicalSigmoidOpMaker<int>,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
hierarchical_sigmoid,
|
||||
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
|
||||
double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
hierarchical_sigmoid_grad,
|
||||
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
|
||||
float>,
|
||||
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
|
||||
double>);
|
@ -0,0 +1,135 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/clip_op.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
#include "paddle/fluid/operators/math/matrix_bit_code.h"
|
||||
#include "paddle/fluid/platform/transform.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
using platform::Transform;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in = ctx.Input<framework::Tensor>("X");
|
||||
auto* w = ctx.Input<framework::Tensor>("W");
|
||||
auto* label = ctx.Input<framework::Tensor>("Label");
|
||||
auto* bias = ctx.Input<framework::Tensor>("Bias");
|
||||
auto* out = ctx.Output<framework::Tensor>("Out");
|
||||
auto* pre_out = ctx.Output<framework::Tensor>("PreOut");
|
||||
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
|
||||
int64_t code_length = math::FindLastSet(num_classes - 1);
|
||||
int64_t batch_size = in->dims()[0];
|
||||
framework::Tensor sum;
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
auto* pre_out_data = pre_out->mutable_data<T>(
|
||||
framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
|
||||
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
|
||||
// Not all class(leaf) nodes' path lengths equal code_length, thus init as
|
||||
// 0s can avoid out of path's loss.
|
||||
math::SetConstant<DeviceContext, T> zero;
|
||||
zero(dev_ctx, pre_out, static_cast<T>(0.0));
|
||||
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
math::RowwiseSum<DeviceContext, T> row_sum;
|
||||
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
|
||||
|
||||
std::vector<int64_t> sum_dims({batch_size, 1UL});
|
||||
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
|
||||
auto sum_mat = EigenMatrix<T>::From(sum);
|
||||
out->mutable_data<T>(ctx.GetPlace());
|
||||
auto out_mat = framework::EigenVector<T>::Flatten(*out);
|
||||
if (bias) {
|
||||
bit_code.Add(pre_out, *bias);
|
||||
}
|
||||
bit_code.Mul(pre_out, *w, *in);
|
||||
// clip to [-40, 40]
|
||||
Transform<DeviceContext> trans;
|
||||
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
|
||||
pre_out_data + pre_out->numel(), pre_out_data,
|
||||
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
|
||||
bit_code.Sum(*pre_out, out, static_cast<T>(-1));
|
||||
// use softrelu to calculate cross entropy
|
||||
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
|
||||
row_sum(dev_ctx, *pre_out, &sum);
|
||||
// TODO(guosheng): Subtract the out of path's loss, since not all
|
||||
// class(leaf) nodes' path lengths equal code_length. But it won't break the
|
||||
// gradient check since both have the out of path's loss and will cancel out
|
||||
// each other.
|
||||
out_mat.device(place) = sum_mat + out_mat;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in = ctx.Input<framework::Tensor>("X");
|
||||
auto* w = ctx.Input<framework::Tensor>("W");
|
||||
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
auto* w_grad = ctx.Output<framework::Tensor>(framework::GradVarName("W"));
|
||||
auto* bias_grad =
|
||||
ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
|
||||
auto* label = ctx.Input<framework::Tensor>("Label");
|
||||
auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
|
||||
auto* out_grad =
|
||||
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
framework::Tensor pre_out_grad;
|
||||
|
||||
pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace());
|
||||
in_grad->mutable_data<T>(ctx.GetPlace());
|
||||
w_grad->mutable_data<T>(ctx.GetPlace());
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
math::SetConstant<DeviceContext, T> zero;
|
||||
zero(dev_ctx, in_grad, static_cast<T>(0.0));
|
||||
zero(dev_ctx, w_grad, static_cast<T>(0.0));
|
||||
|
||||
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
|
||||
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
|
||||
|
||||
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
|
||||
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
|
||||
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
|
||||
Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
|
||||
|
||||
// softrelu derivative
|
||||
pre_out_grad_mat.device(place) =
|
||||
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
|
||||
bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b)
|
||||
pre_out_grad_mat.device(place) =
|
||||
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
|
||||
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
|
||||
// be consistent with the clipping in forward.
|
||||
if (bias_grad) {
|
||||
bias_grad->mutable_data<T>(ctx.GetPlace());
|
||||
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
|
||||
bit_code.AddGrad(pre_out_grad, bias_grad);
|
||||
}
|
||||
bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
|
||||
bit_code.MulGradError(pre_out_grad, *w, in_grad);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,176 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/math/matrix_bit_code.h"
|
||||
#include <iostream>
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
|
||||
const framework::Tensor& vec) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t batch_size = tmat->dims()[0];
|
||||
size_t width = tmat->dims()[1];
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
size_t index = code.calc_index(j);
|
||||
tmat->data<T>()[i * width + j] += vec.data<T>()[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
|
||||
framework::Tensor* vec) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t batch_size = tmat.dims()[0];
|
||||
size_t width = tmat.dims()[1];
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
size_t index = code.calc_index(j);
|
||||
vec->data<T>()[index] += tmat.data<T>()[i * width + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
|
||||
framework::Tensor* sum, T scale_sum) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t num_samples = tmat.dims()[0];
|
||||
size_t o_width = tmat.dims()[1];
|
||||
for (size_t i = 0; i < num_samples; ++i) {
|
||||
T sm = static_cast<T>(0.0);
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
if (code.calc_bit(j)) {
|
||||
// calc_bit starts from right most bit, while data in tmat[i] is in the
|
||||
// reverse order.
|
||||
sm += tmat.data<T>()[i * o_width + j];
|
||||
}
|
||||
}
|
||||
sum->data<T>()[i] = scale_sum * sm;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
|
||||
const framework::Tensor& weight,
|
||||
const framework::Tensor& input) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t num_samples = tmat->dims()[0];
|
||||
size_t tmat_width = tmat->dims()[1];
|
||||
size_t input_width = input.dims()[1];
|
||||
size_t weight_width = weight.dims()[1];
|
||||
auto tmat_value = tmat->data<T>();
|
||||
auto weight_value = weight.data<T>();
|
||||
auto input_value = input.data<T>();
|
||||
for (size_t i = 0; i < num_samples; ++i) {
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
size_t index = code.calc_index(j);
|
||||
T sum = static_cast<T>(0.0);
|
||||
for (size_t k = 0; k < input_width; ++k) {
|
||||
sum += weight_value[weight_width * index + k] *
|
||||
input_value[input_width * i + k];
|
||||
}
|
||||
tmat_value[i * tmat_width + j] += sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
|
||||
framework::Tensor* weight,
|
||||
const framework::Tensor& input) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t num_samples = tmat.dims()[0];
|
||||
size_t input_width = input.dims()[1];
|
||||
size_t tmat_width = tmat.dims()[1];
|
||||
size_t weight_width = weight->dims()[1];
|
||||
auto tmat_value = tmat.data<T>();
|
||||
auto weight_value = weight->data<T>();
|
||||
auto input_value = input.data<T>();
|
||||
for (size_t i = 0; i < num_samples; ++i) {
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
size_t index = code.calc_index(j);
|
||||
|
||||
for (size_t k = 0; k < input_width; ++k) {
|
||||
weight_value[weight_width * index + k] +=
|
||||
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
|
||||
const framework::Tensor& weight,
|
||||
framework::Tensor* input) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t num_samples = tmat.dims()[0];
|
||||
size_t tmat_width = tmat.dims()[1];
|
||||
size_t input_width = input->dims()[1];
|
||||
size_t weight_width = weight.dims()[1];
|
||||
auto tmat_value = tmat.data<T>();
|
||||
auto weight_value = weight.data<T>();
|
||||
auto input_value = input->data<T>();
|
||||
|
||||
for (size_t i = 0; i < num_samples; ++i) {
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
size_t index = code.calc_index(j);
|
||||
|
||||
for (size_t k = 0; k < input_width; ++k) {
|
||||
input_value[input_width * i + k] +=
|
||||
tmat_value[i * tmat_width + j] *
|
||||
weight_value[weight_width * index + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
|
||||
SimpleCodeTable code_table(num_classes_);
|
||||
size_t num_samples = tmat->dims()[0];
|
||||
size_t o_width = tmat->dims()[1];
|
||||
for (size_t i = 0; i < num_samples; ++i) {
|
||||
auto code = code_table(static_cast<size_t>(ids_[i]));
|
||||
int code_length = code.get_length();
|
||||
for (int j = 0; j < code_length; ++j) {
|
||||
if (code.calc_bit(j)) {
|
||||
tmat->data<T>()[i * o_width + j] -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template class MatrixBitCodeFunctor<float>;
|
||||
template class MatrixBitCodeFunctor<double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,143 @@
|
||||
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
/**
|
||||
* SimpleCodeTable class should support 3 functions:
|
||||
*
|
||||
* size_t size()
|
||||
* return the number of ids
|
||||
*
|
||||
* int get_max_code_length()
|
||||
* return the maximal code length
|
||||
*
|
||||
* SimpleCode operator()(size_t i)
|
||||
* return the i-th code. Code class is descriebed below.
|
||||
*
|
||||
* SimpleCode class should support 3 functions:
|
||||
*
|
||||
* int get_length()
|
||||
* return the length of the code
|
||||
*
|
||||
* size_t cal_index(int bit)
|
||||
* bit ranges from 0 to get_length() - 1
|
||||
* return the index for the (1+bit) level parent
|
||||
*
|
||||
* bool calc_bit(int bit)
|
||||
* return true if the bit level parent is the right child of (1+bit) level
|
||||
* parent
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* return the 1-based index of the highest bit set
|
||||
*
|
||||
* for x > 0:
|
||||
* \f[
|
||||
* FindLastSet(x) = 1 + \floor*{\log_{2}x}
|
||||
* \f]
|
||||
*/
|
||||
inline constexpr size_t FindLastSet(size_t x) {
|
||||
return std::is_same<size_t, unsigned int>::value
|
||||
? (x ? 8 * sizeof(x) - __builtin_clz(x) : 0)
|
||||
: (std::is_same<size_t, unsigned long>::value // NOLINT
|
||||
? (x ? 8 * sizeof(x) - __builtin_clzl(x) : 0)
|
||||
: (x ? 8 * sizeof(x) - __builtin_clzll(x) : 0));
|
||||
}
|
||||
|
||||
struct SimpleCode {
|
||||
SimpleCode(size_t code, size_t num_classes) : c_(code + num_classes) {}
|
||||
/**
|
||||
* Here the id of root shoud be 1 rather than 0, thus the encoding of class c
|
||||
* is `c + num_classes` and all siblings can get the same weight indice using
|
||||
* prefixes.
|
||||
* Weight index is the prefixes of encoding, thus leave out the right most
|
||||
* bit in calc_index.
|
||||
* Binary classification path is the suffixes of encoding, thus leave out the
|
||||
* left most bit in calc_bit.
|
||||
*/
|
||||
inline size_t calc_index(int bit) const { return (c_ >> (bit + 1)) - 1; }
|
||||
inline bool calc_bit(int bit) const { return c_ & (1 << bit); }
|
||||
inline int get_length() const { return FindLastSet(c_) - 1; }
|
||||
|
||||
private:
|
||||
size_t c_;
|
||||
};
|
||||
|
||||
struct SimpleCodeTable {
|
||||
explicit SimpleCodeTable(size_t num_classes) : num_classes_(num_classes) {}
|
||||
SimpleCode operator()(size_t code) const {
|
||||
return SimpleCode(code, num_classes_);
|
||||
}
|
||||
size_t size() const { return num_classes_; }
|
||||
int get_max_code_length() const { return FindLastSet(num_classes_ - 1); }
|
||||
|
||||
private:
|
||||
size_t num_classes_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MatrixBitCodeFunctor {
|
||||
public:
|
||||
explicit MatrixBitCodeFunctor(size_t num_classes, const int64_t* ids)
|
||||
: num_classes_(num_classes), ids_(ids) {}
|
||||
/* For j < code_length
|
||||
tmat(i, j) += vec(0, index(i, j))
|
||||
*/
|
||||
void Add(framework::Tensor* tmat, const framework::Tensor& vec);
|
||||
|
||||
/* For j < code_length
|
||||
vec(0, index(i, j)) += tmat(i, j)
|
||||
*/
|
||||
void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
|
||||
|
||||
/* For j < code_length
|
||||
sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
|
||||
*/
|
||||
void Sum(const framework::Tensor& tmat, framework::Tensor* sum, T scale_sum);
|
||||
|
||||
/* For j < code_length
|
||||
tmat(i, j) -= bit(i, j)
|
||||
*/
|
||||
void Sub(framework::Tensor* tmat);
|
||||
/* For j < code_length
|
||||
input.row(i) += tmat(i, j) * weight.row(index(i, j))
|
||||
*/
|
||||
void Mul(framework::Tensor* tmat, const framework::Tensor& weight,
|
||||
const framework::Tensor& input);
|
||||
|
||||
/* For index(i, j) >= 0:
|
||||
weight.row(index(i, j)) += tmat(i, j) * input.row(i)
|
||||
*/
|
||||
void MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight,
|
||||
const framework::Tensor& input);
|
||||
/* For j < code_length
|
||||
input.row(i) += tmat(i, j) * weight.row(index(i, j))
|
||||
*/
|
||||
void MulGradError(const framework::Tensor& tmat,
|
||||
const framework::Tensor& weight, framework::Tensor* input);
|
||||
|
||||
size_t num_classes_;
|
||||
const int64_t* ids_;
|
||||
};
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,202 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SqueezeOpInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of SqueezeOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of SqueezeOp should not be null.");
|
||||
|
||||
const auto &x_dims = ctx->GetInputDim("X");
|
||||
// Check input tensor dims (<6) Eigen limit.
|
||||
PADDLE_ENFORCE(x_dims.size() <= 6,
|
||||
"Invalid dimnesions, the rank of Input(X) "
|
||||
"should be in the range of [1, 6] (Eigen limit).");
|
||||
|
||||
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
|
||||
for (int a : axes) {
|
||||
PADDLE_ENFORCE_LT(a, x_dims.size(),
|
||||
"The squeeze axis should be less than input "
|
||||
"tensor's rank.");
|
||||
}
|
||||
|
||||
auto out_dims = GetOutputShape(axes, x_dims);
|
||||
ctx->SetOutputDim("Out", out_dims);
|
||||
if (x_dims[0] == out_dims[0]) {
|
||||
// Only pass LoD when the first dimension of output and Input(X)
|
||||
// are the same.
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
}
|
||||
|
||||
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
|
||||
const framework::DDim &in_dims) {
|
||||
size_t num_squeeze_dims = squeeze_dims.size();
|
||||
int cnt_squeezed_dims = 0;
|
||||
bool should_squeeze[9] = {false};
|
||||
|
||||
// Determines number of dimensions of output tensor after squeeze.
|
||||
// Mark and count the dimensions need to be squeezed
|
||||
if (num_squeeze_dims == 0) {
|
||||
for (int idx = 0; idx < in_dims.size(); ++idx) {
|
||||
if (in_dims[idx] == 1) {
|
||||
should_squeeze[idx] = true;
|
||||
++cnt_squeezed_dims;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
|
||||
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
|
||||
: squeeze_dims[idx];
|
||||
// Check current index, the upper limit has beed checked in line 36.
|
||||
PADDLE_ENFORCE(current >= 0,
|
||||
"Invalid axis, the negative axis is out of range.");
|
||||
PADDLE_ENFORCE(in_dims[current] == 1,
|
||||
"Invalid axis index, the axis that will be squeezed "
|
||||
"should be equal to 1.");
|
||||
|
||||
if (!(should_squeeze[current])) {
|
||||
++cnt_squeezed_dims;
|
||||
}
|
||||
should_squeeze[current] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Make output dimensions
|
||||
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
|
||||
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
|
||||
if (!should_squeeze[in_idx]) {
|
||||
output_shape[out_idx++] = in_dims[in_idx];
|
||||
}
|
||||
}
|
||||
|
||||
return framework::make_ddim(output_shape);
|
||||
}
|
||||
};
|
||||
|
||||
class SqueezeOp : public framework::OperatorBase {
|
||||
public:
|
||||
using OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &place) const override {
|
||||
auto &axes = Attr<std::vector<int>>("axes");
|
||||
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
|
||||
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims);
|
||||
|
||||
framework::AttributeMap attrs;
|
||||
attrs["shape"] = framework::vectorize2int(out_dims);
|
||||
attrs["inplace"] = Attr<bool>("inplace");
|
||||
// Invoke Reshape Op
|
||||
auto reshape_op = framework::OpRegistry::CreateOp(
|
||||
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
|
||||
{{"Out", {Output("Out")}}}, attrs);
|
||||
reshape_op->Run(scope, place);
|
||||
}
|
||||
};
|
||||
|
||||
class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor). The input tensor of squeeze operator.");
|
||||
AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
|
||||
AddAttr<std::vector<int>>("axes",
|
||||
"(std::vector<int>). List of integers,"
|
||||
" indicating the dimensions to squeeze.")
|
||||
.SetDefault({});
|
||||
AddAttr<bool>("inplace",
|
||||
"(default: false) Squeeze the source tensor's shape without "
|
||||
"memory copy. When Attr(inplace) is set true, the output "
|
||||
"tensor shares memory with Input(X), otherwise, a new output "
|
||||
"tensor is created, and its data are copied from Input(x).")
|
||||
.SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
Squeeze Operator.
|
||||
|
||||
Remove single-dimensional entries from the shape of a tensor.
|
||||
Takes a parameter axes with a list of axes to squeeze.
|
||||
If axes is not provided, all the single dimensions will be removed from the shape.
|
||||
If an axis is selected with shape entry not equal to one, an error is raised.
|
||||
|
||||
Examples:
|
||||
Case 1:
|
||||
Given
|
||||
X.shape = (1, 3, 1, 5)
|
||||
and
|
||||
axes = [0]
|
||||
we get:
|
||||
Out.shape = (3, 1, 5)
|
||||
|
||||
Case 2:
|
||||
Given
|
||||
X.shape = (1, 3, 1, 5)
|
||||
and
|
||||
axes = []
|
||||
we get:
|
||||
Out.shape = (3, 5)
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class SqueezeGradInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *context) const override {
|
||||
context->SetOutputDim(framework::GradVarName("X"),
|
||||
context->GetInputDim("X"));
|
||||
context->ShareLoD("X", framework::GradVarName("X"));
|
||||
}
|
||||
};
|
||||
|
||||
class SqueezeGradOp : public framework::OperatorBase {
|
||||
public:
|
||||
using OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &place) const override {
|
||||
auto dx_name = Output(framework::GradVarName("X"));
|
||||
auto dout_name = Input(framework::GradVarName("Out"));
|
||||
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
|
||||
framework::AttributeMap attrs;
|
||||
attrs["shape"] = framework::vectorize2int(x_dims);
|
||||
attrs["inplace"] = Attr<bool>("inplace");
|
||||
|
||||
auto reshape_op = framework::OpRegistry::CreateOp(
|
||||
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
|
||||
attrs);
|
||||
reshape_op->Run(scope, place);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
// Tell linker to use reshape op
|
||||
USE_OP(reshape);
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
|
||||
ops::SqueezeOpInferShape,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape);
|
@ -0,0 +1,191 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class UnsqueezeOpInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of UnsqueezeOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of UnsqueezeOp should not be null.");
|
||||
|
||||
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
|
||||
const auto &x_dims = ctx->GetInputDim("X");
|
||||
// Validity Check: input tensor dims (<6).
|
||||
PADDLE_ENFORCE(x_dims.size() <= 6,
|
||||
"Invalid dimensions, the rank of Input(X) "
|
||||
"should be in the range of [1, 6] (Eigen limit)");
|
||||
auto out_dims = GetOutputShape(axes, x_dims);
|
||||
ctx->SetOutputDim("Out", out_dims);
|
||||
if (x_dims[0] == out_dims[0]) {
|
||||
// Only pass LoD when the first dimension of output and Input(X)
|
||||
// are the same.
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
}
|
||||
|
||||
static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
|
||||
const framework::DDim &in_dims) {
|
||||
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
|
||||
int cur_output_size = in_dims.size();
|
||||
std::vector<int64_t> output_shape(output_size, 0);
|
||||
|
||||
// Validity Check: rank range.
|
||||
PADDLE_ENFORCE(output_size <= 6,
|
||||
"The output tensor's rank should be less than 6.");
|
||||
|
||||
for (int axis : unsqz_dims) {
|
||||
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
|
||||
// Vaildity Check: the axis bound
|
||||
PADDLE_ENFORCE(
|
||||
cur >= 0 && cur <= cur_output_size,
|
||||
"The unsqueeze dims must be within range of current rank.");
|
||||
// Move old axis, and insert new axis
|
||||
for (int i = cur_output_size; i >= cur; --i) {
|
||||
if (output_shape[i] == 1) {
|
||||
// Move axis
|
||||
output_shape[i + 1] = 1;
|
||||
output_shape[i] = 0;
|
||||
}
|
||||
}
|
||||
output_shape[cur] = 1;
|
||||
// Add the output size.
|
||||
cur_output_size++;
|
||||
}
|
||||
|
||||
// Make output shape
|
||||
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
|
||||
if (output_shape[out_idx] == 0) {
|
||||
output_shape[out_idx] = in_dims[in_idx++];
|
||||
}
|
||||
}
|
||||
|
||||
return framework::make_ddim(output_shape);
|
||||
}
|
||||
};
|
||||
|
||||
class UnsqueezeOp : public framework::OperatorBase {
|
||||
public:
|
||||
using OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &place) const override {
|
||||
auto &axes = Attr<std::vector<int>>("axes");
|
||||
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
|
||||
auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims);
|
||||
|
||||
framework::AttributeMap attrs;
|
||||
attrs["shape"] = framework::vectorize2int(out_dims);
|
||||
attrs["inplace"] = Attr<bool>("inplace");
|
||||
// Invoke Reshape op.
|
||||
auto reshape_op = framework::OpRegistry::CreateOp(
|
||||
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
|
||||
{{"Out", {Output("Out")}}}, attrs);
|
||||
reshape_op->Run(scope, place);
|
||||
}
|
||||
};
|
||||
|
||||
class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor). The input tensor of unsqueeze operator.");
|
||||
AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator.");
|
||||
AddAttr<std::vector<int>>("axes",
|
||||
"(std::vector<int>). List of integers,"
|
||||
" indicating the dimensions to be inserted")
|
||||
.AddCustomChecker([](const std::vector<int> &axes) {
|
||||
PADDLE_ENFORCE(!axes.empty(),
|
||||
"Invalid axes, The unsqueeze axes is empty.");
|
||||
// Validity Check: axes dims (<6).
|
||||
PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6,
|
||||
"Invalid dimensions, dynamic dimensions should be "
|
||||
"within [1, 6] dimensions (Eigen limit).");
|
||||
// Validity Check: the range of unsqueeze aixs.
|
||||
for (int axis : axes) {
|
||||
PADDLE_ENFORCE(axis < 6,
|
||||
"Invalid dimensions, input axis should be"
|
||||
" within [1, 6] dimensions (Eigen limit).");
|
||||
}
|
||||
});
|
||||
AddAttr<bool>(
|
||||
"inplace",
|
||||
"(default: false) Unsqueeze the source tensor's shape without "
|
||||
"memory copy. When Attr(inplace) is set true, the output "
|
||||
"tensor shares memory with Input(X), otherwise, a new output "
|
||||
"tensor is created, and its data are copied from Input(x).")
|
||||
.SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
Unsqueeze Operator.
|
||||
|
||||
Insert single-dimensional entries to the shape of a tensor.
|
||||
Takes one required argument axes, a list of dimensions that will be inserted.
|
||||
Dimension indices in axes are as seen in the output tensor.
|
||||
|
||||
For example:
|
||||
Given a tensor such that tensor with shape [3, 4, 5],
|
||||
then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class UnsqueezeGradInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext *ctx) const override {
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
ctx->ShareLoD("X", framework::GradVarName("X"));
|
||||
}
|
||||
};
|
||||
|
||||
class UnsqueezeGradOp : public framework::OperatorBase {
|
||||
public:
|
||||
using OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope &scope,
|
||||
const platform::Place &place) const override {
|
||||
auto dx_name = Output(framework::GradVarName("X"));
|
||||
auto dout_name = Input(framework::GradVarName("Out"));
|
||||
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
|
||||
|
||||
framework::AttributeMap attrs;
|
||||
attrs["shape"] = framework::vectorize2int(x_dims);
|
||||
attrs["inplace"] = Attr<bool>("inplace");
|
||||
|
||||
auto reshape_op = framework::OpRegistry::CreateOp(
|
||||
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
|
||||
attrs);
|
||||
reshape_op->Run(scope, place);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
// Tell linker to use reshape op.
|
||||
USE_OP(reshape);
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
|
||||
ops::UnsqueezeOpInferShape,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
|
||||
ops::UnsqueezeGradInferShape);
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue