Add new norm api, support frobenius norm and p-order vector norm. (#23716)
* Add new norm api, support frobenius norm and p-order vector norm. test==develop * combine test files, add more attr checks. test=developrevert-23830-2.0-beta
parent
15ce8e21d8
commit
03e737aca7
@ -0,0 +1,139 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
Indicesou may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/p_norm_op.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class PnormOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor) A tensor of rank >= axis.");
|
||||
AddAttr<float>("porder",
|
||||
"The porder is the p order vector norm to calculate.")
|
||||
.SetDefault(2.0f);
|
||||
AddAttr<int>("axis",
|
||||
"The axis on which to apply normalization. If axis < 0, "
|
||||
"the dimension to pnorm is rank(X) + axis. -1 is "
|
||||
"the last dimension.")
|
||||
.SetDefault(-1);
|
||||
AddAttr<float>("epsilon",
|
||||
"(float, default 1e-10) The epsilon value is used "
|
||||
"to avoid division by zero.")
|
||||
.SetDefault(1.0e-12f);
|
||||
AddAttr<bool>(
|
||||
"keepdim",
|
||||
"(bool, default false) Whether to keep the dimensions as the input")
|
||||
.SetDefault(false);
|
||||
AddOutput(
|
||||
"Out",
|
||||
"(Tensor) Output tensor for the `(sum(x.pow(p)) + epsion).pow(1/p)`");
|
||||
AddComment(R"DOC(
|
||||
|
||||
Given a tensor, apply 2-normalization along the provided axis.
|
||||
|
||||
$$
|
||||
pnorm = \(\sum_i {abs\(x_i\)^p} \)^{1/p}
|
||||
$$
|
||||
|
||||
where, $\sum_i{x_i^p}$ is calculated along the `axis` dimension.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class PnormOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "p_norm");
|
||||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "p_norm");
|
||||
auto porder = ctx->Attrs().Get<float>("porder");
|
||||
PADDLE_ENFORCE_NE(porder, 0,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input porder of p_norm is not support for "
|
||||
"porder == 0, INFINITY, -INFINITY now."));
|
||||
PADDLE_ENFORCE_NE(porder, INFINITY,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input porder of p_norm is not support for "
|
||||
"porder == 0, INFINITY, -INFINITY now."));
|
||||
PADDLE_ENFORCE_NE(porder, -INFINITY,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input porder of p_norm is not support for "
|
||||
"porder == 0, INFINITY, -INFINITY now."));
|
||||
auto xdim = ctx->GetInputDim("X");
|
||||
int axis = ctx->Attrs().Get<int>("axis");
|
||||
bool keepdim = ctx->Attrs().Get<bool>("keepdim");
|
||||
if (axis < 0) axis = xdim.size() + axis;
|
||||
std::vector<int> reduce_dims;
|
||||
for (int i = 0; i < xdim.size(); ++i) {
|
||||
if (i != axis) reduce_dims.emplace_back(xdim[i]);
|
||||
}
|
||||
xdim[axis] = 1;
|
||||
if (keepdim) {
|
||||
ctx->SetOutputDim("Out", xdim);
|
||||
} else {
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(reduce_dims));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class PnormOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "p_norm");
|
||||
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "p_norm");
|
||||
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
||||
"Out@GRAD", "p_norm");
|
||||
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
|
||||
"X@GRAD", "p_norm");
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class PnormOpGradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType("p_norm_grad");
|
||||
op->SetAttrMap(this->Attrs());
|
||||
op->SetInput("X", this->Input("X"));
|
||||
op->SetInput("Out", this->Output("Out"));
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CPU = paddle::platform::CPUDeviceContext;
|
||||
|
||||
REGISTER_OPERATOR(p_norm, ops::PnormOp, ops::PnormOpMaker,
|
||||
ops::PnormOpGradOpMaker<paddle::framework::OpDesc>,
|
||||
ops::PnormOpGradOpMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OPERATOR(p_norm_grad, ops::PnormOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(p_norm, ops::PnormKernel<CPU, float>,
|
||||
ops::PnormKernel<CPU, double>);
|
||||
REGISTER_OP_CPU_KERNEL(p_norm_grad, ops::PnormGradKernel<CPU, float>,
|
||||
ops::PnormGradKernel<CPU, double>);
|
@ -0,0 +1,180 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
Indicesou may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <algorithm>
|
||||
#include "cub/cub.cuh"
|
||||
#include "paddle/fluid/operators/p_norm_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ int sgn(T val) {
|
||||
return (T(0) < val) - (val < T(0));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float inline_abs(float x) { return abs(x); }
|
||||
__device__ __forceinline__ double inline_abs(double x) { return abs(x); }
|
||||
|
||||
__device__ __forceinline__ int inline_sign(float x) { return sgn<float>(x); }
|
||||
__device__ __forceinline__ int inline_sign(double x) { return sgn<double>(x); }
|
||||
|
||||
__device__ __forceinline__ float inline_pow(float base, float exponent) {
|
||||
return pow(base, exponent);
|
||||
}
|
||||
__device__ __forceinline__ double inline_pow(double base, double exponent) {
|
||||
return pow(base, exponent);
|
||||
}
|
||||
|
||||
template <typename T, int BlockDim>
|
||||
__global__ void Pnorm(const T* x, const int pre,
|
||||
const int axis_n, // dim in axis
|
||||
const int post, float porder, T* out_norm) {
|
||||
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
int num = pre * post;
|
||||
for (int i = blockIdx.x; i < num; i += gridDim.x) {
|
||||
int base = (i / post) * post * axis_n + (i % post);
|
||||
|
||||
T sum = 0.0;
|
||||
__shared__ T norm;
|
||||
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
|
||||
const T x_ij = x[base + j * post];
|
||||
sum += inline_pow(inline_abs(x_ij), porder);
|
||||
}
|
||||
T reduce_result = BlockReduce(temp_storage).Sum(sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
norm = inline_pow(reduce_result, 1.0f / porder);
|
||||
out_norm[i] = norm;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class PnormCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in_x = ctx.Input<framework::Tensor>("X");
|
||||
auto* out_norm = ctx.Output<framework::Tensor>("Out");
|
||||
const T* x = in_x->data<T>();
|
||||
T* norm = out_norm->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto xdim = in_x->dims();
|
||||
auto ndim = out_norm->dims();
|
||||
float porder = ctx.Attr<float>("porder");
|
||||
int axis = ctx.Attr<int>("axis");
|
||||
if (axis < 0) axis = xdim.size() + axis;
|
||||
int pre, n, post;
|
||||
GetDims(xdim, axis, &pre, &n, &post);
|
||||
|
||||
auto& dev_ctx = ctx.cuda_device_context();
|
||||
|
||||
const int block = 512;
|
||||
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
|
||||
const int max_blocks = std::max(max_threads / block, 1);
|
||||
int grid = std::min(max_blocks, pre * post);
|
||||
Pnorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
|
||||
porder, norm);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BlockDim>
|
||||
__global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
|
||||
const float porder, const int pre,
|
||||
const int axis_n, const int post, const T eps,
|
||||
T* x_grad) {
|
||||
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage_sum;
|
||||
// dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x)
|
||||
int num = pre * post;
|
||||
for (int i = blockIdx.x; i < num; i += gridDim.x) {
|
||||
T sum = 0.0;
|
||||
__shared__ T row_sum;
|
||||
__shared__ T row_sqrt_norm;
|
||||
__shared__ T row_norm;
|
||||
|
||||
auto base = (i / post) * post * axis_n + (i % post);
|
||||
|
||||
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
|
||||
int index = base + j * post;
|
||||
sum += x[index] * y_grad[index];
|
||||
}
|
||||
T reduce_result = BlockReduce(temp_storage_sum).Sum(sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
row_sum = reduce_result;
|
||||
row_sqrt_norm = x_norm[i];
|
||||
row_norm = row_sqrt_norm * row_sqrt_norm;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const T pnorm_i = x_norm[i];
|
||||
const T yout_i = y_grad[i];
|
||||
|
||||
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
|
||||
int index = base + j * post;
|
||||
const T x_ij = inline_abs(x[index]);
|
||||
const T dy_ij = y_grad[index];
|
||||
x_grad[index] = inline_pow(x_ij, porder - 1.0f) /
|
||||
(inline_pow(pnorm_i, porder - 1.0f) + eps) * yout_i *
|
||||
inline_sign(x[index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T, typename AttrType = T>
|
||||
class PnormGradCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in_x = ctx.Input<framework::Tensor>("X");
|
||||
auto* in_norm = ctx.Input<framework::Tensor>("Out");
|
||||
auto* in_norm_dy =
|
||||
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
T* dx = out_dx->mutable_data<T>(ctx.GetPlace());
|
||||
const T* x = in_x->data<T>();
|
||||
const T* x_norm = in_norm->data<T>();
|
||||
const T* norm_dy = in_norm_dy->data<T>();
|
||||
|
||||
auto xdim = in_x->dims();
|
||||
float porder = ctx.Attr<float>("porder");
|
||||
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
|
||||
int axis = ctx.Attr<int>("axis");
|
||||
if (axis < 0) axis = xdim.size() + axis;
|
||||
int pre, n, post;
|
||||
GetDims(xdim, axis, &pre, &n, &post);
|
||||
|
||||
auto& dev_ctx = ctx.cuda_device_context();
|
||||
|
||||
const int block = 512;
|
||||
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
|
||||
const int max_blocks = std::max(max_threads / block, 1);
|
||||
int grid = std::min(max_blocks, pre * post);
|
||||
PnormGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
|
||||
x, x_norm, norm_dy, porder, pre, n, post, eps, dx);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CUDA = paddle::platform::CUDADeviceContext;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(p_norm, ops::PnormCUDAKernel<CUDA, float>,
|
||||
ops::PnormCUDAKernel<CUDA, double>);
|
||||
REGISTER_OP_CUDA_KERNEL(p_norm_grad, ops::PnormGradCUDAKernel<CUDA, float>,
|
||||
ops::PnormGradCUDAKernel<CUDA, double>);
|
@ -0,0 +1,112 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
Indicesou may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n,
|
||||
int* post) {
|
||||
*pre = 1;
|
||||
*post = 1;
|
||||
*n = dim[axis];
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
(*pre) *= dim[i];
|
||||
}
|
||||
for (int i = axis + 1; i < dim.size(); ++i) {
|
||||
(*post) *= dim[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class PnormKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in_x = ctx.Input<framework::Tensor>("X");
|
||||
auto* out_norm = ctx.Output<framework::Tensor>("Out");
|
||||
out_norm->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto xdim = in_x->dims();
|
||||
float porder = ctx.Attr<float>("porder");
|
||||
int axis = ctx.Attr<int>("axis");
|
||||
if (axis < 0) axis = xdim.size() + axis;
|
||||
int pre, n, post;
|
||||
GetDims(xdim, axis, &pre, &n, &post);
|
||||
|
||||
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
|
||||
|
||||
Eigen::DSizes<int, 3> shape(pre, n, post);
|
||||
Eigen::DSizes<int, 2> norm_shape(pre, post);
|
||||
|
||||
auto x_e = framework::EigenVector<T>::Flatten(*in_x);
|
||||
auto norm_e = framework::EigenVector<T>::Flatten(*out_norm);
|
||||
|
||||
auto x = x_e.reshape(shape);
|
||||
auto norm = norm_e.reshape(norm_shape);
|
||||
|
||||
Eigen::DSizes<int, 1> rdim(1);
|
||||
auto xp = (x.abs()).pow(porder);
|
||||
auto sum = xp.sum(rdim);
|
||||
norm.device(*place) = sum.pow(1.0f / porder);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T, typename AttrType = T>
|
||||
class PnormGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in_x = ctx.Input<framework::Tensor>("X");
|
||||
auto* in_norm = ctx.Input<framework::Tensor>("Out");
|
||||
auto* in_norm_dy =
|
||||
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
out_dx->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
|
||||
auto xdim = in_x->dims();
|
||||
float porder = ctx.Attr<float>("porder");
|
||||
|
||||
int axis = ctx.Attr<int>("axis");
|
||||
if (axis < 0) axis = xdim.size() + axis;
|
||||
int pre, n, post;
|
||||
GetDims(xdim, axis, &pre, &n, &post);
|
||||
Eigen::DSizes<int, 3> shape(pre, n, post);
|
||||
Eigen::DSizes<int, 3> rshape(pre, 1, post);
|
||||
|
||||
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
|
||||
|
||||
auto x_e = framework::EigenVector<T>::Flatten(*in_x);
|
||||
auto dx_e = framework::EigenVector<T>::Flatten(*out_dx);
|
||||
auto norm_e = framework::EigenVector<T>::Flatten(*in_norm);
|
||||
auto norm_dy_e = framework::EigenVector<T>::Flatten(*in_norm_dy);
|
||||
|
||||
auto x = x_e.reshape(shape);
|
||||
auto dx = dx_e.reshape(shape);
|
||||
auto norm = norm_e.reshape(rshape);
|
||||
auto norm_dy = norm_dy_e.reshape(rshape);
|
||||
|
||||
Eigen::DSizes<int, 1> rdim(1);
|
||||
Eigen::DSizes<int, 3> bcast(1, n, 1);
|
||||
|
||||
dx.device(*place) = (x.abs()).pow(porder - 1.0f);
|
||||
dx.device(*place) =
|
||||
dx / ((norm.broadcast(bcast)).pow(porder - 1.0f) + x.constant(eps));
|
||||
dx.device(*place) = dx * norm_dy.broadcast(bcast) * x.sign();
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,65 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reduce_ops/frobenius_norm_op.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
class FrobeniusNormOpGradMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType("frobenius_norm_grad");
|
||||
op->SetInput("X", this->Input("X"));
|
||||
op->SetInput("Out", this->Output("Out"));
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
op->SetAttrMap(this->Attrs());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
class FrobeniusNormOpMaker : public ops::ReduceOpMaker {
|
||||
protected:
|
||||
virtual std::string GetName() const { return "frobenius_norm"; }
|
||||
virtual std::string GetOpType() const { return "Reduce frobenius_norm"; }
|
||||
};
|
||||
|
||||
REGISTER_OPERATOR(frobenius_norm, ops::ReduceOp, FrobeniusNormOpMaker,
|
||||
ops::FrobeniusNormOpGradMaker<paddle::framework::OpDesc>,
|
||||
ops::FrobeniusNormOpGradMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OPERATOR(frobenius_norm_grad, ops::ReduceGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(frobenius_norm,
|
||||
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
|
||||
float, ops::FrobeniusNormFunctor>,
|
||||
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
|
||||
double, ops::FrobeniusNormFunctor>);
|
||||
|
||||
template <typename T>
|
||||
using CPUFrobeniusNormGradKernel =
|
||||
ops::FrobeniusNormGradKernel<paddle::platform::CPUDeviceContext, T,
|
||||
ops::FrobeniusNormGradFunctor>;
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(frobenius_norm_grad, CPUFrobeniusNormGradKernel<float>,
|
||||
CPUFrobeniusNormGradKernel<double>);
|
@ -0,0 +1,32 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
|
||||
#include "paddle/fluid/operators/reduce_ops/frobenius_norm_op.h"
|
||||
|
||||
template <typename T>
|
||||
using CUDAFrobeniusNormKernel =
|
||||
ops::ReduceKernel<paddle::platform::CUDADeviceContext, T,
|
||||
ops::FrobeniusNormFunctor>;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(frobenius_norm, CUDAFrobeniusNormKernel<float>,
|
||||
CUDAFrobeniusNormKernel<double>);
|
||||
|
||||
template <typename T>
|
||||
using CUDAFrobeniusNormGradKernel =
|
||||
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
|
||||
ops::FrobeniusNormGradFunctor>;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(frobenius_norm_grad, CUDAFrobeniusNormGradKernel<float>,
|
||||
CUDAFrobeniusNormGradKernel<double>);
|
@ -0,0 +1,54 @@
|
||||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// \partial \| X \|_F = \frac{X}{ \| X \|_F }
|
||||
template <typename DeviceContext, typename T, typename Functor>
|
||||
class FrobeniusNormGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
// default use Eigen broadcast
|
||||
ReduceGradKernel<DeviceContext, T, Functor, false> kernel;
|
||||
kernel.Compute(context);
|
||||
}
|
||||
};
|
||||
|
||||
struct FrobeniusNormFunctor {
|
||||
template <typename DeviceContext, typename X, typename Y, typename Dim>
|
||||
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
|
||||
y->device(place) = ((x->square()).sum(dim)).sqrt();
|
||||
}
|
||||
};
|
||||
|
||||
struct FrobeniusNormGradFunctor {
|
||||
template <typename DeviceContext, typename X, typename Y, typename DX,
|
||||
typename DY, typename Dim>
|
||||
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
|
||||
const Dim& dim, int size) {
|
||||
dx->device(place) = y->broadcast(dim);
|
||||
dx->device(place) = *dx + dx->constant(1e-12f);
|
||||
dx->device(place) = (*x / *dx) * (dy->broadcast(dim));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,210 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
def p_norm(x, axis, porder, keepdims=False):
|
||||
if axis is None: axis = -1
|
||||
xp = np.power(np.abs(x), porder)
|
||||
s = np.sum(xp, axis=axis, keepdims=keepdims)
|
||||
r = np.power(s, 1.0 / porder)
|
||||
return r
|
||||
|
||||
|
||||
def frobenius_norm(x, axis=None, keepdims=False):
|
||||
if isinstance(axis, list): axis = tuple(axis)
|
||||
if axis is None: axis = (-2, -1)
|
||||
r = np.linalg.norm(x, ord='fro', axis=axis, keepdims=keepdims)
|
||||
return r
|
||||
|
||||
|
||||
class TestFrobeniusNormOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "frobenius_norm"
|
||||
self.init_test_case()
|
||||
x = (np.random.random(self.shape) + 1.0).astype(self.dtype)
|
||||
norm = frobenius_norm(x, self.axis, self.keepdim)
|
||||
self.reduce_all = (len(self.axis) == len(self.shape))
|
||||
self.inputs = {'X': x}
|
||||
self.attrs = {
|
||||
'dim': list(self.axis),
|
||||
'keep_dim': self.keepdim,
|
||||
'reduce_all': self.reduce_all
|
||||
}
|
||||
self.outputs = {'Out': norm}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
def init_test_case(self):
|
||||
self.shape = [2, 3, 4, 5]
|
||||
self.axis = (1, 2)
|
||||
self.keepdim = False
|
||||
self.dtype = "float64"
|
||||
|
||||
|
||||
class TestFrobeniusNormOp2(TestFrobeniusNormOp):
|
||||
def init_test_case(self):
|
||||
self.shape = [5, 5, 5]
|
||||
self.axis = (0, 1)
|
||||
self.keepdim = True
|
||||
self.dtype = "float32"
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
class TestPnormOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "p_norm"
|
||||
self.init_test_case()
|
||||
x = (np.random.random(self.shape) + 0.5).astype(self.dtype)
|
||||
norm = p_norm(x, self.axis, self.porder, self.keepdim)
|
||||
self.inputs = {'X': x}
|
||||
self.attrs = {
|
||||
'epsilon': self.epsilon,
|
||||
'axis': self.axis,
|
||||
'keepdim': self.keepdim,
|
||||
'porder': float(self.porder)
|
||||
}
|
||||
self.outputs = {'Out': norm}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
def init_test_case(self):
|
||||
self.shape = [2, 3, 4, 5]
|
||||
self.axis = 1
|
||||
self.epsilon = 1e-12
|
||||
self.porder = 2.0
|
||||
self.keepdim = False
|
||||
self.dtype = "float64"
|
||||
|
||||
|
||||
class TestPnormOp2(TestPnormOp):
|
||||
def init_test_case(self):
|
||||
self.shape = [3, 20, 3]
|
||||
self.axis = 2
|
||||
self.epsilon = 1e-12
|
||||
self.porder = 2.0
|
||||
self.keepdim = True
|
||||
self.dtype = "float32"
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
def run_out(self, p, axis, shape_x, shape_y, dtype):
|
||||
with fluid.program_guard(fluid.Program()):
|
||||
data1 = fluid.data(name="X", shape=shape_x, dtype=dtype)
|
||||
data2 = fluid.data(name="Y", shape=shape_y, dtype=dtype)
|
||||
out = paddle.norm(input=data1, p=p, axis=axis, out=data2)
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
result = exe.run(feed={"X": np.random.rand(*shape_x).astype(dtype)},
|
||||
fetch_list=[data2, out])
|
||||
self.assertEqual((result[0] == result[1]).all(), True)
|
||||
|
||||
|
||||
def run_fro(self, p, axis, shape_x, dtype):
|
||||
with fluid.program_guard(fluid.Program()):
|
||||
data = fluid.data(name="X", shape=shape_x, dtype=dtype)
|
||||
out = paddle.norm(input=data, p=p, axis=axis)
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype)
|
||||
expected_result = frobenius_norm(np_input, axis=axis)
|
||||
result, = exe.run(feed={"X": np_input}, fetch_list=[out])
|
||||
self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True)
|
||||
|
||||
|
||||
def run_pnorm(self, p, axis, shape_x, dtype):
|
||||
with fluid.program_guard(fluid.Program()):
|
||||
data = fluid.data(name="X", shape=shape_x, dtype=dtype)
|
||||
out = paddle.norm(input=data, p=p, axis=axis)
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype)
|
||||
expected_result = p_norm(np_input, porder=p, axis=axis).astype(dtype)
|
||||
result, = exe.run(feed={"X": np_input}, fetch_list=[out])
|
||||
self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True)
|
||||
|
||||
|
||||
class API_NormTest(unittest.TestCase):
|
||||
def test_output_result(self):
|
||||
run_out(self, p=2, axis=1, shape_x=[3, 4], shape_y=[3], dtype="float32")
|
||||
run_out(
|
||||
self,
|
||||
p='fro',
|
||||
axis=None,
|
||||
shape_x=[3, 4],
|
||||
shape_y=[1],
|
||||
dtype="float32")
|
||||
|
||||
def test_basic(self):
|
||||
run_fro(self, p='fro', axis=None, shape_x=[3, 3, 4], dtype="float32")
|
||||
run_fro(self, p='fro', axis=[0, 1], shape_x=[3, 3, 4], dtype="float64")
|
||||
run_pnorm(self, p=2, axis=None, shape_x=[3, 4], dtype="float32")
|
||||
run_pnorm(self, p=2, axis=1, shape_x=[3, 4], dtype="float64")
|
||||
|
||||
def test_name(self):
|
||||
with fluid.program_guard(fluid.Program()):
|
||||
x = fluid.data(name="x", shape=[10, 10], dtype="float32")
|
||||
y_1 = paddle.norm(x, p='fro', name='frobenius_name')
|
||||
y_2 = paddle.norm(x, p=2, name='pnorm_name')
|
||||
self.assertEqual(('frobenius_name' in y_1.name), True)
|
||||
self.assertEqual(('pnorm_name' in y_2.name), True)
|
||||
|
||||
def test_errors(self):
|
||||
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
||||
|
||||
def err_dtype(p, shape_x, xdtype, out=None):
|
||||
data = fluid.data(shape=shape_x, dtype=xdtype)
|
||||
paddle.norm(data, p=p, out=out)
|
||||
|
||||
self.assertRaises(TypeError, err_dtype, "fro", [2, 2], "int64")
|
||||
out = fluid.data(name="out", shape=[1], dtype="int64")
|
||||
self.assertRaises(TypeError, err_dtype, "fro", [2, 2], "float64",
|
||||
out)
|
||||
self.assertRaises(TypeError, err_dtype, 2, [10], "int64")
|
||||
self.assertRaises(TypeError, err_dtype, 2, [10], "float64", out)
|
||||
|
||||
data = fluid.data(name="data_2d", shape=[2, 2], dtype="float64")
|
||||
self.assertRaises(ValueError, paddle.norm, data, p="unsupport norm")
|
||||
self.assertRaises(ValueError, paddle.norm, data, p=[1])
|
||||
self.assertRaises(ValueError, paddle.norm, data, p=[1], axis=-1)
|
||||
self.assertRaises(
|
||||
ValueError, paddle.norm, data, p='unspport', axis=[-2, -1])
|
||||
data = fluid.data(name="data_3d", shape=[2, 2, 2], dtype="float64")
|
||||
self.assertRaises(
|
||||
ValueError, paddle.norm, data, p='unspport', axis=[-2, -1])
|
||||
self.assertRaises(
|
||||
ValueError, paddle.norm, data, p='unspport', axis=[-3, -2, -1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue