add xpu ops for training transformer in kunlun (#29539)
* 1.fix matmul bug 2. add one hot * add xpu error msgrevert-31562-mean
parent
0fdd365665
commit
760d015c14
@ -0,0 +1,170 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "xpu/refactor/math.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
namespace operators {
|
||||
typedef enum { XPU_OR, XPU_AND } XpuLogicalType;
|
||||
|
||||
std::string XpuLogicalType2Str(XpuLogicalType ty) {
|
||||
switch (ty) {
|
||||
case XpuLogicalType::XPU_OR:
|
||||
return std::string("logical or");
|
||||
case XpuLogicalType::XPU_AND:
|
||||
return std::string("logical and");
|
||||
default:
|
||||
return std::string("unknown type");
|
||||
}
|
||||
return std::string("unknown");
|
||||
}
|
||||
|
||||
template <XpuLogicalType xpu_type, typename T>
|
||||
class BinaryLogicalOpXPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* x = context.Input<framework::Tensor>("X");
|
||||
auto* y = context.Input<framework::Tensor>("Y");
|
||||
auto* out = context.Output<framework::Tensor>("Out");
|
||||
T* out_ptr = out->mutable_data<T>(context.GetPlace());
|
||||
const T* x_ptr = x->data<T>();
|
||||
const T* y_ptr = y->data<T>();
|
||||
auto& dev_ctx =
|
||||
context.template device_context<paddle::platform::XPUDeviceContext>();
|
||||
framework::Tensor broadcast_x;
|
||||
framework::Tensor broadcast_y;
|
||||
bool need_broad_cast = false;
|
||||
if (x->numel() != out->numel()) {
|
||||
// x need broadcast
|
||||
T* broadcast_x_ptr =
|
||||
broadcast_x.mutable_data<T>(context.GetPlace(), out->numel());
|
||||
auto& out_dim = out->dims();
|
||||
auto& x_dim = x->dims();
|
||||
int dims = out_dim.size();
|
||||
std::vector<int> bcast_xdims;
|
||||
std::vector<int> bcast_ydims;
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
if (out_dim[i] == x_dim[i]) {
|
||||
bcast_xdims.push_back(x_dim[i]);
|
||||
bcast_ydims.push_back(x_dim[i]);
|
||||
continue;
|
||||
}
|
||||
bcast_xdims.push_back(1);
|
||||
bcast_xdims.push_back(x_dim[i]);
|
||||
bcast_ydims.push_back(out_dim[i] / x_dim[i]);
|
||||
bcast_ydims.push_back(x_dim[i]);
|
||||
}
|
||||
|
||||
int ret = xpu::broadcast<int8_t>(
|
||||
dev_ctx.x_context(), reinterpret_cast<const int8_t*> x_ptr,
|
||||
reinterpret_cast<int8_t*> broadcast_x_ptr, bcast_xdims, bcast_ydims);
|
||||
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
|
||||
platform::errors::External(
|
||||
"XPU broadcast kernel return wrong value[%d %s]",
|
||||
ret, XPUAPIErrorMsg[ret]));
|
||||
x_ptr = (const T*)broadcast_x_ptr;
|
||||
need_broad_cast = true;
|
||||
}
|
||||
if (y->numel() != out->numel()) {
|
||||
// y need broadcast
|
||||
T* broadcast_y_ptr =
|
||||
broadcast_y.mutable_data<T>(context.GetPlace(), out->numel());
|
||||
auto& out_dim = out->dims();
|
||||
auto& y_dim = y->dims();
|
||||
int dims = out_dim.size();
|
||||
std::vector<int> bcast_xdims;
|
||||
std::vector<int> bcast_ydims;
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
if (out_dim[i] == y_dim[i]) {
|
||||
bcast_xdims.push_back(y_dim[i]);
|
||||
bcast_ydims.push_back(y_dim[i]);
|
||||
continue;
|
||||
}
|
||||
bcast_xdims.push_back(1);
|
||||
bcast_xdims.push_back(y_dim[i]);
|
||||
bcast_ydims.push_back(out_dim[i] / y_dim[i]);
|
||||
bcast_ydims.push_back(y_dim[i]);
|
||||
}
|
||||
|
||||
int ret = xpu::broadcast<int8_t>(
|
||||
dev_ctx.x_context(), reinterpret_cast<const int8_t*> y_ptr,
|
||||
reinterpret_cast<int8_t*> broadcast_y_ptr, bcast_xdims, bcast_ydims);
|
||||
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
|
||||
platform::errors::External(
|
||||
"XPU broadcast kernel return wrong value[%d %s]",
|
||||
ret, XPUAPIErrorMsg[ret]));
|
||||
y_ptr = (const T*)broadcast_y_ptr;
|
||||
need_broad_cast = true;
|
||||
}
|
||||
|
||||
// logical kernel
|
||||
int ret = XPU_SUCCESS;
|
||||
switch (xpu_type) {
|
||||
case XpuLogicalType::XPU_OR:
|
||||
ret = xpu::logical_or<bool>(dev_ctx.x_context(), x_ptr, y_ptr, out_ptr,
|
||||
out->numel());
|
||||
break;
|
||||
case XpuLogicalType::XPU_AND:
|
||||
ret = xpu::logical_and<bool>(dev_ctx.x_context(), x_ptr, y_ptr, out_ptr,
|
||||
out->numel());
|
||||
default:
|
||||
LOG(ERROR) << "xpu not support logical xpu type = "
|
||||
<< XpuLogicalType2Str(xpu_type);
|
||||
break;
|
||||
}
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ret, XPU_SUCCESS,
|
||||
platform::errors::External("XPU API return wrong value[%d %s] in "
|
||||
"op_name[%s].",
|
||||
ret, XPUAPIErrorMsg[ret],
|
||||
XpuLogicalType2Str(xpu_type)));
|
||||
|
||||
if (need_broad_cast && dev_ctx.x_context()->xpu_stream != nullptr) {
|
||||
xpu_wait();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class UnaryLogicalOpXPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* x = context.Input<framework::Tensor>("X");
|
||||
auto* out = context.Output<framework::Tensor>("Out");
|
||||
if (x->numel() == 0) {
|
||||
return;
|
||||
}
|
||||
out->mutable_data<T>(context.GetPlace());
|
||||
auto& dev_ctx =
|
||||
context.template device_context<paddle::platform::XPUDeviceContext>();
|
||||
int ret = xpu::logical_not<bool>(dev_ctx.x_context(), x->data<T>(),
|
||||
out->data<T>(), x->numel());
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ret, XPU_SUCCESS,
|
||||
platform::errors::External("XPU API return wrong value[%d %s].", ret,
|
||||
XPUAPIErrorMsg[ret]));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
#endif
|
@ -0,0 +1,21 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
#include "paddle/fluid/operators/controlflow/logical_op_xpu.h"
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
logical_and,
|
||||
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, bool>);
|
||||
#endif
|
@ -0,0 +1,19 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
#include "paddle/fluid/operators/controlflow/logical_op_xpu.h"
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_XPU_KERNEL(logicalnot, ops::UnaryLogicalOpXPUKernel<bool>);
|
||||
#endif
|
@ -0,0 +1,22 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
#include "paddle/fluid/operators/controlflow/logical_op_xpu.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
logical_or,
|
||||
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, bool>);
|
||||
#endif
|
@ -0,0 +1,71 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/framework.pb.h"
|
||||
#include "paddle/fluid/operators/one_hot_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class OneHotXPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* in = context.Input<LoDTensor>("X");
|
||||
auto* out = context.Output<LoDTensor>("Out");
|
||||
int depth = context.Attr<int>("depth");
|
||||
if (context.HasInput("depth_tensor")) {
|
||||
auto* depth_tensor = context.Input<Tensor>("depth_tensor");
|
||||
auto* depth_data = depth_tensor->data<int32_t>();
|
||||
if (depth_tensor->place() == platform::XPUPlace()) {
|
||||
xpu_memcpy(static_cast<void*>(&depth),
|
||||
static_cast<const void*>(depth_data), sizeof(int32_t),
|
||||
XPU_DEVICE_TO_HOST);
|
||||
} else {
|
||||
depth = depth_data[0];
|
||||
}
|
||||
auto in_dims = in->dims();
|
||||
framework::DDim out_dims(in_dims);
|
||||
out_dims[out_dims.size() - 1] = depth;
|
||||
out->Resize(out_dims);
|
||||
}
|
||||
|
||||
auto& dev_ctx = context.template device_context<DeviceContext>();
|
||||
int len = in->numel();
|
||||
int ret = xpu::one_hot<T>(dev_ctx.x_context(), in->data<T>(),
|
||||
out->mutable_data<float>(context.GetPlace()), len,
|
||||
depth);
|
||||
|
||||
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
|
||||
platform::errors::External(
|
||||
"XPU one_hot kernel return wrong value[%d %s]", ret,
|
||||
XPUAPIErrorMsg[ret]));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
one_hot, ops::OneHotXPUKernel<paddle::platform::XPUDeviceContext, int>,
|
||||
ops::OneHotXPUKernel<paddle::platform::XPUDeviceContext, int64_t>);
|
||||
#endif
|
@ -0,0 +1,235 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
from __future__ import print_function
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
from paddle.fluid.op import Operator
|
||||
import paddle.fluid.core as core
|
||||
import paddle.fluid as fluid
|
||||
import paddle
|
||||
from op_test_xpu import XPUOpTest
|
||||
from paddle.static import Program, program_guard
|
||||
|
||||
TEST_META_OP_DATA = [{
|
||||
'op_str': 'logical_and',
|
||||
'binary_op': True
|
||||
}, {
|
||||
'op_str': 'logical_or',
|
||||
'binary_op': True
|
||||
}, {
|
||||
'op_str': 'logical_not',
|
||||
'binary_op': False
|
||||
}]
|
||||
|
||||
TEST_META_SHAPE_DATA = {
|
||||
'XDimLargerThanYDim1': {
|
||||
'x_shape': [2, 3, 4, 5],
|
||||
'y_shape': [4, 5]
|
||||
},
|
||||
'XDimLargerThanYDim2': {
|
||||
'x_shape': [2, 3, 4, 5],
|
||||
'y_shape': [4, 1]
|
||||
},
|
||||
'XDimLargerThanYDim3': {
|
||||
'x_shape': [2, 3, 4, 5],
|
||||
'y_shape': [1, 4, 1]
|
||||
},
|
||||
'XDimLargerThanYDim4': {
|
||||
'x_shape': [2, 3, 4, 5],
|
||||
'y_shape': [3, 4, 1]
|
||||
},
|
||||
'XDimLargerThanYDim5': {
|
||||
'x_shape': [2, 3, 1, 5],
|
||||
'y_shape': [3, 1, 1]
|
||||
},
|
||||
'XDimLessThanYDim1': {
|
||||
'x_shape': [4, 1],
|
||||
'y_shape': [2, 3, 4, 5]
|
||||
},
|
||||
'XDimLessThanYDim2': {
|
||||
'x_shape': [1, 4, 1],
|
||||
'y_shape': [2, 3, 4, 5]
|
||||
},
|
||||
'XDimLessThanYDim3': {
|
||||
'x_shape': [3, 4, 1],
|
||||
'y_shape': [2, 3, 4, 5]
|
||||
},
|
||||
'XDimLessThanYDim4': {
|
||||
'x_shape': [3, 1, 1],
|
||||
'y_shape': [2, 3, 1, 5]
|
||||
},
|
||||
'XDimLessThanYDim5': {
|
||||
'x_shape': [4, 5],
|
||||
'y_shape': [2, 3, 4, 5]
|
||||
},
|
||||
'Axis1InLargerDim': {
|
||||
'x_shape': [1, 4, 5],
|
||||
'y_shape': [2, 3, 1, 5]
|
||||
},
|
||||
'EqualDim1': {
|
||||
'x_shape': [10, 7],
|
||||
'y_shape': [10, 7]
|
||||
},
|
||||
'EqualDim2': {
|
||||
'x_shape': [1, 1, 4, 5],
|
||||
'y_shape': [2, 3, 1, 5]
|
||||
}
|
||||
}
|
||||
|
||||
TEST_META_WRONG_SHAPE_DATA = {
|
||||
'ErrorDim1': {
|
||||
'x_shape': [2, 3, 4, 5],
|
||||
'y_shape': [3, 4]
|
||||
},
|
||||
'ErrorDim2': {
|
||||
'x_shape': [2, 3, 4, 5],
|
||||
'y_shape': [4, 3]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def run_static_xpu(x_np, y_np, op_str, binary_op=True):
|
||||
paddle.enable_static()
|
||||
startup_program = fluid.Program()
|
||||
main_program = fluid.Program()
|
||||
place = paddle.XPUPlace(0)
|
||||
exe = fluid.Executor(place)
|
||||
with fluid.program_guard(main_program, startup_program):
|
||||
x = paddle.static.data(name='x', shape=x_np.shape, dtype='bool')
|
||||
op = getattr(paddle, op_str)
|
||||
feed_list = {'x': x_np}
|
||||
if not binary_op:
|
||||
res = op(x)
|
||||
else:
|
||||
y = paddle.static.data(name='y', shape=y_np.shape, dtype='bool')
|
||||
feed_list['y'] = y_np
|
||||
res = op(x, y)
|
||||
exe.run(startup_program)
|
||||
static_result = exe.run(main_program, feed=feed_list, fetch_list=[res])
|
||||
return static_result
|
||||
|
||||
|
||||
def run_dygraph_xpu(x_np, y_np, op_str, binary_op=True):
|
||||
place = paddle.XPUPlace(0)
|
||||
paddle.disable_static(place)
|
||||
op = getattr(paddle, op_str)
|
||||
x = paddle.to_tensor(x_np)
|
||||
if not binary_op:
|
||||
dygraph_result = op(x)
|
||||
else:
|
||||
y = paddle.to_tensor(y_np)
|
||||
dygraph_result = op(x, y)
|
||||
return dygraph_result
|
||||
|
||||
|
||||
def np_data_generator(np_shape, *args, **kwargs):
|
||||
return np.random.choice(a=[True, False], size=np_shape).astype(bool)
|
||||
|
||||
|
||||
def test_xpu(unit_test, test_error=False):
|
||||
for op_data in TEST_META_OP_DATA:
|
||||
meta_data = dict(op_data)
|
||||
np_op = getattr(np, meta_data['op_str'])
|
||||
META_DATA = dict(TEST_META_SHAPE_DATA)
|
||||
if test_error:
|
||||
META_DATA = dict(TEST_META_WRONG_SHAPE_DATA)
|
||||
for shape_data in META_DATA.values():
|
||||
meta_data['x_np'] = np_data_generator(shape_data['x_shape'])
|
||||
meta_data['y_np'] = np_data_generator(shape_data['y_shape'])
|
||||
if meta_data['binary_op'] and test_error:
|
||||
# catch C++ Exception
|
||||
unit_test.assertRaises(BaseException, run_static_xpu,
|
||||
**meta_data)
|
||||
continue
|
||||
static_result = run_static_xpu(**meta_data)
|
||||
dygraph_result = run_dygraph_xpu(**meta_data)
|
||||
if meta_data['binary_op']:
|
||||
np_result = np_op(meta_data['x_np'], meta_data['y_np'])
|
||||
else:
|
||||
np_result = np_op(meta_data['x_np'])
|
||||
unit_test.assertTrue((static_result == np_result).all())
|
||||
unit_test.assertTrue((dygraph_result.numpy() == np_result).all())
|
||||
|
||||
|
||||
def test_type_error(unit_test, type_str_map):
|
||||
def check_type(op_str, x, y, binary_op):
|
||||
op = getattr(paddle, op_str)
|
||||
error_type = TypeError
|
||||
if isinstance(x, np.ndarray):
|
||||
x = paddle.to_tensor(x)
|
||||
y = paddle.to_tensor(y)
|
||||
error_type = BaseException
|
||||
if binary_op:
|
||||
if type_str_map['x'] != 'bool' or type_str_map['y'] != 'bool':
|
||||
unit_test.assertRaises(error_type, op, x=x, y=y)
|
||||
if not fluid.in_dygraph_mode():
|
||||
unit_test.assertRaises(error_type, op, x=x, y=y, out=1)
|
||||
else:
|
||||
if type_str_map['x'] != 'bool':
|
||||
unit_test.assertRaises(error_type, op, x=x)
|
||||
if not fluid.in_dygraph_mode():
|
||||
unit_test.assertRaises(error_type, op, x=x, out=1)
|
||||
|
||||
place = paddle.XPUPlace(0)
|
||||
|
||||
for op_data in TEST_META_OP_DATA:
|
||||
meta_data = dict(op_data)
|
||||
binary_op = meta_data['binary_op']
|
||||
|
||||
paddle.disable_static(place)
|
||||
x = np.random.choice(a=[0, 1], size=[10]).astype(type_str_map['x'])
|
||||
y = np.random.choice(a=[0, 1], size=[10]).astype(type_str_map['y'])
|
||||
check_type(meta_data['op_str'], x, y, binary_op)
|
||||
|
||||
paddle.enable_static()
|
||||
startup_program = paddle.static.Program()
|
||||
main_program = paddle.static.Program()
|
||||
with paddle.static.program_guard(main_program, startup_program):
|
||||
x = paddle.static.data(
|
||||
name='x', shape=[10], dtype=type_str_map['x'])
|
||||
y = paddle.static.data(
|
||||
name='y', shape=[10], dtype=type_str_map['y'])
|
||||
check_type(meta_data['op_str'], x, y, binary_op)
|
||||
|
||||
|
||||
def type_map_factory():
|
||||
x_type_list = ['float32', 'float64', 'int32', 'int64', 'bool']
|
||||
y_type_list = ['float32', 'float64', 'int32', 'int64', 'bool']
|
||||
return [{
|
||||
'x': x_type,
|
||||
'y': y_type
|
||||
} for x_type in x_type_list for y_type in y_type_list]
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPU(unittest.TestCase):
|
||||
def test(self):
|
||||
test_xpu(self, True)
|
||||
|
||||
def test_error(self):
|
||||
test_xpu(self, True)
|
||||
|
||||
def test_type_error(self):
|
||||
type_map_list = type_map_factory()
|
||||
for type_map in type_map_list:
|
||||
test_type_error(self, type_map)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,184 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.fluid.core as core
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
from op_test_xpu import XPUOpTest
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import Program, program_guard
|
||||
import time
|
||||
|
||||
paddle.enable_static()
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestOneHotOp(XPUOpTest):
|
||||
def setUp(self):
|
||||
self.use_xpu = True
|
||||
self.op_type = 'one_hot'
|
||||
depth = 10
|
||||
depth_np = np.array(10).astype('int32')
|
||||
x_lod = [[4, 1, 3, 3]]
|
||||
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
|
||||
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
|
||||
|
||||
out = np.zeros(shape=(np.product(x.shape[:-1]),
|
||||
depth)).astype('float32')
|
||||
|
||||
for i in range(np.product(x.shape)):
|
||||
out[i, x[i]] = 1.0
|
||||
|
||||
self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np}
|
||||
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
|
||||
self.outputs = {'Out': (out, x_lod)}
|
||||
|
||||
def test_check_output(self):
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_output_with_place(place, check_dygraph=False)
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestOneHotOp_attr(XPUOpTest):
|
||||
def setUp(self):
|
||||
self.op_type = 'one_hot'
|
||||
depth = 10
|
||||
x_lod = [[4, 1, 3, 3]]
|
||||
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
|
||||
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
|
||||
|
||||
out = np.zeros(shape=(np.product(x.shape[:-1]),
|
||||
depth)).astype('float32')
|
||||
|
||||
for i in range(np.product(x.shape)):
|
||||
out[i, x[i]] = 1.0
|
||||
|
||||
self.inputs = {'X': (x, x_lod)}
|
||||
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32), 'depth': depth}
|
||||
self.outputs = {'Out': (out, x_lod)}
|
||||
|
||||
def test_check_output(self):
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_output_with_place(place, check_dygraph=False)
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestOneHotOp_default_dtype(XPUOpTest):
|
||||
def setUp(self):
|
||||
self.op_type = 'one_hot'
|
||||
depth = 10
|
||||
depth_np = np.array(10).astype('int32')
|
||||
x_lod = [[4, 1, 3, 3]]
|
||||
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
|
||||
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
|
||||
|
||||
out = np.zeros(shape=(np.product(x.shape[:-1]),
|
||||
depth)).astype('float32')
|
||||
|
||||
for i in range(np.product(x.shape)):
|
||||
out[i, x[i]] = 1.0
|
||||
|
||||
self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np}
|
||||
self.attrs = {}
|
||||
self.outputs = {'Out': (out, x_lod)}
|
||||
|
||||
def test_check_output(self):
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_output_with_place(place, check_dygraph=False)
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestOneHotOp_default_dtype_attr(XPUOpTest):
|
||||
def setUp(self):
|
||||
self.op_type = 'one_hot'
|
||||
depth = 10
|
||||
x_lod = [[4, 1, 3, 3]]
|
||||
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
|
||||
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
|
||||
|
||||
out = np.zeros(shape=(np.product(x.shape[:-1]),
|
||||
depth)).astype('float32')
|
||||
|
||||
for i in range(np.product(x.shape)):
|
||||
out[i, x[i]] = 1.0
|
||||
|
||||
self.inputs = {'X': (x, x_lod)}
|
||||
self.attrs = {'depth': depth}
|
||||
self.outputs = {'Out': (out, x_lod)}
|
||||
|
||||
def test_check_output(self):
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_output_with_place(place, check_dygraph=False)
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestOneHotOp_out_of_range(XPUOpTest):
|
||||
def setUp(self):
|
||||
self.op_type = 'one_hot'
|
||||
depth = 10
|
||||
x_lod = [[4, 1, 3, 3]]
|
||||
x = [np.random.choice([-1, depth]) for i in range(sum(x_lod[0]))]
|
||||
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
|
||||
|
||||
out = np.zeros(shape=(np.product(x.shape[:-1]),
|
||||
depth)).astype('float32')
|
||||
|
||||
self.inputs = {'X': (x, x_lod)}
|
||||
self.attrs = {'depth': depth, 'allow_out_of_range': True}
|
||||
self.outputs = {'Out': (out, x_lod)}
|
||||
|
||||
def test_check_output(self):
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_output_with_place(place, check_dygraph=False)
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestOneHotOpError(unittest.TestCase):
|
||||
def test_errors(self):
|
||||
with program_guard(Program(), Program()):
|
||||
# the input must be Variable
|
||||
in_w = np.random.random((4, 1)).astype("int32")
|
||||
self.assertRaises(TypeError, fluid.layers.one_hot, in_w)
|
||||
# the input must be int32 or int 64
|
||||
in_w2 = fluid.layers.data(
|
||||
name="in_w2",
|
||||
shape=[4, 1],
|
||||
append_batch_size=False,
|
||||
dtype="float32")
|
||||
self.assertRaises(TypeError, fluid.layers.one_hot, in_w2)
|
||||
# the depth must be int, long or Variable
|
||||
in_r = fluid.layers.data(
|
||||
name="in_r",
|
||||
shape=[4, 1],
|
||||
append_batch_size=False,
|
||||
dtype="int32")
|
||||
depth_w = np.array([4])
|
||||
self.assertRaises(TypeError, fluid.layers.one_hot, in_r, 4.1)
|
||||
self.assertRaises(TypeError, fluid.layers.one_hot, in_r, depth_w)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
paddle.enable_static()
|
||||
unittest.main()
|
Loading…
Reference in new issue