|
|
|
@ -28,7 +28,7 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
platform::is_xpu_place(context.GetPlace()), true,
|
|
|
|
|
platform::errors::InvalidArgument("This kernel only runs on XPU."));
|
|
|
|
|
platform::errors::PreconditionNotMet("This kernel only runs on XPU."));
|
|
|
|
|
const Tensor* logits = context.Input<Tensor>("Logits");
|
|
|
|
|
const Tensor* labels = context.Input<Tensor>("Label");
|
|
|
|
|
Tensor* softmax = context.Output<Tensor>("Softmax");
|
|
|
|
@ -46,8 +46,11 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
context.template device_context<platform::XPUDeviceContext>();
|
|
|
|
|
int r = xpu::softmax2d_forward(dev_ctx.x_context(), logits->data<float>(),
|
|
|
|
|
softmax->data<float>(), n, d);
|
|
|
|
|
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
|
|
|
|
|
platform::errors::InvalidArgument("XPU kernel error!"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
r, xpu::Error_t::SUCCESS,
|
|
|
|
|
platform::errors::External("XPU kernel error. Softmax2d_forward "
|
|
|
|
|
"execution not succeed, error code=%d",
|
|
|
|
|
r));
|
|
|
|
|
// cross_entropy
|
|
|
|
|
auto ignore_index = context.Attr<int>("ignore_index");
|
|
|
|
|
const bool soft_label = context.Attr<bool>("soft_label");
|
|
|
|
@ -61,10 +64,13 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
int* labels_int32_host =
|
|
|
|
|
reinterpret_cast<int*>(std::malloc(n * sizeof(int)));
|
|
|
|
|
int* labels_int32_device = NULL;
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
xpu_malloc(reinterpret_cast<void**>(&labels_int32_device),
|
|
|
|
|
n * sizeof(int)),
|
|
|
|
|
XPU_SUCCESS, platform::errors::InvalidArgument("XPU kernel error!"));
|
|
|
|
|
int ret = xpu_malloc(reinterpret_cast<void**>(&labels_int32_device),
|
|
|
|
|
n * sizeof(int));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
|
|
|
|
|
platform::errors::External(
|
|
|
|
|
"XPU API return wrong value[%d], please check "
|
|
|
|
|
"where Baidu Kunlun Card is properly installed.",
|
|
|
|
|
ret));
|
|
|
|
|
dev_ctx.Wait();
|
|
|
|
|
memory::Copy(platform::CPUPlace(), labels_int64_host,
|
|
|
|
|
BOOST_GET_CONST(platform::XPUPlace, context.GetPlace()),
|
|
|
|
@ -78,8 +84,11 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
int r = xpu::cross_entropy_forward(
|
|
|
|
|
dev_ctx.x_context(), n, d, softmax->data<float>(),
|
|
|
|
|
labels_int32_device, loss->data<float>(), nullptr, ignore_index);
|
|
|
|
|
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
|
|
|
|
|
platform::errors::InvalidArgument("XPU kernel error!"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
r, xpu::Error_t::SUCCESS,
|
|
|
|
|
platform::errors::External("XPU kernel error. Cross_entropy_forward "
|
|
|
|
|
"execution not succeed, error code=%d",
|
|
|
|
|
r));
|
|
|
|
|
dev_ctx.Wait();
|
|
|
|
|
std::free(labels_int32_host);
|
|
|
|
|
std::free(labels_int64_host);
|
|
|
|
|