fix huber loss op attr type, test=develop (#19937)

expand_as_op_1
Zeng Jinle 6 years ago committed by GitHub
parent cc157d5990
commit b1e83b33b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,7 +41,7 @@ struct HuberLossForward {
T delta; T delta;
}; };
template <typename DeviceContext, typename T, typename AttrType = T> template <typename DeviceContext, typename T>
class HuberLossKernel : public framework::OpKernel<T> { class HuberLossKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
@ -49,7 +49,7 @@ class HuberLossKernel : public framework::OpKernel<T> {
auto* in1 = context.Input<Tensor>("Y"); auto* in1 = context.Input<Tensor>("Y");
auto* out0 = context.Output<Tensor>("Residual"); auto* out0 = context.Output<Tensor>("Residual");
auto* out1 = context.Output<Tensor>("Out"); auto* out1 = context.Output<Tensor>("Out");
auto delta = static_cast<T>(context.Attr<AttrType>("delta")); auto delta = static_cast<T>(context.Attr<float>("delta"));
auto& place = auto& place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
@ -86,7 +86,7 @@ struct HuberLossBackward {
T delta; T delta;
}; };
template <typename DeviceContext, typename T, typename AttrType = T> template <typename DeviceContext, typename T>
class HuberLossGradKernel : public framework::OpKernel<T> { class HuberLossGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
@ -94,7 +94,7 @@ class HuberLossGradKernel : public framework::OpKernel<T> {
auto* in1 = context.Input<Tensor>(framework::GradVarName("Out")); auto* in1 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X")); auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y")); auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
auto delta = static_cast<T>(context.op().Attr<AttrType>("delta")); auto delta = static_cast<T>(context.op().Attr<float>("delta"));
auto& place = auto& place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();

Loading…
Cancel
Save