|
|
@ -41,7 +41,7 @@ struct HuberLossForward {
|
|
|
|
T delta;
|
|
|
|
T delta;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T, typename AttrType = T>
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
class HuberLossKernel : public framework::OpKernel<T> {
|
|
|
|
class HuberLossKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
@ -49,7 +49,7 @@ class HuberLossKernel : public framework::OpKernel<T> {
|
|
|
|
auto* in1 = context.Input<Tensor>("Y");
|
|
|
|
auto* in1 = context.Input<Tensor>("Y");
|
|
|
|
auto* out0 = context.Output<Tensor>("Residual");
|
|
|
|
auto* out0 = context.Output<Tensor>("Residual");
|
|
|
|
auto* out1 = context.Output<Tensor>("Out");
|
|
|
|
auto* out1 = context.Output<Tensor>("Out");
|
|
|
|
auto delta = static_cast<T>(context.Attr<AttrType>("delta"));
|
|
|
|
auto delta = static_cast<T>(context.Attr<float>("delta"));
|
|
|
|
auto& place =
|
|
|
|
auto& place =
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
|
|
|
|
|
|
@ -86,7 +86,7 @@ struct HuberLossBackward {
|
|
|
|
T delta;
|
|
|
|
T delta;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T, typename AttrType = T>
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
class HuberLossGradKernel : public framework::OpKernel<T> {
|
|
|
|
class HuberLossGradKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
@ -94,7 +94,7 @@ class HuberLossGradKernel : public framework::OpKernel<T> {
|
|
|
|
auto* in1 = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
auto* in1 = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
auto delta = static_cast<T>(context.op().Attr<AttrType>("delta"));
|
|
|
|
auto delta = static_cast<T>(context.op().Attr<float>("delta"));
|
|
|
|
auto& place =
|
|
|
|
auto& place =
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
|
|
|
|
|
|
|