Add elementwise XPU OP kernel for KUNLUN core, including (but still cannot process common broadcastmy_2.0rc
parent
ae6ad23c3c
commit
c791df09cf
@ -0,0 +1,43 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct XPUDivFunctor {
|
||||
int operator()(xpu::Context* ctx, const T* x, const T* y, T* z, int len) {
|
||||
return xpu::elementwise_div(ctx, x, y, z, len);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class ElementwiseDivXPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
XPUElementwise<T, XPUDivFunctor<T>>(ctx);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
elementwise_div,
|
||||
ops::ElementwiseDivXPUKernel<paddle::platform::XPUDeviceContext, float>);
|
||||
#endif
|
@ -0,0 +1,45 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct XPUMaxFunctor {
|
||||
int operator()(xpu::Context* ctx, const T* x, const T* y, T* z, int len) {
|
||||
return xpu::elementwise_max(ctx, x, y, z, len);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class ElementwiseMaxXPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
XPUElementwise<T, XPUMaxFunctor<T>>(ctx);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
elementwise_max,
|
||||
ops::ElementwiseMaxXPUKernel<paddle::platform::XPUDeviceContext, float>);
|
||||
#endif
|
@ -0,0 +1,40 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
template <typename DeviceContext, typename T>
|
||||
class ElementwiseMulXPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
XPUElementwise<T, XPUMulFunctor<T>>(ctx);
|
||||
}
|
||||
};
|
||||
DEFINE_XPU_GRAD_KERNEL(Mul, mul, true);
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
elementwise_mul,
|
||||
ops::ElementwiseMulXPUKernel<paddle::platform::XPUDeviceContext, float>);
|
||||
REGISTER_OP_XPU_KERNEL(elementwise_mul_grad,
|
||||
ops::ElementwiseMulGradXPUKernel<
|
||||
paddle::platform::XPUDeviceContext, float>);
|
||||
|
||||
#endif
|
@ -0,0 +1,49 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct XPUSubFunctor {
|
||||
int operator()(xpu::Context* ctx, const T* x, const T* y, T* z, int len) {
|
||||
return xpu::elementwise_sub(ctx, x, y, z, len);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class ElementwiseSubXPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
XPUElementwise<T, XPUSubFunctor<T>>(ctx);
|
||||
}
|
||||
};
|
||||
|
||||
DEFINE_XPU_GRAD_KERNEL(Sub, sub, false);
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
elementwise_sub,
|
||||
ops::ElementwiseSubXPUKernel<paddle::platform::XPUDeviceContext, float>);
|
||||
REGISTER_OP_XPU_KERNEL(elementwise_sub_grad,
|
||||
ops::ElementwiseSubGradXPUKernel<
|
||||
paddle::platform::XPUDeviceContext, float>);
|
||||
|
||||
#endif
|
@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
paddle.enable_static()
|
||||
|
||||
|
||||
class TestXPUElementwiseOpBase(object):
|
||||
def setUp(self, op_type):
|
||||
self.op_type = op_type
|
||||
self.attrs = {'use_xpu': True}
|
||||
self.is_common_broadcast = False
|
||||
self.is_x_size_less_than_y = False
|
||||
self.grad_implemented = False
|
||||
self.y_grad_implemented = True
|
||||
self.dtype = np.float32
|
||||
self.__class__.op_type = self.op_type
|
||||
self.__class__.use_xpu = True
|
||||
self.__class__.dtype = self.dtype
|
||||
|
||||
def net(self, place):
|
||||
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
||||
x = fluid.layers.data(
|
||||
name='X', shape=self.inputs['X'].shape, dtype=self.dtype)
|
||||
y = fluid.layers.data(
|
||||
name='Y', shape=self.inputs['Y'].shape, dtype=self.dtype)
|
||||
op = getattr(fluid.layers, self.op_type)
|
||||
z = op(x, y)
|
||||
exe = fluid.Executor(place)
|
||||
z_value = exe.run(feed=self.inputs, fetch_list=[z.name])
|
||||
|
||||
def test_check_output(self):
|
||||
if paddle.is_compiled_with_xpu():
|
||||
place = paddle.XPUPlace(0)
|
||||
if not self.is_common_broadcast and not self.is_x_size_less_than_y:
|
||||
self.check_output_with_place(place, atol=1e-3)
|
||||
else:
|
||||
with self.assertRaises(BaseException):
|
||||
self.net(place)
|
||||
|
||||
def _check_grad_xpu_helper(self,
|
||||
inputs_to_check,
|
||||
output_names,
|
||||
no_grad_set=None,
|
||||
max_relative_error=0.05):
|
||||
if self.grad_implemented and not self.is_common_broadcast \
|
||||
and not self.is_x_size_less_than_y:
|
||||
if paddle.is_compiled_with_xpu():
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_grad_with_place(
|
||||
place,
|
||||
inputs_to_check,
|
||||
output_names,
|
||||
no_grad_set=no_grad_set,
|
||||
max_relative_error=max_relative_error)
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self._check_grad_xpu_helper(['X', 'Y'], 'Out')
|
||||
|
||||
def test_check_grad_ingore_x(self):
|
||||
self._check_grad_xpu_helper(['Y'], 'Out', set("X"))
|
||||
|
||||
def test_check_grad_ingore_y(self):
|
||||
if self.y_grad_implemented:
|
||||
self._check_grad_xpu_helper(['X'], 'Out', set("Y"))
|
||||
|
||||
def init_axis(self):
|
||||
self.axis = -1
|
||||
|
||||
def make_input(self, x_shape=[13, 17], y_shape=[13, 17]):
|
||||
self.inputs = {
|
||||
'X': np.random.uniform(0.1, 1, x_shape).astype(self.dtype),
|
||||
'Y': np.random.uniform(0.1, 1, y_shape).astype(self.dtype)
|
||||
}
|
||||
|
||||
def reshape_input(self, x_shape=None, y_shape=None):
|
||||
if x_shape is None:
|
||||
x = self.inputs['X']
|
||||
else:
|
||||
x = self.inputs['X'].reshape(x_shape)
|
||||
if y_shape is None:
|
||||
y = self.inputs['Y']
|
||||
else:
|
||||
y = self.inputs['Y'].reshape(y_shape)
|
||||
return x, y
|
||||
|
||||
def make_output(self, x_shape=None, y_shape=None):
|
||||
pass
|
@ -0,0 +1,138 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from op_test import OpTest, skip_check_grad_ci
|
||||
from elementwise import TestXPUElementwiseOpBase
|
||||
paddle.enable_static()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUElementwiseDivOp(OpTest, TestXPUElementwiseOpBase):
|
||||
def setUp(self):
|
||||
TestXPUElementwiseOpBase.setUp(self, "elementwise_div")
|
||||
self.make_input()
|
||||
self.make_output()
|
||||
|
||||
def make_output(self, x_shape=None, y_shape=None):
|
||||
x, y = self.reshape_input(x_shape, y_shape)
|
||||
self.outputs = {'Out': np.divide(x, y)}
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseDivOp_scalar(TestXPUElementwiseDivOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseDivOp_scalar, self).setUp()
|
||||
self.grad_implemented = False
|
||||
self.make_input([20, 3, 4], [1])
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseDivOp_Vector(TestXPUElementwiseDivOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseDivOp_Vector, self).setUp()
|
||||
self.make_input([100, ], [100, ])
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseDivOp_broadcast_0(TestXPUElementwiseDivOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseDivOp_broadcast_0, self).setUp()
|
||||
self.attrs['axis'] = 0
|
||||
self.make_input([100, 3, 4], [100, ])
|
||||
self.make_output(y_shape=[100, 1, 1])
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseDivOp_broadcast_1(TestXPUElementwiseDivOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseDivOp_broadcast_1, self).setUp()
|
||||
self.attrs['axis'] = 1
|
||||
self.make_input([2, 100, 4], [100, ])
|
||||
self.make_output(y_shape=[1, 100, 1])
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseDivOp_broadcast_2(TestXPUElementwiseDivOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseDivOp_broadcast_2, self).setUp()
|
||||
self.make_input([2, 3, 100], [100, ])
|
||||
self.make_output(y_shape=[1, 1, 100])
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseDivOp_broadcast_3(TestXPUElementwiseDivOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseDivOp_broadcast_3, self).setUp()
|
||||
self.attrs['axis'] = 1
|
||||
self.make_input([2, 10, 12, 5], [10, 12])
|
||||
self.make_output(y_shape=[1, 10, 12, 1])
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseDivOp_broadcast_4(TestXPUElementwiseDivOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseDivOp_broadcast_4, self).setUp()
|
||||
self.is_common_broadcast = True
|
||||
self.make_input([2, 3, 50], [2, 1, 50])
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseDivOp_broadcast_5(TestXPUElementwiseDivOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseDivOp_broadcast_5, self).setUp()
|
||||
self.is_common_broadcast = True
|
||||
self.make_input([2, 3, 4, 20], [2, 3, 1, 20])
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseDivOp_commonuse_1(TestXPUElementwiseDivOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseDivOp_commonuse_1, self).setUp()
|
||||
self.is_common_broadcast = True
|
||||
self.make_input([2, 3, 100], [1, 1, 100])
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseDivOp_xsize_lessthan_ysize(TestXPUElementwiseDivOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseDivOp_xsize_lessthan_ysize, self).setUp()
|
||||
self.is_x_size_less_than_y = True
|
||||
self.attrs['axis'] = 2
|
||||
self.make_input([10, 12], [2, 3, 10, 12])
|
||||
self.make_output(x_shape=[1, 1, 10, 12])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,129 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest, skip_check_grad_ci
|
||||
import paddle
|
||||
from elementwise import TestXPUElementwiseOpBase
|
||||
paddle.enable_static()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUElementwiseOp(OpTest, TestXPUElementwiseOpBase):
|
||||
def setUp(self):
|
||||
TestXPUElementwiseOpBase.setUp(self, "elementwise_max")
|
||||
self.make_input()
|
||||
self.make_output()
|
||||
|
||||
def make_input(self, x_shape=[13, 17], y_shape=[13, 17], idx_list=None):
|
||||
x = np.random.random(x_shape).astype(self.dtype)
|
||||
sgn = np.random.choice([-1, 1], y_shape).astype(self.dtype)
|
||||
if idx_list is None:
|
||||
y = x + sgn * np.random.uniform(0.1, 1, y_shape).astype(self.dtype)
|
||||
else:
|
||||
x_temp = x
|
||||
for idx in idx_list:
|
||||
x_temp = np.take(x_temp, [0], axis=idx)
|
||||
sgn = sgn.reshape(x_temp.shape)
|
||||
y = x_temp + sgn * np.random.uniform(0.1, 1, x_temp.shape)
|
||||
y = y.reshape(y_shape).astype(self.dtype)
|
||||
|
||||
self.inputs = {'X': x, 'Y': y}
|
||||
|
||||
def make_output(self, x_shape=None, y_shape=None):
|
||||
x, y = self.reshape_input(x_shape, y_shape)
|
||||
self.outputs = {'Out': np.maximum(x, y)}
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseMaxOp_scalar(TestXPUElementwiseOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseMaxOp_scalar, self).setUp()
|
||||
self.make_input([2, 3, 20], [1])
|
||||
self.make_output()
|
||||
self.grad_implemented = False
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseMaxOp_Vector(TestXPUElementwiseOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseMaxOp_Vector, self).setUp()
|
||||
self.make_input([100, ], [100, ])
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseMaxOp_broadcast_0(TestXPUElementwiseOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseMaxOp_broadcast_0, self).setUp()
|
||||
self.attrs['axis'] = 0
|
||||
self.make_input([100, 5, 2], [100, ], [1, 2])
|
||||
self.make_output(y_shape=[100, 1, 1])
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseMaxOp_broadcast_1(TestXPUElementwiseOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseMaxOp_broadcast_1, self).setUp()
|
||||
self.attrs['axis'] = 1
|
||||
self.make_input([2, 100, 3], [100, ], [0, 2])
|
||||
self.make_output(y_shape=[1, 100, 1])
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseMaxOp_broadcast_2(TestXPUElementwiseOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseMaxOp_broadcast_2, self).setUp()
|
||||
self.make_input([1, 3, 100], [100, ], [0, 1])
|
||||
self.make_output(y_shape=[1, 1, 100])
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseMaxOp_broadcast_3(TestXPUElementwiseOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseMaxOp_broadcast_3, self).setUp()
|
||||
self.attrs['axis'] = 1
|
||||
self.make_input([2, 50, 2, 1], [50, 2], [0, 3])
|
||||
self.make_output(y_shape=[1, 50, 2, 1])
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseMaxOp_broadcast_4(TestXPUElementwiseOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseMaxOp_broadcast_4, self).setUp()
|
||||
self.make_input([2, 3, 4, 5], [2, 3, 1, 5])
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseMaxOp_broadcast_5(TestXPUElementwiseOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseMaxOp_broadcast_5, self).setUp()
|
||||
self.make_input([2, 3, 100], [1, 1, 100])
|
||||
self.make_output()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,153 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest, skip_check_grad_ci
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import compiler, Program, program_guard
|
||||
import paddle
|
||||
from elementwise import TestXPUElementwiseOpBase
|
||||
paddle.enable_static()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUElementwiseMulOp(OpTest, TestXPUElementwiseOpBase):
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = False
|
||||
|
||||
def setUp(self):
|
||||
TestXPUElementwiseOpBase.setUp(self, "elementwise_mul")
|
||||
self.init_kernel_type()
|
||||
self.init_axis()
|
||||
self.attrs['axis'] = self.axis
|
||||
self.attrs['use_mkldnn'] = self.use_mkldnn
|
||||
self.grad_implemented = True
|
||||
self.make_input()
|
||||
self.make_output()
|
||||
|
||||
def make_output(self, x_shape=None, y_shape=None):
|
||||
x, y = self.reshape_input(x_shape, y_shape)
|
||||
self.outputs = {'Out': np.multiply(x, y)}
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUElementwiseMulOp_scalar(TestXPUElementwiseMulOp):
|
||||
def setUp(self):
|
||||
super(TestXPUElementwiseMulOp_scalar, self).setUp()
|
||||
self.make_input((10, 3, 4), (1, ))
|
||||
self.make_output()
|
||||
self.grad_implemented = False
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUElementwiseMulOp_Vector(TestXPUElementwiseMulOp):
|
||||
def setUp(self):
|
||||
super(TestXPUElementwiseMulOp_Vector, self).setUp()
|
||||
self.make_input((100, ), (100, ))
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUElementwiseMulOp_broadcast_0(TestXPUElementwiseMulOp):
|
||||
def setUp(self):
|
||||
super(TestXPUElementwiseMulOp_broadcast_0, self).setUp()
|
||||
self.make_input((100, 2, 3), (100, ))
|
||||
self.make_output(y_shape=(100, 1, 1))
|
||||
self.y_grad_implemented = False
|
||||
|
||||
def init_axis(self):
|
||||
self.axis = 0
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseMulOp_broadcast_1(TestXPUElementwiseMulOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseMulOp_broadcast_1, self).setUp()
|
||||
self.attrs['axis'] = 1
|
||||
self.y_grad_implemented = False
|
||||
self.make_input((2, 100, 3), (100, ))
|
||||
self.make_output(y_shape=(1, 100, 1))
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseMulOp_broadcast_2(TestXPUElementwiseMulOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseMulOp_broadcast_2, self).setUp()
|
||||
self.y_grad_implemented = False
|
||||
self.make_input((2, 3, 100), (100, ))
|
||||
self.make_output(y_shape=(1, 1, 100))
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseMulOp_broadcast_3(TestXPUElementwiseMulOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseMulOp_broadcast_3, self).setUp()
|
||||
self.attrs['axis'] = 1
|
||||
self.y_grad_implemented = False
|
||||
self.make_input((2, 10, 12, 3), (10, 12))
|
||||
self.make_output(y_shape=(1, 10, 12, 1))
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseMulOp_broadcast_4(TestXPUElementwiseMulOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseMulOp_broadcast_4, self).setUp()
|
||||
self.is_common_broadcast = True
|
||||
self.make_input((10, 2, 11), (10, 1, 11))
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseMulOp_broadcast_5(TestXPUElementwiseMulOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseMulOp_broadcast_5, self).setUp()
|
||||
self.is_common_broadcast = True
|
||||
self.make_input((10, 4, 2, 3), (10, 4, 1, 3))
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUElementwiseMulOp_commonuse_1(TestXPUElementwiseMulOp):
|
||||
def setUp(self):
|
||||
super(TestXPUElementwiseMulOp_commonuse_1, self).setUp()
|
||||
self.is_common_broadcast = True
|
||||
self.make_input((2, 3, 100), (1, 1, 100))
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUElementwiseMulOp_xsize_lessthan_ysize(TestXPUElementwiseMulOp):
|
||||
def setUp(self):
|
||||
super(TestXPUElementwiseMulOp_xsize_lessthan_ysize, self).setUp()
|
||||
self.attrs['axis'] = 2
|
||||
self.is_x_size_less_than_y = True
|
||||
self.make_input((10, 10), (2, 2, 10, 10))
|
||||
self.make_output(x_shape=(1, 1, 10, 10))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,128 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
from op_test import OpTest, skip_check_grad_ci
|
||||
import paddle
|
||||
from elementwise import TestXPUElementwiseOpBase
|
||||
paddle.enable_static()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUElementwiseSubOp(OpTest, TestXPUElementwiseOpBase):
|
||||
def setUp(self):
|
||||
TestXPUElementwiseOpBase.setUp(self, "elementwise_sub")
|
||||
self.make_input()
|
||||
self.make_output()
|
||||
self.grad_implemented = True
|
||||
|
||||
def make_output(self, x_shape=None, y_shape=None):
|
||||
x, y = self.reshape_input(x_shape, y_shape)
|
||||
self.outputs = {'Out': x - y}
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseSubOp_scalar(TestXPUElementwiseSubOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseSubOp_scalar, self).setUp()
|
||||
self.grad_implemented = False
|
||||
self.make_input((10, 3, 4), (1, ))
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseSubOp_Vector(TestXPUElementwiseSubOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseSubOp_Vector, self).setUp()
|
||||
self.make_input((100, ), (100, ))
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseSubOp_broadcast_0(TestXPUElementwiseSubOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseSubOp_broadcast_0, self).setUp()
|
||||
self.attrs['axis'] = 0
|
||||
self.make_input((100, 3, 2), (100, ))
|
||||
self.make_output(y_shape=(100, 1, 1))
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseSubOp_broadcast_1(TestXPUElementwiseSubOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseSubOp_broadcast_1, self).setUp()
|
||||
self.attrs['axis'] = 1
|
||||
self.make_input((2, 100, 3), (100, ))
|
||||
self.make_output(y_shape=(1, 100, 1))
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseSubOp_broadcast_2(TestXPUElementwiseSubOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseSubOp_broadcast_2, self).setUp()
|
||||
self.make_input((2, 3, 100), (100, ))
|
||||
self.make_output(y_shape=(1, 1, 100))
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseSubOp_broadcast_3(TestXPUElementwiseSubOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseSubOp_broadcast_3, self).setUp()
|
||||
self.attrs['axis'] = 1
|
||||
self.make_input((2, 10, 12, 3), (10, 12))
|
||||
self.make_output(y_shape=(1, 10, 12, 1))
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseSubOp_broadcast_4(TestXPUElementwiseSubOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseSubOp_broadcast_4, self).setUp()
|
||||
self.is_common_broadcast = True
|
||||
self.make_input((2, 5, 3, 12), (2, 5, 1, 12))
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseSubOp_commonuse_1(TestXPUElementwiseSubOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseSubOp_commonuse_1, self).setUp()
|
||||
self.is_common_broadcast = True
|
||||
self.make_input((2, 3, 100), (1, 1, 100))
|
||||
self.make_output()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestElementwiseSubOp_xsize_lessthan_ysize(TestXPUElementwiseSubOp):
|
||||
def setUp(self):
|
||||
super(TestElementwiseSubOp_xsize_lessthan_ysize, self).setUp()
|
||||
self.attrs['axis'] = 2
|
||||
self.is_x_size_less_than_y = True
|
||||
self.make_input((10, 12), (2, 3, 10, 12))
|
||||
self.make_output(x_shape=(1, 1, 10, 12))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue