support transformer v2.0 (#30381)

revert-31562-mean
taixiurong 5 years ago committed by GitHub
parent e85be1b1b2
commit 6a3c8725b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,7 +10,7 @@ if (WITH_AARCH64)
elseif(WITH_SUNWAY)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2020_1227.tar.gz" CACHE STRING "" FORCE)
else()
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_0105.tar.gz" CACHE STRING "" FORCE)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE)
endif()
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")

@ -45,15 +45,13 @@ class LayerNormXPUKernel : public framework::OpKernel<T> {
auto* mean_data = mean->mutable_data<T>(ctx.GetPlace());
auto* variance_data = variance->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::layer_norm(dev_ctx.x_context(), left, right, x_data, y_data,
scale_data, bias_data, epsilon, mean_data,
variance_data, false);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(layer_norm) return wrong "
"value[%d], please check whether Baidu "
"Kunlun Card is properly installed.",
r));
int r = xpu::layer_norm(dev_ctx.x_context(), x_data, y_data, left, right,
epsilon, scale_data, bias_data, mean_data,
variance_data);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU layer_norm kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
};
@ -87,15 +85,14 @@ class LayerNormGradXPUKernel : public framework::OpKernel<T> {
auto* dx_data =
(dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::layer_norm_backward(
dev_ctx.x_context(), left, right, x_data, scale_data, variance_data,
mean_data, dy_data, dx_data, dscale_data, dbias_data, epsilon);
int r = xpu::layer_norm_grad(dev_ctx.x_context(), x_data, dy_data, dx_data,
left, right, epsilon, scale_data, mean_data,
variance_data, dscale_data, dbias_data);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(layer_norm_backward) return wrong "
"value[%d], please check whether Baidu "
"Kunlun Card is properly installed.",
r));
platform::errors::External(
"XPU layer_norm_grad kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
};

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -35,7 +35,7 @@ class OneHotXPUKernel : public framework::OpKernel<T> {
if (context.HasInput("depth_tensor")) {
auto* depth_tensor = context.Input<Tensor>("depth_tensor");
auto* depth_data = depth_tensor->data<int32_t>();
if (depth_tensor->place() == platform::XPUPlace()) {
if (platform::is_xpu_place(depth_tensor->place())) {
xpu_memcpy(static_cast<void*>(&depth),
static_cast<const void*>(depth_data), sizeof(int32_t),
XPU_DEVICE_TO_HOST);

@ -0,0 +1,70 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include <string>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/operators/one_hot_op.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class OneHotV2XPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int depth = context.Attr<int>("depth");
if (context.HasInput("depth_tensor")) {
auto* depth_tensor = context.Input<Tensor>("depth_tensor");
auto* depth_data = depth_tensor->data<int32_t>();
if (platform::is_xpu_place(depth_tensor->place())) {
xpu_memcpy(static_cast<void*>(&depth),
static_cast<const void*>(depth_data), sizeof(int32_t),
XPU_DEVICE_TO_HOST);
} else {
depth = depth_data[0];
}
auto out_dims = out->dims();
out_dims[out_dims.size() - 1] = depth;
out->Resize(out_dims);
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int len = in->numel();
int ret = xpu::one_hot<T>(dev_ctx.x_context(), in->data<T>(),
out->mutable_data<float>(context.GetPlace()), len,
depth, 1.0, 0.0);
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
platform::errors::External(
"XPU one_hot kernel return wrong value[%d %s]", ret,
XPUAPIErrorMsg[ret]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
one_hot_v2, ops::OneHotV2XPUKernel<paddle::platform::XPUDeviceContext, int>,
ops::OneHotV2XPUKernel<paddle::platform::XPUDeviceContext, int64_t>);
#endif

@ -46,10 +46,13 @@ class ScaleXPUKernel : public framework::OpKernel<T> {
in->dims().to_str().c_str(),
out->dims().to_str().c_str()));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::scale(dev_ctx.x_context(), in->numel(), scale, bias,
bias_after_scale, in->data<float>(), out->data<float>());
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::Fatal("XPU scale kernel error!"));
int r =
xpu::scale(dev_ctx.x_context(), in->data<float>(), out->data<float>(),
in->numel(), bias_after_scale, scale, bias);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU scale kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};

@ -41,8 +41,21 @@ class SoftmaxXPUKernel : public framework::OpKernel<T> {
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::softmax<T>(dev_ctx.x_context(), x->data<float>(),
out->data<float>(), x_dims, axis);
int r = XPU_SUCCESS;
Tensor clip_x;
int len = x->numel();
T* clip_x_data =
clip_x.mutable_data<T>(platform::XPUPlace(), len * sizeof(T));
r = xpu::clip(dev_ctx.x_context(), x->data<float>(), clip_x_data, len,
-1e30, 1e30);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(clip) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::softmax<T>(dev_ctx.x_context(), clip_x_data, out->data<float>(),
x_dims, axis);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(softmax2d_forward) return wrong "

@ -0,0 +1,196 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import sys
sys.path.append("..")
from op_test_xpu import XPUOpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import time
paddle.enable_static()
class TestOneHotOp(XPUOpTest):
def setUp(self):
self.use_xpu = True
self.op_type = 'one_hot_v2'
depth = 10
depth_np = np.array(10).astype('int32')
# dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0])])
out = np.zeros(shape=(np.product(x.shape), depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
class TestOneHotOp_attr(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]), 1,
depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, 0, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32), 'depth': depth}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
class TestOneHotOp_default_dtype(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
depth_np = np.array(10).astype('int32')
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0])])
out = np.zeros(shape=(np.product(x.shape), depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np}
self.attrs = {}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
class TestOneHotOp_default_dtype_attr(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]), 1,
depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, 0, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
class TestOneHotOp_out_of_range(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
x_lod = [[4, 1, 3, 3]]
x = [np.random.choice([-1, depth]) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0])])
out = np.zeros(shape=(np.product(x.shape), depth)).astype('float32')
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth, 'allow_out_of_range': True}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
class TestOneHotOpApi(unittest.TestCase):
def test_api(self):
depth = 10
self._run(depth)
def test_api_with_depthTensor(self):
depth = fluid.layers.assign(input=np.array([10], dtype=np.int32))
self._run(depth)
def test_api_with_dygraph(self):
depth = 10
label = np.array([np.random.randint(0, depth - 1)
for i in range(6)]).reshape([6, 1])
with fluid.dygraph.guard():
one_hot_label = fluid.one_hot(
input=fluid.dygraph.to_variable(label), depth=depth)
def _run(self, depth):
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
one_hot_label = fluid.one_hot(input=label, depth=depth)
place = fluid.XPUPlace(0)
label_data = np.array([np.random.randint(0, 10 - 1)
for i in range(6)]).reshape([6, 1])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
ret = exe.run(feed={'label': label_data, },
fetch_list=[one_hot_label],
return_numpy=False)
class BadInputTestOnehotV2(unittest.TestCase):
def test_error(self):
with fluid.program_guard(fluid.Program()):
def test_bad_x():
label = fluid.layers.data(
name="label",
shape=[4],
append_batch_size=False,
dtype="float32")
one_hot_label = fluid.one_hot(input=label, depth=4)
self.assertRaises(TypeError, test_bad_x)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Loading…
Cancel
Save