|
|
|
@ -31,6 +31,15 @@ struct ModFunctor {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct InverseModFunctor {
|
|
|
|
|
inline HOSTDEVICE T operator()(T a, T b) const {
|
|
|
|
|
T res = b % a;
|
|
|
|
|
if ((res != 0) && ((res < 0) != (a < 0))) res += a;
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ModFunctorFP {
|
|
|
|
|
inline HOSTDEVICE T operator()(T a, T b) const {
|
|
|
|
@ -40,13 +49,29 @@ struct ModFunctorFP {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct InverseModFunctorFP {
|
|
|
|
|
inline HOSTDEVICE T operator()(T a, T b) const {
|
|
|
|
|
T res = fmod(b, a);
|
|
|
|
|
if ((res != 0) && ((a < 0) != (res < 0))) res += a;
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
void elementwise_mod(const framework::ExecutionContext &ctx,
|
|
|
|
|
const framework::Tensor *x, const framework::Tensor *y,
|
|
|
|
|
framework::Tensor *z) {
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
|
|
|
|
|
ModFunctor<T>(), z);
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
|
auto y_dims = y->dims();
|
|
|
|
|
if (x_dims.size() >= y_dims.size()) {
|
|
|
|
|
ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
|
|
|
|
|
ModFunctor<T>(), z);
|
|
|
|
|
} else {
|
|
|
|
|
ElementwiseComputeEx<InverseModFunctor<T>, DeviceContext, T>(
|
|
|
|
|
ctx, x, y, axis, InverseModFunctor<T>(), z);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
@ -54,8 +79,15 @@ void elementwise_mod_fp(const framework::ExecutionContext &ctx,
|
|
|
|
|
const framework::Tensor *x, const framework::Tensor *y,
|
|
|
|
|
framework::Tensor *z) {
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
ElementwiseComputeEx<ModFunctorFP<T>, DeviceContext, T>(ctx, x, y, axis,
|
|
|
|
|
ModFunctorFP<T>(), z);
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
|
auto y_dims = y->dims();
|
|
|
|
|
if (x_dims.size() >= y_dims.size()) {
|
|
|
|
|
ElementwiseComputeEx<ModFunctorFP<T>, DeviceContext, T>(
|
|
|
|
|
ctx, x, y, axis, ModFunctorFP<T>(), z);
|
|
|
|
|
} else {
|
|
|
|
|
ElementwiseComputeEx<InverseModFunctorFP<T>, DeviceContext, T>(
|
|
|
|
|
ctx, x, y, axis, InverseModFunctorFP<T>(), z);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|