Add bce_loss op (#23388)
* add bce_loss * fix mistake * replace paddle_enforce,test=develop * fix,test=develop * update,test=develop * remove duplicate,test=develop * update,test=develop * update error,test=develop * update,test=develop * fix unittest, test=develop * update, test=developrevert-23830-2.0-beta
parent
faf284a9b3
commit
ab05cdc46e
@ -0,0 +1,180 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/bce_loss_op.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using framework::Tensor;
|
||||||
|
|
||||||
|
class BCELossOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
ctx->HasInput("X"), true,
|
||||||
|
platform::errors::InvalidArgument("Input(X) should be not null."));
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
ctx->HasInput("Label"), true,
|
||||||
|
platform::errors::InvalidArgument("Input(Label) should be not null."));
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
ctx->HasOutput("Out"), true,
|
||||||
|
platform::errors::InvalidArgument("Output(Out) should be not null."));
|
||||||
|
|
||||||
|
auto x_dims = ctx->GetInputDim("X");
|
||||||
|
auto label_dims = ctx->GetInputDim("Label");
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
x_dims.size(), label_dims.size(),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input(X) and Input(Label) shall have the same shape."));
|
||||||
|
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
|
||||||
|
framework::contain_unknown_dim(label_dims);
|
||||||
|
bool check = ctx->IsRuntime() || !contain_unknown_dim;
|
||||||
|
if (check) {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
x_dims.size(), label_dims.size(),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"ShapeError: Input(X) and Input(Label) shall have the same shape "
|
||||||
|
"But received: the shape of Input(X) is [%s], the shape of "
|
||||||
|
"Input(Label) is [%s].",
|
||||||
|
x_dims, label_dims));
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->ShareDim("X", "Out");
|
||||||
|
ctx->ShareLoD("X", "Out");
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
return framework::OpKernelType(
|
||||||
|
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
|
||||||
|
ctx.device_context());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class BCELossGradOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
ctx->HasInput("X"), true,
|
||||||
|
platform::errors::InvalidArgument("Input(X) should be not null."));
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
ctx->HasInput("Label"), true,
|
||||||
|
platform::errors::InvalidArgument("Input(Label) should be not null."));
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input(Out@GRAD) shoudl be not null."));
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Output(X@GRAD) should be not null."));
|
||||||
|
|
||||||
|
auto x_dims = ctx->GetInputDim("X");
|
||||||
|
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
||||||
|
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
|
||||||
|
framework::contain_unknown_dim(dout_dims);
|
||||||
|
bool check = ctx->IsRuntime() || !contain_unknown_dim;
|
||||||
|
if (check) {
|
||||||
|
PADDLE_ENFORCE_EQ(x_dims, dout_dims,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"ShapeError:The Input(X) and Input(Out@Grad) "
|
||||||
|
"should have the same "
|
||||||
|
"shape, But received: the shape of Input(X) is "
|
||||||
|
"[%s], the shape of "
|
||||||
|
"Input(Out@GRAD) is [%s].",
|
||||||
|
x_dims, dout_dims));
|
||||||
|
}
|
||||||
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
||||||
|
ctx->ShareLoD("X", framework::GradVarName("X"));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
return framework::OpKernelType(
|
||||||
|
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
|
||||||
|
ctx.device_context());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class BCELossOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("X",
|
||||||
|
"(Tensor, default Tensor<float>), the input is a tensor of logits"
|
||||||
|
"computed by the previous operator, which is always the result of"
|
||||||
|
"a sigmoid operator. Input must between in 0 and 1.");
|
||||||
|
AddInput("Label",
|
||||||
|
"(Tensor, default Tensor<float>), have same shape with input"
|
||||||
|
"label should between in 0 and 1.");
|
||||||
|
AddOutput("Out",
|
||||||
|
"(Tensor, default Tensor<float>), have same shape with"
|
||||||
|
"input");
|
||||||
|
AddComment(R"DOC(
|
||||||
|
BinaryCrossEntropy operator.
|
||||||
|
|
||||||
|
This measures the element-wise probability error in classification tasks
|
||||||
|
in which each class is independent.
|
||||||
|
|
||||||
|
The logitstic loss is given as follows:
|
||||||
|
$$loss = -Label * \log(X) - (1 - Label) * \log(1 - X)$$
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class BCELossGradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||||
|
public:
|
||||||
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void Apply(GradOpPtr<T> op) const override {
|
||||||
|
op->SetType("bce_loss_grad");
|
||||||
|
op->SetInput("X", this->Input("X"));
|
||||||
|
op->SetInput("Label", this->Input("Label"));
|
||||||
|
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||||
|
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||||
|
// op->SetAttrMap(this->Attrs());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
DECLARE_INPLACE_OP_INFERER(BCELossInplaceInferer, {"X", "Out"});
|
||||||
|
DECLARE_INPLACE_OP_INFERER(BCELossGradInplaceInferer,
|
||||||
|
{framework::GradVarName("Out"),
|
||||||
|
framework::GradVarName("X")});
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(bce_loss, ops::BCELossOp, ops::BCELossOpMaker,
|
||||||
|
ops::BCELossGradOpMaker<paddle::framework::OpDesc>,
|
||||||
|
ops::BCELossGradOpMaker<paddle::imperative::OpBase>,
|
||||||
|
ops::BCELossInplaceInferer);
|
||||||
|
REGISTER_OPERATOR(bce_loss_grad, ops::BCELossGradOp,
|
||||||
|
ops::BCELossGradInplaceInferer);
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
bce_loss, ops::BCELossOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||||
|
ops::BCELossOpKernel<paddle::platform::CPUDeviceContext, double>);
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
bce_loss_grad,
|
||||||
|
ops::BCELossGradOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||||
|
ops::BCELossGradOpKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,133 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
#include <algorithm>
|
||||||
|
#include "cub/cub.cuh"
|
||||||
|
#include "paddle/fluid/operators/bce_loss_op.h"
|
||||||
|
#include "paddle/fluid/operators/math.h"
|
||||||
|
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||||
|
#include "paddle/fluid/platform/gpu_launch_config.h"
|
||||||
|
#include "paddle/fluid/platform/hostdevice.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||||
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||||
|
i += blockDim.x * gridDim.x)
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void GPUBCELossForward(const T* x_data, const T* label_data,
|
||||||
|
T* out_data, const int in_numel) {
|
||||||
|
CUDA_1D_KERNEL_LOOP(i, in_numel) {
|
||||||
|
T x = x_data[i];
|
||||||
|
T label = label_data[i];
|
||||||
|
T one = static_cast<T>(1.);
|
||||||
|
T neg_100 = static_cast<T>(-100.);
|
||||||
|
|
||||||
|
T term1 = max(real_log(x), neg_100);
|
||||||
|
T term2 = max(real_log(one - x), neg_100);
|
||||||
|
|
||||||
|
out_data[i] = ((label - one) * term2) - (label * term1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void GPUBCELossBackward(const T* x_data, const T* label_data,
|
||||||
|
const T* dout_data, T* dx_data,
|
||||||
|
const int in_numel) {
|
||||||
|
CUDA_1D_KERNEL_LOOP(i, in_numel) {
|
||||||
|
T x = x_data[i];
|
||||||
|
T label = label_data[i];
|
||||||
|
T dout = dout_data[i];
|
||||||
|
T one = static_cast<T>(1.);
|
||||||
|
T eps = static_cast<T>(1e-12);
|
||||||
|
|
||||||
|
T term1 = max((one - x) * x, eps);
|
||||||
|
|
||||||
|
dx_data[i] = dout * (x - label) / term1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class BCELossCUDAKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* x = ctx.Input<Tensor>("X");
|
||||||
|
auto* labels = ctx.Input<Tensor>("Label");
|
||||||
|
auto* out = ctx.Output<Tensor>("Out");
|
||||||
|
|
||||||
|
auto x_data = x->data<T>();
|
||||||
|
auto out_data = out->mutable_data<T>(ctx.GetPlace());
|
||||||
|
int x_numel = x->numel();
|
||||||
|
platform::GpuLaunchConfig config =
|
||||||
|
platform::getGpuLaunchConfig(x_numel, ctx);
|
||||||
|
|
||||||
|
Tensor x_cpu;
|
||||||
|
framework::TensorCopy(*x, platform::CPUPlace(), &x_cpu);
|
||||||
|
T* x_cpu_data = x_cpu.data<T>();
|
||||||
|
|
||||||
|
for (int i = 0; i < x_numel; ++i) {
|
||||||
|
PADDLE_ENFORCE_GE(
|
||||||
|
x_cpu_data[i], static_cast<T>(0),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Illegal input, input must be greater than or equal to 0"));
|
||||||
|
PADDLE_ENFORCE_LE(
|
||||||
|
x_cpu_data[i], static_cast<T>(1),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Illegal input, input must be less than or equal to 1"));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& dev_ctx = ctx.cuda_device_context();
|
||||||
|
|
||||||
|
GPUBCELossForward<
|
||||||
|
T><<<config.blocks, config.threads, 0, dev_ctx.stream()>>>(
|
||||||
|
x_data, labels->data<T>(), out_data, x_numel);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class BCELossGradCUDAKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* x = ctx.Input<Tensor>("X");
|
||||||
|
auto* labels = ctx.Input<Tensor>("Label");
|
||||||
|
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||||
|
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||||
|
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
|
||||||
|
|
||||||
|
int x_numel = x->numel();
|
||||||
|
platform::GpuLaunchConfig config =
|
||||||
|
platform::getGpuLaunchConfig(x_numel, ctx);
|
||||||
|
auto& dev_ctx = ctx.cuda_device_context();
|
||||||
|
|
||||||
|
GPUBCELossBackward<
|
||||||
|
T><<<config.blocks, config.threads, 0, dev_ctx.stream()>>>(
|
||||||
|
x->data<T>(), labels->data<T>(), dout->data<T>(), dx_data, x_numel);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_CUDA_KERNEL(
|
||||||
|
bce_loss,
|
||||||
|
ops::BCELossCUDAKernel<paddle::platform::CUDADeviceContext, float>,
|
||||||
|
ops::BCELossCUDAKernel<paddle::platform::CUDADeviceContext, double>);
|
||||||
|
REGISTER_OP_CUDA_KERNEL(
|
||||||
|
bce_loss_grad,
|
||||||
|
ops::BCELossGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
|
||||||
|
ops::BCELossGradCUDAKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,85 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <algorithm> // for max
|
||||||
|
#include "paddle/fluid/framework/eigen.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/operators/math.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class BCELossOpKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* x = ctx.Input<Tensor>("X");
|
||||||
|
auto* labels = ctx.Input<Tensor>("Label");
|
||||||
|
auto* out = ctx.Output<Tensor>("Out");
|
||||||
|
|
||||||
|
auto x_data = x->data<T>();
|
||||||
|
auto label_data = labels->data<T>();
|
||||||
|
auto out_data = out->mutable_data<T>(ctx.GetPlace());
|
||||||
|
int x_numel = x->numel();
|
||||||
|
|
||||||
|
// out = -(label * ln(x) + (1 - label) * ln(1 - x)) = (label - 1) * ln(1 -
|
||||||
|
// x) - label * ln(x)
|
||||||
|
for (int i = 0; i < x_numel; ++i) {
|
||||||
|
PADDLE_ENFORCE_GE(
|
||||||
|
x_data[i], static_cast<T>(0),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Illegal input, input must be greater than or equal to 0"));
|
||||||
|
PADDLE_ENFORCE_LE(
|
||||||
|
x_data[i], static_cast<T>(1),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Illegal input, input must be less than or equal to 1"));
|
||||||
|
out_data[i] =
|
||||||
|
(label_data[i] - static_cast<T>(1)) *
|
||||||
|
std::max(real_log(static_cast<T>(1) - x_data[i]), (T)(-100)) -
|
||||||
|
label_data[i] * std::max(real_log(x_data[i]), (T)(-100));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class BCELossGradOpKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* x = ctx.Input<Tensor>("X");
|
||||||
|
auto* labels = ctx.Input<Tensor>("Label");
|
||||||
|
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||||
|
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||||
|
|
||||||
|
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
|
||||||
|
auto dout_data = dout->data<T>();
|
||||||
|
auto x_data = x->data<T>();
|
||||||
|
auto label_data = labels->data<T>();
|
||||||
|
|
||||||
|
int x_numel = x->numel();
|
||||||
|
|
||||||
|
// dx = dout * ((x - label)/(x - x^2))
|
||||||
|
for (int i = 0; i < x_numel; ++i) {
|
||||||
|
dx_data[i] =
|
||||||
|
dout_data[i] * ((x_data[i] - label_data[i]) /
|
||||||
|
std::max((static_cast<T>(1) - x_data[i]) * x_data[i],
|
||||||
|
static_cast<T>(1e-12)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,135 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import numpy as np
|
||||||
|
import unittest
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestBCELoss(unittest.TestCase):
|
||||||
|
def test_BCELoss(self):
|
||||||
|
input_np = np.random.random(size=(20, 30)).astype(np.float64)
|
||||||
|
label_np = np.random.random(size=(20, 30)).astype(np.float64)
|
||||||
|
prog = fluid.Program()
|
||||||
|
startup_prog = fluid.Program()
|
||||||
|
places = [fluid.CPUPlace()]
|
||||||
|
if fluid.core.is_compiled_with_cuda():
|
||||||
|
places.append(fluid.CUDAPlace(0))
|
||||||
|
reductions = ['sum', 'mean', 'none']
|
||||||
|
for place in places:
|
||||||
|
for red in reductions:
|
||||||
|
with fluid.program_guard(prog, startup_prog):
|
||||||
|
input = fluid.data(
|
||||||
|
name='input', shape=[None, 30], dtype='float64')
|
||||||
|
label = fluid.data(
|
||||||
|
name='label', shape=[None, 30], dtype='float64')
|
||||||
|
bce_loss = paddle.nn.loss.BCELoss(reduction=red)
|
||||||
|
res = bce_loss(input, label)
|
||||||
|
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
static_result = exe.run(
|
||||||
|
prog,
|
||||||
|
feed={"input": input_np,
|
||||||
|
"label": label_np},
|
||||||
|
fetch_list=[res])
|
||||||
|
|
||||||
|
with fluid.dygraph.guard():
|
||||||
|
bce_loss = paddle.nn.loss.BCELoss(reduction=red)
|
||||||
|
dy_res = bce_loss(
|
||||||
|
fluid.dygraph.to_variable(input_np),
|
||||||
|
fluid.dygraph.to_variable(label_np))
|
||||||
|
dy_result = dy_res.numpy()
|
||||||
|
|
||||||
|
expected = -1 * (label_np * np.log(input_np) +
|
||||||
|
(1. - label_np) * np.log(1. - input_np))
|
||||||
|
if red == 'mean':
|
||||||
|
expected = np.mean(expected)
|
||||||
|
elif red == 'sum':
|
||||||
|
expected = np.sum(expected)
|
||||||
|
else:
|
||||||
|
expected = expected
|
||||||
|
self.assertTrue(np.allclose(static_result, expected))
|
||||||
|
self.assertTrue(np.allclose(static_result, dy_result))
|
||||||
|
self.assertTrue(np.allclose(dy_result, expected))
|
||||||
|
|
||||||
|
def test_BCELoss_weight(self):
|
||||||
|
input_np = np.random.random(size=(20, 30)).astype(np.float64)
|
||||||
|
label_np = np.random.random(size=(20, 30)).astype(np.float64)
|
||||||
|
weight_np = np.random.random(size=(20, 30)).astype(np.float64)
|
||||||
|
prog = fluid.Program()
|
||||||
|
startup_prog = fluid.Program()
|
||||||
|
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
|
||||||
|
) else fluid.CPUPlace()
|
||||||
|
with fluid.program_guard(prog, startup_prog):
|
||||||
|
input = fluid.data(name='input', shape=[None, 30], dtype='float64')
|
||||||
|
label = fluid.data(name='label', shape=[None, 30], dtype='float64')
|
||||||
|
weight = fluid.data(
|
||||||
|
name='weight', shape=[None, 30], dtype='float64')
|
||||||
|
bce_loss = paddle.nn.loss.BCELoss(weight=weight)
|
||||||
|
res = bce_loss(input, label)
|
||||||
|
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
static_result = exe.run(prog,
|
||||||
|
feed={
|
||||||
|
"input": input_np,
|
||||||
|
"label": label_np,
|
||||||
|
"weight": weight_np
|
||||||
|
},
|
||||||
|
fetch_list=[res])
|
||||||
|
|
||||||
|
with fluid.dygraph.guard():
|
||||||
|
bce_loss = paddle.nn.loss.BCELoss(
|
||||||
|
weight=fluid.dygraph.to_variable(weight_np))
|
||||||
|
dy_res = bce_loss(
|
||||||
|
fluid.dygraph.to_variable(input_np),
|
||||||
|
fluid.dygraph.to_variable(label_np))
|
||||||
|
dy_result = dy_res.numpy()
|
||||||
|
|
||||||
|
expected = np.mean(-1 * weight_np *
|
||||||
|
(label_np * np.log(input_np) +
|
||||||
|
(1. - label_np) * np.log(1. - input_np)))
|
||||||
|
self.assertTrue(np.allclose(static_result, expected))
|
||||||
|
self.assertTrue(np.allclose(static_result, dy_result))
|
||||||
|
self.assertTrue(np.allclose(dy_result, expected))
|
||||||
|
|
||||||
|
|
||||||
|
def bce_loss(input, label):
|
||||||
|
return -1 * (label * np.log(input) + (1. - label) * np.log(1. - input))
|
||||||
|
|
||||||
|
|
||||||
|
class TestBceLossOp(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.init_test_case()
|
||||||
|
self.op_type = "bce_loss"
|
||||||
|
input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float64")
|
||||||
|
label_np = np.random.randint(0, 2, self.shape).astype("float64")
|
||||||
|
output_np = bce_loss(input_np, label_np)
|
||||||
|
|
||||||
|
self.inputs = {'X': input_np, 'Label': label_np}
|
||||||
|
self.outputs = {'Out': output_np}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
self.check_grad(['X'], 'Out')
|
||||||
|
|
||||||
|
def init_test_case(self):
|
||||||
|
self.shape = [10, 10]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue