update linspace, equal operators to API 2.0 (#23274)
* update linspace, equal operators to API 2.0, test=develop * equal support higher performance CUDA kernel, test=develop * update comment of equal&linspace operator, test=develop * update comment of equal&linspace operator, test=developrevert-23830-2.0-beta
parent
03deb41d73
commit
a2e10930cf
@ -0,0 +1,150 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/controlflow/compare_reduce_op.h"
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename Functor>
|
||||
class CompareReduceOpKernel
|
||||
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
using T = typename Functor::ELEM_TYPE;
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
auto* x = context.Input<Tensor>("X");
|
||||
auto* y = context.Input<Tensor>("Y");
|
||||
auto* z = context.Output<Tensor>("Out");
|
||||
int axis = context.Attr<int>("axis");
|
||||
|
||||
Tensor tmp;
|
||||
framework::DDim x_dims = x->dims();
|
||||
framework::DDim y_dims = y->dims();
|
||||
int max_dim = std::max(x_dims.size(), y_dims.size());
|
||||
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
|
||||
std::vector<int> x_dims_array(max_dim);
|
||||
std::vector<int> y_dims_array(max_dim);
|
||||
std::vector<int> tmp_dims_array(max_dim);
|
||||
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
|
||||
y_dims_array.data(), tmp_dims_array.data(), max_dim,
|
||||
axis);
|
||||
tmp.mutable_data<bool>(framework::make_ddim(tmp_dims_array),
|
||||
context.GetPlace());
|
||||
|
||||
if (x->numel() == 1 && y->numel() == 1) {
|
||||
bool* z_data = tmp.mutable_data<bool>(context.GetPlace());
|
||||
z_data[0] = Functor()(x->data<T>()[0], y->data<T>()[0]);
|
||||
} else {
|
||||
ElementwiseComputeEx<Functor, platform::CPUDeviceContext, T, bool>(
|
||||
context, x, y, axis, Functor(), &tmp);
|
||||
}
|
||||
|
||||
// Reduce by 'logical and' operator
|
||||
z->mutable_data<bool>(context.GetPlace());
|
||||
auto ipt = framework::EigenVector<bool>::Flatten(tmp);
|
||||
auto out = framework::EigenScalar<bool>::From(*z);
|
||||
auto& place = *context.template device_context<platform::CPUDeviceContext>()
|
||||
.eigen_device();
|
||||
auto reduce_dim = Eigen::array<int, 1>({{0}});
|
||||
out.device(place) = ipt.all(reduce_dim);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpComment>
|
||||
class CompareReduceOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
OpComment comment;
|
||||
AddInput("X", string::Sprintf("the left hand operand of %s operator",
|
||||
comment.type));
|
||||
AddInput("Y", string::Sprintf("the right hand operand of %s operator",
|
||||
comment.type));
|
||||
AddAttr<int>(
|
||||
"axis",
|
||||
"The start dimension index for broadcasting Y onto X. [default -1]")
|
||||
.SetDefault(-1)
|
||||
.EqualGreaterThan(-1);
|
||||
AddOutput("Out", string::Sprintf(
|
||||
"tensor with a bool element. If all "
|
||||
"element %s, the Out tensor is [True], else [False]",
|
||||
comment.equation));
|
||||
AddComment(string::Sprintf(R"DOC(
|
||||
It operates element-wise on X and Y, and returns the Out. X, Y is a
|
||||
N-dim tensor, which could be any type. If all element $%s$, the Out tensor
|
||||
is [True], else [False]
|
||||
)DOC",
|
||||
comment.equation));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpComment>
|
||||
class CompareReduceOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* context) const override {
|
||||
OpComment comment;
|
||||
PADDLE_ENFORCE_EQ(context->HasInput("X"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"%s operator must have input X", comment.type));
|
||||
PADDLE_ENFORCE_EQ(context->HasInput("Y"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"%s operator must have input Y", comment.type));
|
||||
auto dim_x = context->GetInputDim("X");
|
||||
auto dim_y = context->GetInputDim("Y");
|
||||
PADDLE_ENFORCE_GE(
|
||||
dim_x.size(), dim_y.size(),
|
||||
platform::errors::InvalidArgument(
|
||||
"The size of dim_y should not be greater than dim_x's."));
|
||||
|
||||
context->SetOutputDim("Out", {1});
|
||||
context->ShareLoD("X", "Out");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
#define REGISTER_COMPARE_REDUCE_OP(op_type, _equation) \
|
||||
struct _##op_type##Comment { \
|
||||
static char type[]; \
|
||||
static char equation[]; \
|
||||
}; \
|
||||
char _##op_type##Comment::type[]{#op_type}; \
|
||||
char _##op_type##Comment::equation[]{_equation}; \
|
||||
REGISTER_OPERATOR( \
|
||||
op_type, ::paddle::operators::CompareReduceOp<_##op_type##Comment>, \
|
||||
::paddle::operators::CompareReduceOpProtoMaker<_##op_type##Comment>, \
|
||||
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
|
||||
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||
|
||||
#define REGISTER_COMPARE_REDUCE_CPU_KERNEL(op_type, functor) \
|
||||
REGISTER_OP_CPU_KERNEL( \
|
||||
op_type, ::paddle::operators::CompareReduceOpKernel< \
|
||||
::paddle::platform::CPUDeviceContext, functor<int>>, \
|
||||
::paddle::operators::CompareReduceOpKernel< \
|
||||
::paddle::platform::CPUDeviceContext, functor<int64_t>>, \
|
||||
::paddle::operators::CompareReduceOpKernel< \
|
||||
::paddle::platform::CPUDeviceContext, functor<float>>, \
|
||||
::paddle::operators::CompareReduceOpKernel< \
|
||||
::paddle::platform::CPUDeviceContext, functor<double>>);
|
||||
REGISTER_COMPARE_REDUCE_OP(equal_reduce, "X == Y");
|
||||
|
||||
REGISTER_COMPARE_REDUCE_CPU_KERNEL(equal_reduce,
|
||||
paddle::operators::EqualReduceFunctor);
|
@ -0,0 +1,88 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/controlflow/compare_reduce_op.h"
|
||||
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct IdentityFunctor {
|
||||
HOSTDEVICE explicit inline IdentityFunctor() {}
|
||||
|
||||
HOSTDEVICE inline T operator()(const T& x) const { return x; }
|
||||
};
|
||||
|
||||
struct BitwiseAdd {
|
||||
// Bitwise add operator, returns <tt>a + b</tt>
|
||||
template <typename T>
|
||||
__host__ __device__ __forceinline__ T operator()(const T& a,
|
||||
const T& b) const {
|
||||
return a & b;
|
||||
}
|
||||
};
|
||||
template <typename DeviceContext, typename Functor>
|
||||
class CompareReduceOpKernel
|
||||
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
using T = typename Functor::ELEM_TYPE;
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
auto* x = context.Input<Tensor>("X");
|
||||
auto* y = context.Input<Tensor>("Y");
|
||||
auto* z = context.Output<Tensor>("Out");
|
||||
int axis = context.Attr<int>("axis");
|
||||
|
||||
Tensor tmp;
|
||||
framework::DDim x_dims = x->dims();
|
||||
framework::DDim y_dims = y->dims();
|
||||
int max_dim = std::max(x_dims.size(), y_dims.size());
|
||||
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
|
||||
std::vector<int> x_dims_array(max_dim);
|
||||
std::vector<int> y_dims_array(max_dim);
|
||||
std::vector<int> tmp_dims_array(max_dim);
|
||||
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
|
||||
y_dims_array.data(), tmp_dims_array.data(), max_dim,
|
||||
axis);
|
||||
tmp.mutable_data<bool>(framework::make_ddim(tmp_dims_array),
|
||||
context.GetPlace());
|
||||
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis,
|
||||
Functor(), &tmp);
|
||||
// Reduce by 'bitwise and' operator
|
||||
std::vector<int> reduce_dims;
|
||||
reduce_dims.resize(tmp.dims().size());
|
||||
for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i;
|
||||
auto stream = context.cuda_device_context().stream();
|
||||
TensorReduce<bool, bool, BitwiseAdd, IdentityFunctor<bool>>(
|
||||
tmp, z, reduce_dims, true, BitwiseAdd(), IdentityFunctor<bool>(),
|
||||
stream);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
|
||||
REGISTER_OP_CUDA_KERNEL( \
|
||||
op_type, paddle::operators::CompareReduceOpKernel< \
|
||||
paddle::platform::CUDADeviceContext, functor<int>>, \
|
||||
paddle::operators::CompareReduceOpKernel< \
|
||||
paddle::platform::CUDADeviceContext, functor<int64_t>>, \
|
||||
paddle::operators::CompareReduceOpKernel< \
|
||||
paddle::platform::CUDADeviceContext, functor<float>>, \
|
||||
paddle::operators::CompareReduceOpKernel< \
|
||||
paddle::platform::CUDADeviceContext, functor<double>>);
|
||||
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_reduce,
|
||||
paddle::operators::EqualReduceFunctor);
|
@ -0,0 +1,43 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <math.h>
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
|
||||
#include "paddle/fluid/platform/transform.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct EqualReduceFunctor {
|
||||
using ELEM_TYPE = T;
|
||||
HOSTDEVICE bool operator()(const T& a, const T& b) const {
|
||||
if (std::is_floating_point<T>::value) {
|
||||
// This branch will be optimized while compiling if T is integer. It is
|
||||
// safe to cast a and b to double.
|
||||
return fabs(static_cast<double>(a - b)) < 1e-8;
|
||||
} else {
|
||||
return (a == b);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,27 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from six.moves import reduce
|
||||
from paddle.fluid.layer_helper import LayerHelper
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator
|
||||
from paddle.fluid.framework import Variable, device_guard
|
||||
from paddle.fluid.initializer import Constant
|
||||
from paddle.fluid.core import VarDesc
|
||||
from paddle.fluid import core
|
||||
from paddle.fluid.data_feeder import check_type, check_dtype, convert_dtype
|
||||
from paddle.fluid.layers import utils
|
||||
from paddle.fluid.layers import fill_constant
|
||||
import numpy
|
||||
import warnings
|
@ -0,0 +1,156 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import op_test
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import Program, program_guard
|
||||
|
||||
|
||||
def create_test_broadcast_class(op_type, args, callback):
|
||||
class Cls(op_test.OpTest):
|
||||
def setUp(self):
|
||||
x = np.random.random(size=args['x_size']).astype('int32')
|
||||
y = np.random.random(size=args['y_size']).astype('int32')
|
||||
z = callback(x, y)
|
||||
self.inputs = {'X': x, 'Y': y}
|
||||
self.outputs = {'Out': z}
|
||||
self.op_type = op_type
|
||||
self.axis = args['axis']
|
||||
|
||||
def test_output(self):
|
||||
self.check_output()
|
||||
|
||||
cls_name = "{0}_{1}".format(op_type, 'broadcast')
|
||||
Cls.__name__ = cls_name
|
||||
globals()[cls_name] = Cls
|
||||
|
||||
|
||||
def create_test_not_equal_class(op_type, typename, callback):
|
||||
class Cls(op_test.OpTest):
|
||||
def setUp(self):
|
||||
x = np.random.random(size=(10, 7)).astype(typename)
|
||||
y = np.random.random(size=(10, 7)).astype(typename)
|
||||
z = callback(x, y)
|
||||
self.inputs = {'X': x, 'Y': y}
|
||||
self.outputs = {'Out': z}
|
||||
self.op_type = op_type
|
||||
|
||||
def test_output(self):
|
||||
self.check_output()
|
||||
|
||||
cls_name = "{0}_{1}_{2}".format(op_type, typename, 'not_equal')
|
||||
Cls.__name__ = cls_name
|
||||
globals()[cls_name] = Cls
|
||||
|
||||
|
||||
def create_test_equal_class(op_type, typename, callback):
|
||||
class Cls(op_test.OpTest):
|
||||
def setUp(self):
|
||||
x = y = np.random.random(size=(10, 7)).astype(typename)
|
||||
z = callback(x, y)
|
||||
self.inputs = {'X': x, 'Y': y}
|
||||
self.outputs = {'Out': z}
|
||||
self.op_type = op_type
|
||||
|
||||
def test_output(self):
|
||||
self.check_output()
|
||||
|
||||
cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal')
|
||||
Cls.__name__ = cls_name
|
||||
globals()[cls_name] = Cls
|
||||
|
||||
|
||||
def create_test_dim1_class(op_type, typename, callback):
|
||||
class Cls(op_test.OpTest):
|
||||
def setUp(self):
|
||||
x = y = np.random.random(size=(1)).astype(typename)
|
||||
z = callback(x, y)
|
||||
self.inputs = {'X': x, 'Y': y}
|
||||
self.outputs = {'Out': z}
|
||||
self.op_type = op_type
|
||||
|
||||
def test_output(self):
|
||||
self.check_output()
|
||||
|
||||
cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal')
|
||||
Cls.__name__ = cls_name
|
||||
globals()[cls_name] = Cls
|
||||
|
||||
|
||||
np_equal = lambda _x, _y: np.array(np.array_equal(_x, _y))
|
||||
|
||||
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
|
||||
create_test_not_equal_class('equal_reduce', _type_name, np_equal)
|
||||
create_test_equal_class('equal_reduce', _type_name, np_equal)
|
||||
create_test_dim1_class('equal_reduce', _type_name, np_equal)
|
||||
|
||||
broadcast_args = [{
|
||||
'x_size': (100, 2, 3),
|
||||
'y_size': (100),
|
||||
'axis': 0
|
||||
}, {
|
||||
'x_size': (2, 100, 3),
|
||||
'y_size': (100),
|
||||
'axis': 1
|
||||
}, {
|
||||
'x_size': (2, 3, 100),
|
||||
'y_size': (1, 1),
|
||||
'axis': -1
|
||||
}, {
|
||||
'x_size': (2, 10, 12, 3),
|
||||
'y_size': (10, 12),
|
||||
'axis': 1
|
||||
}, {
|
||||
'x_size': (100, 2, 3, 4),
|
||||
'y_size': (100, 1),
|
||||
'axis': 0
|
||||
}, {
|
||||
'x_size': (10, 3, 12),
|
||||
'y_size': (10, 1, 12),
|
||||
'axis': -1
|
||||
}, {
|
||||
'x_size': (2, 12, 3, 5),
|
||||
'y_size': (2, 12, 1, 5),
|
||||
'axis': -1
|
||||
}, {
|
||||
'x_size': (2, 12, 3, 5),
|
||||
'y_size': (3, 5),
|
||||
'axis': 2
|
||||
}]
|
||||
|
||||
|
||||
def np_broadcast_equal(_x, _y):
|
||||
res = np.all(np.equal(_x, _y))
|
||||
return np.array(res)
|
||||
|
||||
|
||||
for args in broadcast_args:
|
||||
create_test_broadcast_class('equal_reduce', args, np_broadcast_equal)
|
||||
|
||||
|
||||
class TestEqualReduceAPI(unittest.TestCase):
|
||||
def test_name(self):
|
||||
x = fluid.layers.assign(np.array([3, 4], dtype="int32"))
|
||||
y = fluid.layers.assign(np.array([3, 4], dtype="int32"))
|
||||
out = paddle.equal(x, y, name='equal_res')
|
||||
assert 'equal_res' in out.name
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue