|
|
|
@ -61,8 +61,15 @@ void elementwise_floor_div(const framework::ExecutionContext &ctx,
|
|
|
|
|
const framework::Tensor *x,
|
|
|
|
|
const framework::Tensor *y, framework::Tensor *z) {
|
|
|
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
|
ElementwiseComputeEx<FloorDivFunctor<T>, DeviceContext, T>(
|
|
|
|
|
ctx, x, y, axis, FloorDivFunctor<T>(), z);
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
|
auto y_dims = y->dims();
|
|
|
|
|
if (x_dims.size() >= y_dims.size()) {
|
|
|
|
|
ElementwiseComputeEx<FloorDivFunctor<T>, DeviceContext, T>(
|
|
|
|
|
ctx, x, y, axis, FloorDivFunctor<T>(), z);
|
|
|
|
|
} else {
|
|
|
|
|
ElementwiseComputeEx<InverseFloorDivFunctor<T>, DeviceContext, T>(
|
|
|
|
|
ctx, x, y, axis, InverseFloorDivFunctor<T>(), z);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|