Norm op support 2-axis (#26492)

revert-26856-strategy_example2
yongqiangma 5 years ago committed by GitHub
parent dc56c89822
commit e4cc6a28b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -42,6 +42,11 @@ class PnormOpMaker : public framework::OpProtoAndCheckerMaker {
"keepdim",
"(bool, default false) Whether to keep the dimensions as the input.")
.SetDefault(false);
AddAttr<bool>("asvector",
"(bool, default false) as vector norm when axis is None and "
"input is matrix, ")
.SetDefault(false);
AddOutput("Out", "(Tensor) Output result tensor of p-norm");
AddComment(R"DOC(
Pnorm Operator.
@ -96,10 +101,15 @@ class PnormOp : public framework::OperatorWithKernel {
"Current Input(X)'s shape is=[%s].",
axis, x_rank, x_dim));
if (axis < 0) axis = x_dim.size() + axis;
std::vector<int> reduce_dims;
for (int i = 0; i < x_dim.size(); ++i) {
if (i != axis) reduce_dims.emplace_back(x_dim[i]);
bool asvector = ctx->Attrs().Get<bool>("asvector");
if (asvector) {
reduce_dims.emplace_back(1);
} else {
if (axis < 0) axis = x_dim.size() + axis;
for (int i = 0; i < x_dim.size(); ++i) {
if (i != axis) reduce_dims.emplace_back(x_dim[i]);
}
}
x_dim[axis] = 1;

@ -129,9 +129,10 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
auto ndim = out_norm->dims();
float porder = ctx.Attr<float>("porder");
int axis = ctx.Attr<int>("axis");
bool asvector = ctx.Attr<bool>("asvector");
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post);
GetDims(xdim, axis, &pre, &n, &post, asvector);
auto& dev_ctx = ctx.cuda_device_context();
@ -230,9 +231,10 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
float porder = ctx.Attr<float>("porder");
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
int axis = ctx.Attr<int>("axis");
bool asvector = ctx.Attr<bool>("asvector");
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post);
GetDims(xdim, axis, &pre, &n, &post, asvector);
auto& dev_ctx = ctx.cuda_device_context();

@ -20,15 +20,19 @@ namespace paddle {
namespace operators {
inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n,
int* post) {
int* post, bool asvector) {
*pre = 1;
*post = 1;
*n = dim[axis];
for (int i = 0; i < axis; ++i) {
(*pre) *= dim[i];
}
for (int i = axis + 1; i < dim.size(); ++i) {
(*post) *= dim[i];
if (asvector) {
*n = product(dim);
} else {
for (int i = 0; i < axis; ++i) {
(*pre) *= dim[i];
}
for (int i = axis + 1; i < dim.size(); ++i) {
(*post) *= dim[i];
}
}
}
@ -43,9 +47,10 @@ class PnormKernel : public framework::OpKernel<T> {
auto xdim = in_x->dims();
float porder = ctx.Attr<float>("porder");
int axis = ctx.Attr<int>("axis");
bool asvector = ctx.Attr<bool>("asvector");
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post);
GetDims(xdim, axis, &pre, &n, &post, asvector);
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
@ -91,9 +96,10 @@ class PnormGradKernel : public framework::OpKernel<T> {
float porder = ctx.Attr<float>("porder");
int axis = ctx.Attr<int>("axis");
bool asvector = ctx.Attr<bool>("asvector");
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post);
GetDims(xdim, axis, &pre, &n, &post, asvector);
Eigen::DSizes<int, 3> shape(pre, n, post);
Eigen::DSizes<int, 3> rshape(pre, 1, post);

@ -14,7 +14,6 @@
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/p_norm_op.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_v2_op.h"

@ -33,6 +33,19 @@ limitations under the License. */
namespace paddle {
namespace operators {
inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n,
int* post) {
*pre = 1;
*post = 1;
*n = dim[axis];
for (int i = 0; i < axis; ++i) {
(*pre) *= dim[i];
}
for (int i = axis + 1; i < dim.size(); ++i) {
(*post) *= dim[i];
}
}
template <typename T, typename Type>
static void FullTopK(Type input_height, Type input_width, int input_dim,
const framework::Tensor* input, T* t_out, Type* t_indices,

@ -22,9 +22,40 @@ import paddle.fluid as fluid
def p_norm(x, axis, porder, keepdims=False):
if axis is None: axis = -1
r = np.linalg.norm(
x, ord=porder, axis=axis, keepdims=keepdims).astype(x.dtype)
r = []
if axis is None:
x = x.flatten()
if porder == np.inf:
r = np.amax(np.abs(x))
elif porder == -np.inf:
r = np.amin(np.abs(x))
else:
r = np.linalg.norm(x, ord=porder)
elif isinstance(axis, list or tuple) and len(axis) == 2:
if porder == np.inf:
axis = tuple(axis)
r = np.amax(np.abs(x), axis=axis, keepdims=keepdims)
elif porder == -np.inf:
axis = tuple(axis)
r = np.amin(np.abs(x), axis=axis, keepdims=keepdims)
elif porder == 0:
axis = tuple(axis)
r = x.astype(bool)
r = np.sum(r, axis)
elif porder == 1:
axis = tuple(axis)
r = np.sum(np.abs(x), axis)
else:
axis = tuple(axis)
xp = np.power(np.abs(x), porder)
s = np.sum(xp, axis=axis, keepdims=keepdims)
r = np.power(s, 1.0 / porder)
else:
if isinstance(axis, list):
axis = tuple(axis)
r = np.linalg.norm(
x, ord=porder, axis=axis, keepdims=keepdims).astype(x.dtype)
return r
@ -186,22 +217,10 @@ class TestPnormOp5(TestPnormOp):
self.check_grad(['X'], 'Out', user_defined_grads=self.gradient)
def run_out(self, p, axis, shape_x, shape_y, dtype):
with fluid.program_guard(fluid.Program()):
data1 = fluid.data(name="X", shape=shape_x, dtype=dtype)
data2 = fluid.data(name="Y", shape=shape_y, dtype=dtype)
out = paddle.norm(input=data1, p=p, axis=axis, out=data2)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(feed={"X": np.random.rand(*shape_x).astype(dtype)},
fetch_list=[data2, out])
self.assertEqual((result[0] == result[1]).all(), True)
def run_fro(self, p, axis, shape_x, dtype):
with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=shape_x, dtype=dtype)
out = paddle.norm(input=data, p=p, axis=axis)
out = paddle.norm(x=data, p=p, axis=axis)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype)
@ -213,35 +232,73 @@ def run_fro(self, p, axis, shape_x, dtype):
def run_pnorm(self, p, axis, shape_x, dtype):
with fluid.program_guard(fluid.Program()):
data = fluid.data(name="X", shape=shape_x, dtype=dtype)
out = paddle.norm(input=data, p=p, axis=axis)
out = paddle.norm(x=data, p=p, axis=axis)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
np_input = (np.random.rand(*shape_x) + 1.0).astype(dtype)
expected_result = p_norm(np_input, porder=p, axis=axis).astype(dtype)
result, = exe.run(feed={"X": np_input}, fetch_list=[out])
self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True)
self.assertEqual((np.abs(result - expected_result) < 1e-6).all(), True)
def run_graph(self, p, axis, shape_x, dtype):
paddle.disable_static()
shape = [2, 3, 4]
np_input = np.arange(24).astype('float32') - 12
np_input = np_input.reshape(shape)
x = paddle.to_tensor(np_input)
#[[[-12. -11. -10. -9.] [ -8. -7. -6. -5.] [ -4. -3. -2. -1.]]
# [[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.]]]
out_pnorm = paddle.norm(x, p=2, axis=-1)
# compute frobenius norm along last two dimensions.
out_fro = paddle.norm(x, p='fro')
out_fro = paddle.norm(x, p='fro', axis=[0, 1])
# compute 2-order norm along [0,1] dimension.
out_pnorm = paddle.norm(x, p=2, axis=[0, 1])
out_pnorm = paddle.norm(x, p=2)
#out_pnorm = [17.43559577 16.91153453 16.73320053 16.91153453]
# compute inf-order norm
out_pnorm = paddle.norm(x, p=np.inf)
#out_pnorm = [12.]
out_pnorm = paddle.norm(x, p=np.inf, axis=0)
#out_pnorm = [[0. 1. 2. 3.] [4. 5. 6. 5.] [4. 3. 2. 1.]]
# compute -inf-order norm
out_pnorm = paddle.norm(x, p=-np.inf)
#out_pnorm = [0.]
out_pnorm = paddle.norm(x, p=-np.inf, axis=0)
# out_fro = [17.43559577 16.91153453 16.73320053 16.91153453]
paddle.enable_static()
class API_NormTest(unittest.TestCase):
def test_output_result(self):
run_out(self, p=2, axis=1, shape_x=[3, 4], shape_y=[3], dtype="float32")
run_out(
self,
p='fro',
axis=None,
shape_x=[3, 4],
shape_y=[1],
dtype="float32")
def test_basic(self):
run_fro(self, p='fro', axis=None, shape_x=[3, 3, 4], dtype="float32")
run_fro(self, p='fro', axis=[0, 1], shape_x=[3, 3, 4], dtype="float64")
run_fro(self, p='fro', axis=None, shape_x=[2, 3, 4], dtype="float32")
run_fro(self, p='fro', axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=2, axis=None, shape_x=[3, 4], dtype="float32")
run_pnorm(self, p=2, axis=1, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=np.inf, axis=1, shape_x=[3, 4], dtype="float32")
run_pnorm(self, p=-np.inf, axis=1, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=np.inf, axis=0, shape_x=[2, 3, 4], dtype="float32")
run_pnorm(self, p=np.inf, axis=None, shape_x=[2, 3, 4], dtype="float32")
run_pnorm(self, p=-np.inf, axis=0, shape_x=[2, 3, 4], dtype="float64")
run_pnorm(
self, p=-np.inf, axis=None, shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=0, axis=1, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=1, axis=1, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=0, axis=None, shape_x=[3, 4], dtype="float64")
run_pnorm(self, p=2, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=2, axis=-1, shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=1, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(self, p=0, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(
self, p=np.inf, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
run_pnorm(
self, p=-np.inf, axis=[0, 1], shape_x=[2, 3, 4], dtype="float64")
def test_dygraph(self):
run_graph(self, p='fro', axis=None, shape_x=[2, 3, 4], dtype="float32")
def test_name(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[10, 10], dtype="float32")
@ -268,11 +325,7 @@ class API_NormTest(unittest.TestCase):
self.assertRaises(ValueError, paddle.norm, data, p="unsupport norm")
self.assertRaises(ValueError, paddle.norm, data, p=[1])
self.assertRaises(ValueError, paddle.norm, data, p=[1], axis=-1)
self.assertRaises(
ValueError, paddle.norm, data, p='unspport', axis=[-2, -1])
data = fluid.data(name="data_3d", shape=[2, 2, 2], dtype="float64")
self.assertRaises(
ValueError, paddle.norm, data, p='unspport', axis=[-2, -1])
self.assertRaises(
ValueError, paddle.norm, data, p='unspport', axis=[-3, -2, -1])

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save