|
|
|
@ -21,7 +21,7 @@ namespace operators {
|
|
|
|
|
|
|
|
|
|
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel {
|
|
|
|
|
class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &context) const override {
|
|
|
|
|
const framework::Tensor *X = context.Input<framework::Tensor>("X");
|
|
|
|
@ -48,7 +48,7 @@ class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel {
|
|
|
|
|
|
|
|
|
|
// dX = sigmoid(X) - labels
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel {
|
|
|
|
|
class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext &context) const override {
|
|
|
|
|
const framework::Tensor *X = context.Input<framework::Tensor>("X");
|
|
|
|
|