Lookup table v2 xpu (#27888)
* add lookup_table_v2_op_xpu, test=kunlun * add lookup_table_v2_op_xpu, test=kunlun * change some Tips ,test=kunlunmy_2.0rc
parent
6150cc86e3
commit
3eb106da6d
@ -0,0 +1,125 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/lookup_table_v2_op.h"
|
||||
#include <memory>
|
||||
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
|
||||
#include "paddle/fluid/framework/op_version_registry.h"
|
||||
#include "paddle/fluid/framework/var_type_inference.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
template <typename DeviceContext, typename T>
|
||||
class LookupTableV2XPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
auto *ids_t = context.Input<LoDTensor>("Ids"); // int
|
||||
auto *output_t = context.Output<LoDTensor>("Out"); // float
|
||||
auto *table_var = context.InputVar("W");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
(std::is_same<DeviceContext, platform::XPUDeviceContext>::value), true,
|
||||
platform::errors::InvalidArgument("Unsupported place!"));
|
||||
|
||||
PADDLE_ENFORCE_EQ(table_var->IsType<LoDTensor>(), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"idx in LookupTableV2XPUKernel should be LoDTensor"));
|
||||
|
||||
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
||||
int64_t ids_numel = ids_t->numel();
|
||||
|
||||
auto *table_t = context.Input<LoDTensor>("W");
|
||||
auto &dev_ctx = context.template device_context<DeviceContext>();
|
||||
// size_t N = table_t->dims()[0];
|
||||
size_t D = table_t->dims()[1];
|
||||
|
||||
auto *table = table_t->data<T>();
|
||||
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
||||
const int64_t *ids = ids_t->data<int64_t>();
|
||||
|
||||
PADDLE_ENFORCE_EQ(ids_numel <= std::numeric_limits<int32_t>::max(), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"idx_numel in LookupTableV2XPUKernel should not "
|
||||
"greater than int32_t::max."));
|
||||
int ids_numel_int32 = static_cast<int>(ids_numel);
|
||||
int r = xpu::embedding<T>(dev_ctx.x_context(), ids_numel_int32, ids, D,
|
||||
table, output, padding_idx);
|
||||
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
|
||||
platform::errors::InvalidArgument("XPU kernel error!"));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class LookupTableV2GradXPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
auto *table_var = context.InputVar("W");
|
||||
DDim table_dim;
|
||||
PADDLE_ENFORCE_EQ(
|
||||
table_var->IsType<LoDTensor>(), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"idx in LookupTableV2GradXPUKernel should be LoDTensor"));
|
||||
table_dim = context.Input<LoDTensor>("W")->dims();
|
||||
|
||||
bool is_sparse = context.Attr<bool>("is_sparse");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
is_sparse, false,
|
||||
platform::errors::InvalidArgument(
|
||||
"LookupTableV2GradXPUKernel dose NOT support is_sparse = True"));
|
||||
|
||||
auto ids_t = context.Input<LoDTensor>("Ids");
|
||||
auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
|
||||
auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
|
||||
|
||||
int64_t ids_numel = ids_t->numel();
|
||||
PADDLE_ENFORCE_EQ(ids_numel <= std::numeric_limits<int32_t>::max(), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"idx_numel in LookupTableV2GradXPUKernel should not "
|
||||
"greater than int32_t::max."));
|
||||
int ids_numel_int32 = static_cast<int>(ids_numel);
|
||||
const int64_t *ids_data = ids_t->data<int64_t>();
|
||||
|
||||
int D = d_table_t->dims()[1];
|
||||
const T *d_output_data = d_output_t->data<T>();
|
||||
T *d_table_data = d_table_t->mutable_data<T>(context.GetPlace());
|
||||
auto &dev_ctx = context.template device_context<DeviceContext>();
|
||||
// set zeros for d_table_data
|
||||
const int zero = 0;
|
||||
int r = xpu::memset(dev_ctx.x_context(), d_table_data, zero,
|
||||
d_table_t->numel() * sizeof(T));
|
||||
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
|
||||
platform::errors::InvalidArgument("XPU kernel error!"));
|
||||
|
||||
r = xpu::embedding_backward<T, int64_t>(dev_ctx.x_context(),
|
||||
ids_numel_int32, ids_data, D,
|
||||
d_output_data, d_table_data);
|
||||
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
|
||||
platform::errors::InvalidArgument("XPU kernel error!"));
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
lookup_table_v2,
|
||||
ops::LookupTableV2XPUKernel<paddle::platform::XPUDeviceContext, float>);
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
lookup_table_v2_grad,
|
||||
ops::LookupTableV2GradXPUKernel<paddle::platform::XPUDeviceContext, float>);
|
||||
#endif
|
@ -0,0 +1,223 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
from op_test import OpTest, skip_check_grad_ci
|
||||
import paddle
|
||||
import paddle.fluid.core as core
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.op import Operator
|
||||
import paddle.compat as cpt
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import Program, program_guard
|
||||
|
||||
paddle.enable_static()
|
||||
|
||||
|
||||
class TestDygraphEmbeddingAPIError(unittest.TestCase):
|
||||
def test_errors(self):
|
||||
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
||||
dict_size = 20
|
||||
layer = fluid.dygraph.nn.Embedding(
|
||||
size=[dict_size, 32], param_attr='emb.w', is_sparse=False)
|
||||
# the input must be Variable
|
||||
x0 = fluid.create_lod_tensor(
|
||||
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], paddle.XPUPlace(0))
|
||||
self.assertRaises(TypeError, layer, x0)
|
||||
# the input dtype must be int64
|
||||
data_t = fluid.data(name='word', shape=[1], dtype='int32')
|
||||
self.assertRaises(TypeError, layer, data_t)
|
||||
|
||||
|
||||
class TestLookupTableOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "lookup_table_v2"
|
||||
table = np.random.random((17, 31)).astype("float64")
|
||||
ids = np.random.randint(0, 17, 4).astype("int64")
|
||||
self.inputs = {'W': table, 'Ids': ids}
|
||||
self.outputs = {'Out': table[ids]}
|
||||
|
||||
def test_check_output_with_place(self):
|
||||
self.check_output_with_place(place=paddle.XPUPlace(0))
|
||||
|
||||
def test_check_grad(self):
|
||||
|
||||
self.check_grad_with_place(
|
||||
inputs_to_check=['W'],
|
||||
output_names='Out',
|
||||
no_grad_set=set('Ids'),
|
||||
place=paddle.XPUPlace(0),
|
||||
in_place=True)
|
||||
|
||||
|
||||
class TestLookupTableOpWithTensorIds(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "lookup_table_v2"
|
||||
table = np.random.random((17, 31)).astype("float64")
|
||||
ids = np.random.randint(low=0, high=17, size=(2, 4, 5)).astype("int32")
|
||||
self.inputs = {'W': table, 'Ids': ids}
|
||||
self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output_with_place(place=paddle.XPUPlace(0))
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad_with_place(
|
||||
inputs_to_check=['W'],
|
||||
output_names='Out',
|
||||
no_grad_set=set('Ids'),
|
||||
place=paddle.XPUPlace(0),
|
||||
in_place=True)
|
||||
|
||||
|
||||
@skip_check_grad_ci(
|
||||
reason="Since paddings are not trainable and fixed in forward,"
|
||||
"the gradient of paddings makes no sense and we don't "
|
||||
"test the gradient here.")
|
||||
class TestLookupTableOpWithPadding(TestLookupTableOp):
|
||||
def test_check_output(self):
|
||||
ids = np.squeeze(self.inputs['Ids'])
|
||||
padding_idx = np.random.choice(ids, 1)[0]
|
||||
self.outputs['Out'][ids == padding_idx] = np.zeros(31)
|
||||
self.attrs = {'padding_idx': int(padding_idx)}
|
||||
self.check_output_with_place(place=paddle.XPUPlace(0))
|
||||
|
||||
|
||||
@skip_check_grad_ci(
|
||||
reason="Since paddings are not trainable and fixed in forward,"
|
||||
"the gradient of paddings makes no sense and we don't "
|
||||
"test the gradient here.")
|
||||
class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
|
||||
def test_check_output(self):
|
||||
ids = self.inputs['Ids']
|
||||
flatten_idx = ids.flatten()
|
||||
padding_idx = np.random.choice(flatten_idx, 1)[0]
|
||||
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
|
||||
self.attrs = {'padding_idx': cpt.long_type(padding_idx)}
|
||||
self.check_output_with_place(place=paddle.XPUPlace(0))
|
||||
|
||||
|
||||
class TestLookupTableWIsSelectedRows(unittest.TestCase):
|
||||
def prepare_ids(self, scope, place):
|
||||
ids_tensor = scope.var('Ids').get_tensor()
|
||||
ids_array = np.array([0, 4, 3, 5]).astype("int64")
|
||||
ids_tensor.set(ids_array, place)
|
||||
return ids_array
|
||||
|
||||
def prepare_w(self, scope, place):
|
||||
rows = [0, 1, 2, 3, 4, 5, 6]
|
||||
row_numel = 12
|
||||
w_selected_rows = scope.var('W')
|
||||
w_array = np.ones((len(rows), row_numel)).astype("float32")
|
||||
for i in range(len(rows)):
|
||||
w_array[i] *= i
|
||||
w_tensor = w_selected_rows.get_tensor()
|
||||
w_tensor.set(w_array, place)
|
||||
|
||||
def create_out_tensor(self, scope, place):
|
||||
return scope.var('Out').get_tensor()
|
||||
|
||||
def check_result(self, ids_array, result_array):
|
||||
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
|
||||
for idx, row in enumerate(ids_array):
|
||||
assert (row == result_array[idx]).all()
|
||||
|
||||
def check_with_place(self, place):
|
||||
scope = core.Scope()
|
||||
ids_array = self.prepare_ids(scope, place)
|
||||
|
||||
self.prepare_w(scope, place)
|
||||
|
||||
out_tensor = self.create_out_tensor(scope, place)
|
||||
|
||||
# create and run lookup_table_v2 operator
|
||||
lookup_table = Operator("lookup_table_v2", W='W', Ids='Ids', Out='Out')
|
||||
lookup_table.run(scope, place)
|
||||
|
||||
# get result from Out
|
||||
result_array = np.array(out_tensor)
|
||||
|
||||
self.check_result(ids_array, result_array)
|
||||
|
||||
def test_w_is_selected_rows(self):
|
||||
places = [paddle.XPUPlace(0)]
|
||||
for place in places:
|
||||
self.check_with_place(place)
|
||||
|
||||
|
||||
class TestLookupTableWithTensorIdsWIsSelectedRows(
|
||||
TestLookupTableWIsSelectedRows):
|
||||
def prepare_ids(self, scope, place):
|
||||
ids_tensor = scope.var('Ids').get_tensor()
|
||||
ids_array = np.random.randint(
|
||||
low=0, high=6, size=(2, 4, 3)).astype("int64")
|
||||
ids_tensor.set(ids_array, place)
|
||||
return ids_array
|
||||
|
||||
def check_result(self, ids_array, result_array):
|
||||
for idx, row in np.ndenumerate(ids_array):
|
||||
assert (row == result_array[idx]).all()
|
||||
|
||||
|
||||
class TestLookupTableApi(unittest.TestCase):
|
||||
def test_api(self):
|
||||
x = fluid.layers.data(name='x', shape=[20], dtype='int64')
|
||||
emb = fluid.embedding(input=x, size=[128, 64])
|
||||
|
||||
place = paddle.XPUPlace(0)
|
||||
x_data = np.random.randint(0, 127, [2, 20]).astype("int64")
|
||||
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
ret = exe.run(feed={'x': x_data, },
|
||||
fetch_list=[emb],
|
||||
return_numpy=False)
|
||||
|
||||
|
||||
class TestEmbedOpError(unittest.TestCase):
|
||||
def test_errors(self):
|
||||
with program_guard(Program(), Program()):
|
||||
input_data = np.random.randint(0, 10, (4, 6)).astype("int64")
|
||||
|
||||
def test_Variable():
|
||||
# the input type must be Variable
|
||||
fluid.embedding(input=input_data, size=(10, 64))
|
||||
|
||||
self.assertRaises(TypeError, test_Variable)
|
||||
|
||||
def test_input_dtype():
|
||||
# the input dtype must be int64
|
||||
input = fluid.data(name='x1', shape=[4, 6], dtype='float32')
|
||||
fluid.embedding(input=input, size=(10, 64))
|
||||
|
||||
self.assertRaises(TypeError, test_input_dtype)
|
||||
|
||||
def test_param_dtype():
|
||||
# dtype must be float32 or float64
|
||||
input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64')
|
||||
fluid.embedding(input=input2, size=(10, 64), dtype='int64')
|
||||
|
||||
self.assertRaises(TypeError, test_param_dtype)
|
||||
input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64')
|
||||
fluid.embedding(input=input3, size=(10, 64), dtype='float16')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paddle.enable_static()
|
||||
unittest.main()
|
Loading…
Reference in new issue