|
|
|
@ -26,13 +26,18 @@ public:
|
|
|
|
|
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
|
|
|
|
|
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
|
|
|
|
|
|
|
|
|
|
auto input0 = context.Input<Tensor>("X");
|
|
|
|
|
auto input1 = context.Input<Tensor>("Y");
|
|
|
|
|
auto output = context.Output<Tensor>(0);
|
|
|
|
|
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
|
|
|
|
|
EigenMatrix<T>::From(*context.Input<Tensor>("X"))
|
|
|
|
|
.contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")),
|
|
|
|
|
dim_pair);
|
|
|
|
|
auto X = EigenMatrix<T>::From(*input0);
|
|
|
|
|
auto Y = EigenMatrix<T>::From(*input1);
|
|
|
|
|
auto Z = EigenMatrix<T>::From(*output);
|
|
|
|
|
auto place = *context.GetEigenDevice<Place>();
|
|
|
|
|
|
|
|
|
|
Z.device(place) = X.contract(Y, dim_pair);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|