fix huber loss op attr type, test=develop (#19937)

expand_as_op_1
Zeng Jinle 6 years ago committed by GitHub
parent cc157d5990
commit b1e83b33b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,7 +41,7 @@ struct HuberLossForward {
T delta;
};
template <typename DeviceContext, typename T, typename AttrType = T>
template <typename DeviceContext, typename T>
class HuberLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
@ -49,7 +49,7 @@ class HuberLossKernel : public framework::OpKernel<T> {
auto* in1 = context.Input<Tensor>("Y");
auto* out0 = context.Output<Tensor>("Residual");
auto* out1 = context.Output<Tensor>("Out");
auto delta = static_cast<T>(context.Attr<AttrType>("delta"));
auto delta = static_cast<T>(context.Attr<float>("delta"));
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
@ -86,7 +86,7 @@ struct HuberLossBackward {
T delta;
};
template <typename DeviceContext, typename T, typename AttrType = T>
template <typename DeviceContext, typename T>
class HuberLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
@ -94,7 +94,7 @@ class HuberLossGradKernel : public framework::OpKernel<T> {
auto* in1 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
auto delta = static_cast<T>(context.op().Attr<AttrType>("delta"));
auto delta = static_cast<T>(context.op().Attr<float>("delta"));
auto& place =
*context.template device_context<DeviceContext>().eigen_device();

Loading…
Cancel
Save