From 6cbeafb6c06da95f45c73d684423a127de95908a Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Wed, 19 Aug 2020 18:26:43 +0800 Subject: [PATCH] add zero norm, inf norm support for p_norm op (#26364) * add zero norm, inf norm support for p_norm op * fix the invalid argument check, fix the dtype problem in test case. --- paddle/fluid/operators/p_norm_op.cc | 79 +++++++----- paddle/fluid/operators/p_norm_op.cu | 118 ++++++++++++++++-- paddle/fluid/operators/p_norm_op.h | 34 +++-- .../fluid/tests/unittests/test_norm_all.py | 79 +++++++++++- 4 files changed, 257 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/operators/p_norm_op.cc b/paddle/fluid/operators/p_norm_op.cc index 057a7a38e3..aa39821051 100644 --- a/paddle/fluid/operators/p_norm_op.cc +++ b/paddle/fluid/operators/p_norm_op.cc @@ -25,34 +25,49 @@ class PnormOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "(Tensor) A tensor of rank >= axis."); AddAttr("porder", - "The porder is the p order vector norm to calculate.") + "(float, default 2) The porder is the p order vector norm " + "to calculate. Available for porder=0, inf, -inf and any " + "real number.") .SetDefault(2.0f); AddAttr("axis", - "The axis on which to apply normalization. If axis < 0, " + "The axis on which to apply norm operation. If axis < 0, " "the dimension to pnorm is rank(X) + axis. -1 is " "the last dimension.") .SetDefault(-1); AddAttr("epsilon", - "(float, default 1e-10) The epsilon value is used " + "(float, default 1e-12) The epsilon value is used " "to avoid division by zero.") .SetDefault(1.0e-12f); AddAttr( "keepdim", - "(bool, default false) Whether to keep the dimensions as the input") + "(bool, default false) Whether to keep the dimensions as the input.") .SetDefault(false); - AddOutput( - "Out", - "(Tensor) Output tensor for the `(sum(x.pow(p)) + epsion).pow(1/p)`"); + AddOutput("Out", "(Tensor) Output result tensor of p-norm"); AddComment(R"DOC( +Pnorm Operator. +Given a tensor X, compute Lp-norm of X. -Given a tensor, apply 2-normalization along the provided axis. +When p = 0, defining $0^0 = 0$, the zero-norm of X is simply the number of non-zero elements of X. +$$ +||X||_{0} = \lim_{p \rightarrow 0} \sum_i |x_i|^p +$$ + +When p = inf, the inf-norm of X is the maximum element of X. +$$ +||X||_\infty = \max_i |x_i| +$$ + +When p = -inf, the negative-inf-norm of X is the minimum element of X. +$$ +||X||_{-\infty} = \min_i |x_i| +$$ +Otherwise, the p-norm of X follows the formula, $$ -pnorm = \(\sum_i {abs\(x_i\)^p} \)^{1/p} +||X||_{p} = (\sum_i |x_i|^p)^{1/p} $$ +where, $\sum_i $ is calculated along the `axis` dimension. -where, $\sum_i{x_i^p}$ is calculated along the `axis` dimension. - )DOC"); } }; @@ -63,31 +78,33 @@ class PnormOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "p_norm"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "p_norm"); - auto porder = ctx->Attrs().Get("porder"); - PADDLE_ENFORCE_NE(porder, INFINITY, - platform::errors::Unimplemented( - "The input porder of p_norm is not support for " - "porder == 0, INFINITY, -INFINITY now.")); - PADDLE_ENFORCE_NE(porder, -INFINITY, - platform::errors::Unimplemented( - "The input porder of p_norm is not support for " - "porder == 0, INFINITY, -INFINITY now.")); - PADDLE_ENFORCE_GT(porder, 0.0f, - platform::errors::InvalidArgument( - "The input porder of p_norm is not support for " - "porder <= 0, But received porder=%f.", - porder)); - auto xdim = ctx->GetInputDim("X"); + auto x_dim = ctx->GetInputDim("X"); + auto x_rank = x_dim.size(); int axis = ctx->Attrs().Get("axis"); bool keepdim = ctx->Attrs().Get("keepdim"); - if (axis < 0) axis = xdim.size() + axis; + + PADDLE_ENFORCE_GE(axis, -x_rank, + platform::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], R is " + "the rank of Input(X). But received axis: %d, R: %d. " + "Current Input(X)'s shape is=[%s].", + axis, x_rank, x_dim)); + PADDLE_ENFORCE_LT(axis, x_rank, + platform::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], R is " + "the rank of Input(X). But received axis: %d, R: %d. " + "Current Input(X)'s shape is=[%s].", + axis, x_rank, x_dim)); + + if (axis < 0) axis = x_dim.size() + axis; std::vector reduce_dims; - for (int i = 0; i < xdim.size(); ++i) { - if (i != axis) reduce_dims.emplace_back(xdim[i]); + for (int i = 0; i < x_dim.size(); ++i) { + if (i != axis) reduce_dims.emplace_back(x_dim[i]); } - xdim[axis] = 1; + x_dim[axis] = 1; + if (keepdim) { - ctx->SetOutputDim("Out", xdim); + ctx->SetOutputDim("Out", x_dim); } else { ctx->SetOutputDim("Out", framework::make_ddim(reduce_dims)); } diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index d9ac98ff88..63f2a1c56c 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -49,20 +49,70 @@ __global__ void Pnorm(const T* x, const int pre, for (int i = blockIdx.x; i < num; i += gridDim.x) { int base = (i / post) * post * axis_n + (i % post); - T sum = 0.0; - __shared__ T norm; for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { const T x_ij = x[base + j * post]; sum += inline_pow(inline_abs(x_ij), porder_t); } T reduce_result = BlockReduce(temp_storage).Sum(sum); + if (threadIdx.x == 0) out_norm[i] = inline_pow(reduce_result, porder_inv); + } +} - if (threadIdx.x == 0) { - norm = inline_pow(reduce_result, porder_inv); - out_norm[i] = norm; +template +__global__ void ZeorNorm(const T* x, const int pre, + const int axis_n, // dim in axis + const int post, T* out_norm) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + int num = pre * post; + for (int i = blockIdx.x; i < num; i += gridDim.x) { + int base = (i / post) * post * axis_n + (i % post); + T sum = 0.0; + for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { + const T x_ij = x[base + j * post]; + sum += static_cast(x_ij != 0); } - __syncthreads(); + T reduce_result = BlockReduce(temp_storage).Sum(sum); + if (threadIdx.x == 0) out_norm[i] = reduce_result; + } +} + +template +__global__ void InfNorm(const T* x, const int pre, + const int axis_n, // dim in axis + const int post, T* out_norm) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + int num = pre * post; + for (int i = blockIdx.x; i < num; i += gridDim.x) { + int base = (i / post) * post * axis_n + (i % post); + T cur_max = inline_abs(x[base]); + for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { + T x_ij_abs = inline_abs(x[base + j * post]); + if (cur_max < x_ij_abs) cur_max = x_ij_abs; + } + T reduce_result = BlockReduce(temp_storage).Reduce(cur_max, cub::Max()); + if (threadIdx.x == 0) out_norm[i] = reduce_result; + } +} + +template +__global__ void NegInfNorm(const T* x, const int pre, + const int axis_n, // dim in axis + const int post, T* out_norm) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + int num = pre * post; + for (int i = blockIdx.x; i < num; i += gridDim.x) { + int base = (i / post) * post * axis_n + (i % post); + T cur_min = inline_abs(x[base]); + for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { + T x_ij_abs = inline_abs(x[base + j * post]); + if (cur_min > x_ij_abs) cur_min = x_ij_abs; + } + T reduce_result = BlockReduce(temp_storage).Reduce(cur_min, cub::Min()); + if (threadIdx.x == 0) out_norm[i] = reduce_result; } } @@ -89,8 +139,19 @@ class PnormCUDAKernel : public framework::OpKernel { int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); int grid = std::min(max_blocks, pre * post); - Pnorm<<>>(x, pre, n, post, - porder, norm); + if (porder == 0) { + ZeorNorm<<>>(x, pre, n, post, + norm); + } else if (porder == INFINITY) { + InfNorm<<>>(x, pre, n, post, + norm); + } else if (porder == -INFINITY) { + NegInfNorm<<>>(x, pre, n, + post, norm); + } else { + Pnorm<<>>(x, pre, n, post, + porder, norm); + } } }; @@ -112,7 +173,6 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad, pnorm_i = x_norm[i]; yout_i = y_grad[i]; } - __syncthreads(); for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { @@ -125,6 +185,33 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad, } } +template +__global__ void InfNormGradient(const T* x, const T* x_norm, const T* y_grad, + const int pre, const int axis_n, const int post, + T* x_grad) { + int num = pre * post; + for (int i = blockIdx.x; i < num; i += gridDim.x) { + __shared__ T pnorm_i; + __shared__ T yout_i; + auto base = (i / post) * post * axis_n + (i % post); + if (threadIdx.x == 0) { + pnorm_i = x_norm[i]; + yout_i = y_grad[i]; + } + __syncthreads(); + + for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { + int index = base + j * post; + const T x_ij = inline_abs(x[index]); + if (x_ij == pnorm_i) { + x_grad[index] = inline_sign(x[index]) * yout_i; + } else { + x_grad[index] = static_cast(0); + } + } + } +} + template class PnormGradCUDAKernel : public framework::OpKernel { public: @@ -153,8 +240,17 @@ class PnormGradCUDAKernel : public framework::OpKernel { int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); int grid = std::min(max_blocks, pre * post); - PnormGradient<<>>( - x, x_norm, norm_dy, porder, pre, n, post, eps, dx); + if (porder == 0) { + math::SetConstant set_zero; + auto& dev_ctx = ctx.template device_context(); + set_zero(dev_ctx, out_dx, static_cast(0)); + } else if (porder == INFINITY || porder == -INFINITY) { + InfNormGradient<<>>( + x, x_norm, norm_dy, pre, n, post, dx); + } else { + PnormGradient<<>>( + x, x_norm, norm_dy, porder, pre, n, post, eps, dx); + } } }; diff --git a/paddle/fluid/operators/p_norm_op.h b/paddle/fluid/operators/p_norm_op.h index c5bdfe3527..7620d1421e 100644 --- a/paddle/fluid/operators/p_norm_op.h +++ b/paddle/fluid/operators/p_norm_op.h @@ -58,10 +58,20 @@ class PnormKernel : public framework::OpKernel { auto x = x_e.reshape(shape); auto norm = norm_e.reshape(norm_shape); + // p=0 means number of non-zero elements of (x) + // p=inf means the maximum of |x| + // p=-inf means the minimum of |x| + // otherwise, Lp-norm = pow(sum(pow(|x|, p)), 1/p) Eigen::DSizes rdim(1); - auto xp = (x.abs()).pow(porder); - auto sum = xp.sum(rdim); - norm.device(*place) = sum.pow(1.0f / porder); + if (porder == 0) { + norm.device(*place) = (x != x.constant(0)).template cast().sum(rdim); + } else if (porder == INFINITY) { + norm.device(*place) = x.abs().maximum(rdim); + } else if (porder == -INFINITY) { + norm.device(*place) = x.abs().minimum(rdim); + } else { + norm.device(*place) = x.abs().pow(porder).sum(rdim).pow(1.0f / porder); + } } }; @@ -102,10 +112,20 @@ class PnormGradKernel : public framework::OpKernel { Eigen::DSizes rdim(1); Eigen::DSizes bcast(1, n, 1); - dx.device(*place) = (x.abs()).pow(porder - 1.0f); - dx.device(*place) = - dx / ((norm.broadcast(bcast)).pow(porder - 1.0f) + x.constant(eps)); - dx.device(*place) = dx * norm_dy.broadcast(bcast) * x.sign(); + if (porder == 0) { + math::SetConstant set_zero; + auto& dev_ctx = ctx.template device_context(); + set_zero(dev_ctx, out_dx, static_cast(0)); + } else if (porder == INFINITY || porder == -INFINITY) { + dx.device(*place) = + (x.abs() == norm.broadcast(bcast)).template cast() * x.sign() * + norm_dy.broadcast(bcast); + } else { + dx.device(*place) = + (x.abs()).pow(porder - 1.0f) / + ((norm.broadcast(bcast)).pow(porder - 1.0f) + x.constant(eps)); + dx.device(*place) = dx * norm_dy.broadcast(bcast) * x.sign(); + } } }; } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/test_norm_all.py b/python/paddle/fluid/tests/unittests/test_norm_all.py index e6b7a3e760..0d083038c6 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_all.py +++ b/python/paddle/fluid/tests/unittests/test_norm_all.py @@ -23,16 +23,16 @@ import paddle.fluid as fluid def p_norm(x, axis, porder, keepdims=False): if axis is None: axis = -1 - xp = np.power(np.abs(x), porder) - s = np.sum(xp, axis=axis, keepdims=keepdims) - r = np.power(s, 1.0 / porder) + r = np.linalg.norm( + x, ord=porder, axis=axis, keepdims=keepdims).astype(x.dtype) return r def frobenius_norm(x, axis=None, keepdims=False): if isinstance(axis, list): axis = tuple(axis) if axis is None: axis = (-2, -1) - r = np.linalg.norm(x, ord='fro', axis=axis, keepdims=keepdims) + r = np.linalg.norm( + x, ord='fro', axis=axis, keepdims=keepdims).astype(x.dtype) return r @@ -89,6 +89,7 @@ class TestPnormOp(OpTest): 'porder': float(self.porder) } self.outputs = {'Out': norm} + self.gradient = self.calc_gradient() def test_check_output(self): self.check_output() @@ -104,6 +105,34 @@ class TestPnormOp(OpTest): self.keepdim = False self.dtype = "float64" + def calc_gradient(self): + self.attrs = { + 'epsilon': self.epsilon, + 'axis': self.axis, + 'keepdim': self.keepdim, + 'porder': float(self.porder) + } + x = self.inputs["X"] + porder = self.attrs["porder"] + axis = self.attrs["axis"] + if porder == 0: + grad = np.zeros(x.shape).astype(x.dtype) + elif porder in [float("inf"), float("-inf")]: + norm = p_norm(x, axis=axis, porder=porder, keepdims=True) + x_abs = np.abs(x) + grad = np.sign(x) + grad[x_abs != norm] = 0.0 + else: + norm = p_norm(x, axis=axis, porder=porder, keepdims=True) + grad = np.power(norm, 1 - porder) * np.power( + np.abs(x), porder - 1) * np.sign(x) + + numel = 1 + for s in x.shape: + numel *= s + numel /= x.shape[axis] + return [grad.astype(x.dtype) * 1 / numel] + class TestPnormOp2(TestPnormOp): def init_test_case(self): @@ -118,6 +147,45 @@ class TestPnormOp2(TestPnormOp): self.check_grad(['X'], 'Out') +class TestPnormOp3(TestPnormOp): + def init_test_case(self): + self.shape = [3, 20, 3] + self.axis = 2 + self.epsilon = 1e-12 + self.porder = np.inf + self.keepdim = True + self.dtype = "float32" + + def test_check_grad(self): + self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) + + +class TestPnormOp4(TestPnormOp): + def init_test_case(self): + self.shape = [3, 20, 3] + self.axis = 2 + self.epsilon = 1e-12 + self.porder = -np.inf + self.keepdim = True + self.dtype = "float32" + + def test_check_grad(self): + self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) + + +class TestPnormOp5(TestPnormOp): + def init_test_case(self): + self.shape = [3, 20, 3] + self.axis = 2 + self.epsilon = 1e-12 + self.porder = 0 + self.keepdim = True + self.dtype = "float32" + + def test_check_grad(self): + self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) + + def run_out(self, p, axis, shape_x, shape_y, dtype): with fluid.program_guard(fluid.Program()): data1 = fluid.data(name="X", shape=shape_x, dtype=dtype) @@ -170,6 +238,9 @@ class API_NormTest(unittest.TestCase): run_fro(self, p='fro', axis=[0, 1], shape_x=[3, 3, 4], dtype="float64") run_pnorm(self, p=2, axis=None, shape_x=[3, 4], dtype="float32") run_pnorm(self, p=2, axis=1, shape_x=[3, 4], dtype="float64") + run_pnorm(self, p=np.inf, axis=1, shape_x=[3, 4], dtype="float32") + run_pnorm(self, p=-np.inf, axis=1, shape_x=[3, 4], dtype="float64") + run_pnorm(self, p=0, axis=1, shape_x=[3, 4], dtype="float64") def test_name(self): with fluid.program_guard(fluid.Program()):