|
|
|
@ -43,23 +43,6 @@ struct ElementwiseSubGradFunctor {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ElementwiseSubOneGradFunctor {
|
|
|
|
|
template <typename Device, typename X, typename Y, typename Z, typename dX,
|
|
|
|
|
typename dY, typename dZ>
|
|
|
|
|
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
|
|
|
|
|
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
|
|
|
|
|
if (dx) {
|
|
|
|
|
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
|
|
|
|
|
dx_e.device(d) = dz_e;
|
|
|
|
|
}
|
|
|
|
|
if (dy) {
|
|
|
|
|
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
|
|
|
|
|
dy_e.device(d) = (-1.0) * dz_e.sum();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ElementwiseSubBroadCastGradFunctor {
|
|
|
|
|
template <typename Device, typename X, typename Y, typename Z, typename dX,
|
|
|
|
@ -106,7 +89,6 @@ class ElementwiseSubGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
ElementwiseGradCompute<DeviceContext, T, ElementwiseSubGradFunctor<T>,
|
|
|
|
|
ElementwiseSubOneGradFunctor<T>,
|
|
|
|
|
ElementwiseSubBroadCastGradFunctor<T>,
|
|
|
|
|
ElementwiseSubBroadCast2GradFunctor<T>>(ctx);
|
|
|
|
|
}
|
|
|
|
|