support elementwise add, activation, matmul on Baidu Kunlun (#27143)
* support elementwise add, activation, matmul on Baidu Kunlun * test=kunlun * minor * test=kunlun * reconstuct the xpu directory * test=kunlun * minor * test=kunlun * minor * test=kunlun * minor * test=kunlun * minor * test=kunlun * minor * test=kunlunrevert-27356-init_low_level_gloo
parent
d37b3774fd
commit
6b727e08b1
@ -0,0 +1,179 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
|
||||
#include "paddle/fluid/operators/activation_op.h"
|
||||
#include <string>
|
||||
#include "paddle/fluid/platform/xpu_header.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using paddle::framework::Tensor;
|
||||
|
||||
template <typename Functor>
|
||||
class XPUActivationKernel
|
||||
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
Functor functor;
|
||||
|
||||
auto attrs = functor.GetAttrs();
|
||||
for (auto &attr : attrs) {
|
||||
*attr.second = context.Attr<float>(attr.first);
|
||||
}
|
||||
functor(context);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Functor>
|
||||
class XPUActivationGradKernel
|
||||
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
Functor functor;
|
||||
|
||||
auto attrs = functor.GetAttrs();
|
||||
for (auto &attr : attrs) {
|
||||
*attr.second = context.Attr<float>(attr.first);
|
||||
}
|
||||
functor(context);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
void xpu_activation_forward(const framework::ExecutionContext &ctx,
|
||||
xpu::Activation_t type) {
|
||||
const auto *x = ctx.Input<Tensor>("X");
|
||||
auto *y = ctx.Output<Tensor>("Out");
|
||||
const T *x_data = x->data<T>();
|
||||
T *y_data = y->mutable_data<T>(ctx.GetPlace());
|
||||
int r = 0;
|
||||
if (xpu::Activation_t::ACT_POW == type.type) {
|
||||
type.pow_factor = ctx.Attr<float>("factor");
|
||||
}
|
||||
auto xpu_context = ctx.device_context<DeviceContext>().x_context();
|
||||
r = xpu::activation_forward(xpu_context, type, x->numel(),
|
||||
reinterpret_cast<const float *>(x_data),
|
||||
reinterpret_cast<float *>(y_data));
|
||||
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
|
||||
platform::errors::External(
|
||||
"XPU API return wrong value[%d], please check whether "
|
||||
"Baidu Kunlun Card is properly installed.",
|
||||
r));
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
void xpu_activation_backward(const framework::ExecutionContext &ctx,
|
||||
xpu::Activation_t type) {
|
||||
/* TODO: relu tanh sigmoid are inplace */
|
||||
const auto *x = ctx.Input<Tensor>("X");
|
||||
auto *y = ctx.Input<Tensor>("Out");
|
||||
auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
const T *x_data = nullptr;
|
||||
const T *y_data = nullptr;
|
||||
const T *y_grad = nullptr;
|
||||
if (x != nullptr) x_data = x->data<T>();
|
||||
if (y != nullptr) y_data = y->data<T>();
|
||||
if (dOut != nullptr) y_grad = dOut->data<T>();
|
||||
T *x_grad = dX->mutable_data<T>(ctx.GetPlace());
|
||||
auto xpu_context = ctx.device_context<DeviceContext>().x_context();
|
||||
int r = xpu::activation_backward(xpu_context, type, dX->numel(),
|
||||
reinterpret_cast<const float *>(x_data),
|
||||
reinterpret_cast<const float *>(y_data),
|
||||
reinterpret_cast<const float *>(y_grad),
|
||||
reinterpret_cast<float *>(x_grad));
|
||||
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
|
||||
platform::errors::External(
|
||||
"XPU API return wrong value[%d], please check whether "
|
||||
"Baidu Kunlun Card is properly installed.",
|
||||
r));
|
||||
}
|
||||
|
||||
template <typename T, xpu::Activation_t::act_enum algorithm>
|
||||
struct XPUActivationFunc : public BaseActivationFunctor<T> {
|
||||
void operator()(const framework::ExecutionContext &ctx) const {
|
||||
xpu_activation_forward<paddle::platform::XPUDeviceContext, T>(ctx,
|
||||
algorithm);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, xpu::Activation_t::act_enum algorithm>
|
||||
struct XPUActivationGradFunc : public BaseActivationFunctor<T> {
|
||||
void operator()(const framework::ExecutionContext &ctx) const {
|
||||
xpu_activation_backward<paddle::platform::XPUDeviceContext, T>(ctx,
|
||||
algorithm);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using XPUReluFunctor = XPUActivationFunc<T, xpu::Activation_t::RELU>;
|
||||
template <typename T>
|
||||
using XPUSigmoidFunctor = XPUActivationFunc<T, xpu::Activation_t::SIGMOID>;
|
||||
template <typename T>
|
||||
using XPUTanhFunctor = XPUActivationFunc<T, xpu::Activation_t::TANH>;
|
||||
template <typename T>
|
||||
using XPUGeluFunctor = XPUActivationFunc<T, xpu::Activation_t::GELU>;
|
||||
template <typename T>
|
||||
using XPULogFunctor = XPUActivationFunc<T, xpu::Activation_t::LOG>;
|
||||
template <typename T>
|
||||
using XPUSquareFunctor = XPUActivationFunc<T, xpu::Activation_t::SQUARE>;
|
||||
template <typename T>
|
||||
using XPUSuareGradFunctor = XPUActivationGradFunc<T, xpu::Activation_t::SQUARE>;
|
||||
template <typename T>
|
||||
using XPUReluGradFunctor = XPUActivationGradFunc<T, xpu::Activation_t::RELU>;
|
||||
template <typename T>
|
||||
using XPUSigmoidGradFunctor =
|
||||
XPUActivationGradFunc<T, xpu::Activation_t::SIGMOID>;
|
||||
template <typename T>
|
||||
using XPUTanhGradFunctor = XPUActivationGradFunc<T, xpu::Activation_t::TANH>;
|
||||
template <typename T>
|
||||
using XPUGeluGradFunctor = XPUActivationGradFunc<T, xpu::Activation_t::GELU>;
|
||||
template <typename T>
|
||||
using XPUSqrtFunctor = XPUActivationFunc<T, xpu::Activation_t::SQRT>;
|
||||
template <typename T>
|
||||
using XPUSqrtGradFunctor = XPUActivationGradFunc<T, xpu::Activation_t::SQRT>;
|
||||
template <typename T>
|
||||
using XPUACTPowFunctor = XPUActivationFunc<T, xpu::Activation_t::ACT_POW>;
|
||||
template <typename T>
|
||||
using XPUABSFunctor = XPUActivationFunc<T, xpu::Activation_t::ABS>;
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
#define REGISTER_ACTIVATION_XPU_KERNEL(act_type, functor, grad_functor) \
|
||||
REGISTER_OP_XPU_KERNEL(act_type, \
|
||||
ops::XPUActivationKernel<ops::functor<float>>); \
|
||||
REGISTER_OP_XPU_KERNEL( \
|
||||
act_type##_grad, \
|
||||
ops::XPUActivationGradKernel<ops::grad_functor<float>>);
|
||||
|
||||
REGISTER_ACTIVATION_XPU_KERNEL(relu, XPUReluFunctor, XPUReluGradFunctor)
|
||||
REGISTER_ACTIVATION_XPU_KERNEL(tanh, XPUTanhFunctor, XPUTanhGradFunctor)
|
||||
REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, XPUSigmoidFunctor,
|
||||
XPUSigmoidGradFunctor)
|
||||
REGISTER_ACTIVATION_XPU_KERNEL(gelu, XPUGeluFunctor, XPUGeluGradFunctor)
|
||||
REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor)
|
||||
REGISTER_ACTIVATION_XPU_KERNEL(square, XPUSquareFunctor, XPUSuareGradFunctor)
|
||||
REGISTER_OP_XPU_KERNEL(log,
|
||||
ops::XPUActivationKernel<ops::XPULogFunctor<float>>);
|
||||
REGISTER_OP_XPU_KERNEL(pow,
|
||||
ops::XPUActivationKernel<ops::XPUACTPowFunctor<float>>);
|
||||
REGISTER_OP_XPU_KERNEL(abs,
|
||||
ops::XPUActivationKernel<ops::XPUABSFunctor<float>>);
|
||||
|
||||
#endif // PADDLE_WITH_XPU
|
@ -0,0 +1,162 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
|
||||
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class ElementwiseAddXPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
XPUElementwise<T, XPUAddFunctor<T>>(ctx);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
ElemwiseGradKernel<T>::Compute(ctx);
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
||||
|
||||
auto dx_dims = dout->dims();
|
||||
auto dy_dims_untrimed = dout->dims();
|
||||
T *dx_data = NULL;
|
||||
T *dy_data = NULL;
|
||||
|
||||
int axis = ctx.Attr<int>("axis");
|
||||
PADDLE_ENFORCE_GE(dx_dims.size(), dy_dims_untrimed.size(),
|
||||
"Rank of first input must >= rank of second input.");
|
||||
|
||||
if (dx != nullptr) {
|
||||
dx->mutable_data<T>(ctx.GetPlace());
|
||||
dx_dims = dx->dims();
|
||||
dx_data = dx->data<T>();
|
||||
}
|
||||
|
||||
if (dy != nullptr) {
|
||||
dy->mutable_data<T>(ctx.GetPlace());
|
||||
dy_dims_untrimed = dy->dims();
|
||||
dy_data = dy->data<T>();
|
||||
}
|
||||
|
||||
int pre, n, post, is_common_broadcast;
|
||||
if (dx_dims == dy_dims_untrimed) {
|
||||
pre = post = 1;
|
||||
n = dout->numel();
|
||||
} else {
|
||||
axis = (axis == -1 ? dx_dims.size() - dy_dims_untrimed.size() : axis);
|
||||
PADDLE_ENFORCE(axis >= 0 && axis < dx_dims.size(),
|
||||
"Axis should be in range [0, dx_dims)");
|
||||
auto dy_dims = trim_trailing_singular_dims(dy_dims_untrimed);
|
||||
axis = (dy_dims.size() == 0) ? dx_dims.size() : axis;
|
||||
get_mid_dims(dx_dims, dy_dims, axis, &pre, &n, &post,
|
||||
&is_common_broadcast);
|
||||
}
|
||||
int len = pre * n * post;
|
||||
|
||||
auto &dev_ctx =
|
||||
ctx.template device_context<paddle::platform::XPUDeviceContext>();
|
||||
if (post == 1) {
|
||||
int r = xpu::matrix_vector_add_grad(
|
||||
dev_ctx.x_context(), dout->data<T>(), dout->data<T>(),
|
||||
dout->data<T>(), dout->data<T>(), dx_data, dy_data, pre, n);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
r, XPU_SUCCESS,
|
||||
platform::errors::External(
|
||||
"XPU API return wrong value[%d], please check whether "
|
||||
"Baidu Kunlun Card is properly installed.",
|
||||
r));
|
||||
return;
|
||||
}
|
||||
|
||||
if (dx == nullptr) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
xpu_malloc(reinterpret_cast<void **>(&dx_data), len * sizeof(float)),
|
||||
XPU_SUCCESS, platform::errors::External("XPU has no enough memory"));
|
||||
}
|
||||
|
||||
if (dy == nullptr) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
xpu_malloc(reinterpret_cast<void **>(&dy_data), len * sizeof(float)),
|
||||
XPU_SUCCESS, platform::errors::External("XPU has no enough memory"));
|
||||
} else {
|
||||
if (len != n) {
|
||||
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void **>(&dy_data),
|
||||
len * sizeof(float)),
|
||||
XPU_SUCCESS, platform::errors::External(
|
||||
"XPU has no enough memory"));
|
||||
}
|
||||
}
|
||||
|
||||
int r = xpu::elementwise_add_grad(
|
||||
dev_ctx.x_context(), dout->data<T>() /*x*/, dout->data<T>() /*y*/,
|
||||
dout->data<T>() /*out*/, dout->data<T>(), dx_data, dy_data, len);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
r, XPU_SUCCESS,
|
||||
platform::errors::External(
|
||||
"XPU API return wrong value[%d], please check whether "
|
||||
"Baidu Kunlun Card is properly installed.",
|
||||
r));
|
||||
|
||||
if ((dy != nullptr) && (len != n)) {
|
||||
r = xpu::reduce_ew(dev_ctx.x_context(), dy_data, dy->data<T>(), pre, n,
|
||||
post, xpu::ElementwiseOp::ASSIGN);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
r, XPU_SUCCESS,
|
||||
platform::errors::External(
|
||||
"XPU API return wrong value[%d], please check whether "
|
||||
"Baidu Kunlun Card is properly installed.",
|
||||
r));
|
||||
dev_ctx.Wait();
|
||||
xpu_free(dy_data);
|
||||
}
|
||||
|
||||
if ((dx == nullptr || dy == nullptr) && !(dy != nullptr && len != n)) {
|
||||
dev_ctx.Wait();
|
||||
}
|
||||
|
||||
if (dx == nullptr) {
|
||||
xpu_free(dx_data);
|
||||
}
|
||||
if (dy == nullptr) {
|
||||
xpu_free(dy_data);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
elementwise_add,
|
||||
ops::ElementwiseAddXPUKernel<paddle::platform::XPUDeviceContext, float>);
|
||||
REGISTER_OP_XPU_KERNEL(elementwise_add_grad,
|
||||
ops::ElementwiseAddGradXPUKernel<
|
||||
paddle::platform::XPUDeviceContext, float>);
|
||||
#endif
|
@ -0,0 +1,113 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#pragma once
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct XPUAddFunctor {
|
||||
int operator()(xpu::Context* ctx, const T* x, const T* y, T* z, int len) {
|
||||
return xpu::elementwise_add(ctx, x, y, z, len);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct XPUMulFunctor {
|
||||
int operator()(xpu::Context* ctx, const T* x, const T* y, T* z, int len) {
|
||||
return xpu::elementwise_mul(ctx, x, y, z, len);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Functor>
|
||||
void XPUElementwise(const framework::ExecutionContext& ctx) {
|
||||
PADDLE_ENFORCE(platform::is_xpu_place(ctx.GetPlace()),
|
||||
"This kernel only runs on XPU device.");
|
||||
auto x_var = ctx.InputVar("X");
|
||||
PADDLE_ENFORCE_NE(x_var, nullptr,
|
||||
platform::errors::Fatal("Cannot get input Variable X"));
|
||||
PADDLE_ENFORCE(x_var->IsType<framework::LoDTensor>(),
|
||||
"XPU only support LoDTensor");
|
||||
|
||||
auto x = x_var->Get<framework::LoDTensor>();
|
||||
auto* y = ctx.Input<framework::LoDTensor>("Y");
|
||||
auto* z = ctx.Output<framework::LoDTensor>("Out");
|
||||
z->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int axis = ctx.Attr<int>("axis");
|
||||
auto x_dims = x.dims();
|
||||
auto y_dims_untrimed = y->dims();
|
||||
PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
|
||||
"Rank of first input must >= rank of second input.");
|
||||
axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
|
||||
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
|
||||
"Axis should be in range [0, x_dims)");
|
||||
auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
|
||||
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
|
||||
int pre, n, post, is_common_broadcast;
|
||||
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post, &is_common_broadcast);
|
||||
int len = pre * n * post;
|
||||
|
||||
const T* x_data = x.data<T>();
|
||||
const T* y_data = y->data<T>();
|
||||
T* z_data = z->data<T>();
|
||||
T* y_broadcast = nullptr;
|
||||
|
||||
auto& dev_ctx =
|
||||
ctx.template device_context<paddle::platform::XPUDeviceContext>();
|
||||
|
||||
if (post == 1) {
|
||||
if (std::is_same<Functor, XPUAddFunctor<T>>::value) {
|
||||
int res = xpu::matrix_vector_add(dev_ctx.x_context(), x_data, y_data,
|
||||
z_data, pre, n);
|
||||
PADDLE_ENFORCE(res == xpu::Error_t::SUCCESS, "XPU kernel error! res = %d",
|
||||
res);
|
||||
return;
|
||||
}
|
||||
if (std::is_same<Functor, XPUMulFunctor<T>>::value) {
|
||||
int res = xpu::matrix_vector_mul(dev_ctx.x_context(), x_data, y_data,
|
||||
z_data, pre, n);
|
||||
PADDLE_ENFORCE(res == xpu::Error_t::SUCCESS, "XPU kernel error! res = %d",
|
||||
res);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (pre != 1 || post != 1) {
|
||||
PADDLE_ENFORCE(xpu_malloc(reinterpret_cast<void**>(&y_broadcast),
|
||||
len * sizeof(T)) == XPU_SUCCESS);
|
||||
int res = xpu::broadcast_ew(dev_ctx.x_context(), y_data, y_broadcast, pre,
|
||||
n, post, xpu::ElementwiseOp::ASSIGN);
|
||||
PADDLE_ENFORCE(res == xpu::Error_t::SUCCESS, "XPU kernel error! res = %d",
|
||||
res);
|
||||
y_data = y_broadcast;
|
||||
}
|
||||
|
||||
Functor functor;
|
||||
int res = functor(dev_ctx.x_context(), x_data, y_data, z_data, len);
|
||||
PADDLE_ENFORCE(res == xpu::Error_t::SUCCESS, "XPU kernel error! res = %d",
|
||||
res);
|
||||
|
||||
if (pre != 1 || post != 1) {
|
||||
dev_ctx.Wait();
|
||||
xpu_free(y_broadcast);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
#endif
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,215 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle.fluid.core as core
|
||||
from op_test import OpTest
|
||||
from scipy.special import expit, erf
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.fluid import compiler, Program, program_guard
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUActivation(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "exp"
|
||||
self.init_dtype()
|
||||
self.init_kernel_type()
|
||||
|
||||
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
|
||||
out = np.exp(x)
|
||||
|
||||
self.attrs = {'use_xpu': True}
|
||||
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
def init_dtype(self):
|
||||
self.dtype = np.float32
|
||||
|
||||
def test_check_output(self):
|
||||
if paddle.is_compiled_with_xpu():
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_output_with_place(place, atol=1e-3)
|
||||
|
||||
def init_kernel_type(self):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUSigmoid(TestXPUActivation):
|
||||
def setUp(self):
|
||||
self.op_type = "sigmoid"
|
||||
self.init_dtype()
|
||||
|
||||
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
|
||||
out = 1 / (1 + np.exp(-x))
|
||||
|
||||
self.attrs = {'use_xpu': True}
|
||||
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
def test_check_grad(self):
|
||||
if paddle.is_compiled_with_xpu():
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_grad_with_place(
|
||||
place, ['X'], 'Out', max_relative_error=0.01)
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUTanh(TestXPUActivation):
|
||||
def setUp(self):
|
||||
self.op_type = "tanh"
|
||||
self.init_dtype()
|
||||
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
|
||||
out = np.tanh(x)
|
||||
|
||||
self.attrs = {'use_xpu': True}
|
||||
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUSqrt(TestXPUActivation):
|
||||
def setUp(self):
|
||||
self.op_type = "sqrt"
|
||||
self.init_dtype()
|
||||
|
||||
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
|
||||
out = np.sqrt(x)
|
||||
|
||||
self.attrs = {'use_xpu': True}
|
||||
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUAbs(TestXPUActivation):
|
||||
def setUp(self):
|
||||
self.op_type = "abs"
|
||||
self.init_dtype()
|
||||
|
||||
x = np.random.uniform(-1, 1, [4, 25]).astype(self.dtype)
|
||||
# Because we set delta = 0.005 in calculating numeric gradient,
|
||||
# if x is too small, such as 0.002, x_neg will be -0.003
|
||||
# x_pos will be 0.007, so the numeric gradient is inaccurate.
|
||||
# we should avoid this
|
||||
x[np.abs(x) < 0.005] = 0.02
|
||||
out = np.abs(x)
|
||||
|
||||
self.attrs = {'use_xpu': True}
|
||||
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPURelu(TestXPUActivation):
|
||||
def setUp(self):
|
||||
self.op_type = "relu"
|
||||
self.init_dtype()
|
||||
|
||||
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
|
||||
# The same reason with TestAbs
|
||||
x[np.abs(x) < 0.005] = 0.02
|
||||
out = np.maximum(x, 0)
|
||||
|
||||
self.attrs = {'use_xpu': True}
|
||||
self.inputs = {'X': x}
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUGelu(TestXPUActivation):
|
||||
def setUp(self):
|
||||
self.op_type = "gelu"
|
||||
self.init_dtype()
|
||||
approximate = False
|
||||
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
|
||||
out = gelu(x, approximate)
|
||||
|
||||
self.inputs = {'X': x}
|
||||
self.outputs = {'Out': out}
|
||||
self.attrs = {"approximate": approximate, 'use_xpu': True}
|
||||
|
||||
|
||||
def gelu(x, approximate):
|
||||
if approximate:
|
||||
y_ref = 0.5 * x * (1.0 + np.tanh(
|
||||
np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
|
||||
else:
|
||||
y_ref = 0.5 * x * (1 + erf(x / np.sqrt(2)))
|
||||
return y_ref.astype(x.dtype)
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPULog(TestXPUActivation):
|
||||
def setUp(self):
|
||||
self.op_type = "log"
|
||||
self.init_dtype()
|
||||
|
||||
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
|
||||
out = np.log(x)
|
||||
|
||||
self.attrs = {'use_xpu': True}
|
||||
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUSquare(TestXPUActivation):
|
||||
def setUp(self):
|
||||
self.op_type = "square"
|
||||
self.init_dtype()
|
||||
|
||||
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
|
||||
out = np.square(x)
|
||||
|
||||
self.attrs = {'use_xpu': True}
|
||||
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUPow(TestXPUActivation):
|
||||
def setUp(self):
|
||||
self.op_type = "pow"
|
||||
self.init_dtype()
|
||||
|
||||
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
|
||||
out = np.power(x, 3)
|
||||
|
||||
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
|
||||
self.attrs = {'factor': 3.0, 'use_xpu': True}
|
||||
self.outputs = {'Out': out}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,161 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.fluid.core as core
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
from op_test import OpTest
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import Program, program_guard
|
||||
|
||||
|
||||
class TestMulOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "mul"
|
||||
self.dtype = np.float64
|
||||
self.init_dtype_type()
|
||||
self.inputs = {
|
||||
'X': np.random.random((20, 5)).astype(self.dtype),
|
||||
'Y': np.random.random((5, 21)).astype(self.dtype)
|
||||
}
|
||||
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
|
||||
|
||||
def init_dtype_type(self):
|
||||
pass
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['X', 'Y'], 'Out')
|
||||
|
||||
def test_check_grad_ingore_x(self):
|
||||
self.check_grad(
|
||||
['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
|
||||
|
||||
def test_check_grad_ingore_y(self):
|
||||
self.check_grad(
|
||||
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
|
||||
|
||||
|
||||
class TestMulOpError(unittest.TestCase):
|
||||
def test_errors(self):
|
||||
with program_guard(Program(), Program()):
|
||||
# The input type of mul_op must be Variable.
|
||||
x1 = fluid.create_lod_tensor(
|
||||
np.array([[-1]]), [[1]], fluid.CPUPlace())
|
||||
x2 = fluid.create_lod_tensor(
|
||||
np.array([[-1]]), [[1]], fluid.CPUPlace())
|
||||
self.assertRaises(TypeError, fluid.layers.mul, x1, x2)
|
||||
# The input dtype of mul_op must be float32 or float64.
|
||||
x3 = fluid.layers.data(name='x3', shape=[4], dtype="int32")
|
||||
x4 = fluid.layers.data(name='x4', shape=[4], dtype="int32")
|
||||
self.assertRaises(TypeError, fluid.layers.mul, x3, x4)
|
||||
|
||||
|
||||
class TestMulOp2(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "mul"
|
||||
self.dtype = np.float64
|
||||
self.init_dtype_type()
|
||||
self.inputs = {
|
||||
'X': np.random.random((3, 4, 2, 9)).astype(self.dtype),
|
||||
'Y': np.random.random((3, 6, 1, 2, 3)).astype(self.dtype)
|
||||
}
|
||||
self.attrs = {
|
||||
'x_num_col_dims': 2,
|
||||
'y_num_col_dims': 2,
|
||||
}
|
||||
result = np.dot(self.inputs['X'].reshape(3 * 4, 2 * 9),
|
||||
self.inputs['Y'].reshape(3 * 6, 1 * 2 * 3))
|
||||
result = result.reshape(3, 4, 1, 2, 3)
|
||||
self.outputs = {'Out': result}
|
||||
|
||||
def init_dtype_type(self):
|
||||
pass
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['X', 'Y'], 'Out')
|
||||
|
||||
def test_check_grad_ingore_x(self):
|
||||
self.check_grad(
|
||||
['Y'], 'Out', max_relative_error=0.5, no_grad_set=set('X'))
|
||||
|
||||
def test_check_grad_ignore_y(self):
|
||||
self.check_grad(
|
||||
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUMulOp1(TestMulOp):
|
||||
def init_dtype_type(self):
|
||||
self.dtype = np.float32
|
||||
|
||||
def test_check_output(self):
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_output_with_place(place, atol=1e-1)
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_grad_with_place(
|
||||
place, ['X', 'Y'], 'Out', max_relative_error=0.5)
|
||||
|
||||
def test_check_grad_ingore_x(self):
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_grad_with_place(
|
||||
place, ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
|
||||
|
||||
def test_check_grad_ingore_y(self):
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_grad_with_place(
|
||||
place, ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPUMulOp2(TestMulOp2):
|
||||
def init_dtype_type(self):
|
||||
self.dtype = np.float32
|
||||
|
||||
def test_check_output(self):
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_output_with_place(place, atol=2e-1)
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_grad_with_place(
|
||||
place, ['X', 'Y'], 'Out', max_relative_error=0.9)
|
||||
|
||||
def test_check_grad_ingore_x(self):
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_grad_with_place(
|
||||
place, ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
|
||||
|
||||
def test_check_grad_ingore_y(self):
|
||||
place = paddle.XPUPlace(0)
|
||||
self.check_grad_with_place(
|
||||
place, ['X'], 'Out', max_relative_error=0.9, no_grad_set=set('Y'))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue