|
|
|
@ -35,77 +35,77 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
struct SumFunctor {
|
|
|
|
|
template <typename DeviceContext, typename X, typename Y, typename Dim>
|
|
|
|
|
void operator()(const DeviceContext& place, X& x, Y& y, const Dim& dim) {
|
|
|
|
|
y.device(place) = x.sum(dim);
|
|
|
|
|
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
|
|
|
|
|
y->device(place) = x->sum(dim);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct SumGradFunctor {
|
|
|
|
|
template <typename DeviceContext, typename X, typename Y, typename DX,
|
|
|
|
|
typename DY, typename Dim>
|
|
|
|
|
void operator()(const DeviceContext& place, X& x, Y& y, DX& dx, DY& dy,
|
|
|
|
|
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
|
|
|
|
|
const Dim& dim, int size) {
|
|
|
|
|
dx.device(place) = dy.broadcast(dim);
|
|
|
|
|
dx->device(place) = dy->broadcast(dim);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct MeanFunctor {
|
|
|
|
|
template <typename DeviceContext, typename X, typename Y, typename Dim>
|
|
|
|
|
void operator()(const DeviceContext& place, X& x, Y& y, const Dim& dim) {
|
|
|
|
|
y.device(place) = x.mean(dim);
|
|
|
|
|
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
|
|
|
|
|
y->device(place) = x->mean(dim);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct MeanGradFunctor {
|
|
|
|
|
template <typename DeviceContext, typename X, typename Y, typename DX,
|
|
|
|
|
typename DY, typename Dim>
|
|
|
|
|
void operator()(const DeviceContext& place, X& x, Y& y, DX& dx, DY& dy,
|
|
|
|
|
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
|
|
|
|
|
const Dim& dim, int size) {
|
|
|
|
|
dx.device(place) = dy.broadcast(dim) / dx.constant(size);
|
|
|
|
|
dx->device(place) = dy->broadcast(dim) / dx->constant(size);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct MaxFunctor {
|
|
|
|
|
template <typename DeviceContext, typename X, typename Y, typename Dim>
|
|
|
|
|
void operator()(const DeviceContext& place, X& x, Y& y, const Dim& dim) {
|
|
|
|
|
y.device(place) = x.maximum(dim);
|
|
|
|
|
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
|
|
|
|
|
y->device(place) = x->maximum(dim);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct MinFunctor {
|
|
|
|
|
template <typename DeviceContext, typename X, typename Y, typename Dim>
|
|
|
|
|
void operator()(const DeviceContext& place, X& x, Y& y, const Dim& dim) {
|
|
|
|
|
y.device(place) = x.minimum(dim);
|
|
|
|
|
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
|
|
|
|
|
y->device(place) = x->minimum(dim);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct MaxOrMinGradFunctor {
|
|
|
|
|
template <typename DeviceContext, typename X, typename Y, typename DX,
|
|
|
|
|
typename DY, typename Dim>
|
|
|
|
|
void operator()(const DeviceContext& place, X& x, Y& y, DX& dx, DY& dy,
|
|
|
|
|
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
|
|
|
|
|
const Dim& dim, int size) {
|
|
|
|
|
auto equals = x == y.broadcast(dim);
|
|
|
|
|
auto ones = dx.constant(1);
|
|
|
|
|
auto zeros = dx.constant(0);
|
|
|
|
|
auto equals = (*x) == y->broadcast(dim);
|
|
|
|
|
auto ones = dx->constant(1);
|
|
|
|
|
auto zeros = dx->constant(0);
|
|
|
|
|
// If there are multiple minimum or maximum elements, the subgradient of
|
|
|
|
|
// each is the set [0, 1], and we pass gradient to all of them here.
|
|
|
|
|
dx.device(place) = dy.broadcast(dim) * equals.select(ones, zeros);
|
|
|
|
|
dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ProdFunctor {
|
|
|
|
|
template <typename DeviceContext, typename X, typename Y, typename Dim>
|
|
|
|
|
void operator()(const DeviceContext& place, X& x, Y& y, const Dim& dim) {
|
|
|
|
|
y.device(place) = x.prod(dim);
|
|
|
|
|
void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) {
|
|
|
|
|
y->device(place) = x->prod(dim);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ProdGradFunctor {
|
|
|
|
|
template <typename DeviceContext, typename X, typename Y, typename DX,
|
|
|
|
|
typename DY, typename Dim>
|
|
|
|
|
void operator()(const DeviceContext& place, X& x, Y& y, DX& dx, DY& dy,
|
|
|
|
|
void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
|
|
|
|
|
const Dim& dim, int size) {
|
|
|
|
|
dx.device(place) = dy.broadcast(dim) * y.broadcast(dim) * x.inverse();
|
|
|
|
|
dx->device(place) = dy->broadcast(dim) * y->broadcast(dim) * x->inverse();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -125,7 +125,7 @@ class ReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto reduce_dim = Eigen::array<int, 1>({{0}});
|
|
|
|
|
Functor functor;
|
|
|
|
|
functor(place, x, out, reduce_dim);
|
|
|
|
|
functor(place, &x, &out, reduce_dim);
|
|
|
|
|
} else {
|
|
|
|
|
int rank = context.Input<Tensor>("X")->dims().size();
|
|
|
|
|
switch (rank) {
|
|
|
|
@ -178,10 +178,10 @@ class ReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
if (D == 1) {
|
|
|
|
|
auto out = EigenScalar<T>::From(*output);
|
|
|
|
|
functor(place, x, out, reduce_dim);
|
|
|
|
|
functor(place, &x, &out, reduce_dim);
|
|
|
|
|
} else {
|
|
|
|
|
auto out = EigenTensor<T, (D - 1)>::From(*output, dims);
|
|
|
|
|
functor(place, x, out, reduce_dim);
|
|
|
|
|
functor(place, &x, &out, reduce_dim);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -206,7 +206,7 @@ class ReduceGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto broadcast_dim =
|
|
|
|
|
Eigen::array<int, 1>({{static_cast<int>(input0->numel())}});
|
|
|
|
|
Functor functor;
|
|
|
|
|
functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim,
|
|
|
|
|
functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim,
|
|
|
|
|
broadcast_dim[0]);
|
|
|
|
|
} else {
|
|
|
|
|
int rank = context.Input<Tensor>("X")->dims().size();
|
|
|
|
@ -258,7 +258,7 @@ class ReduceGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto& place =
|
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
Functor functor;
|
|
|
|
|
functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim,
|
|
|
|
|
functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim,
|
|
|
|
|
broadcast_dim[dim]);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|