refine SoftmaxFunctor

tonyyang-svail-feed-op-desgin
qijun 8 years ago
parent 79def5e634
commit 84ff7e9784

@ -36,7 +36,7 @@ struct ValueClip {
template <typename Place, typename T>
class SoftmaxFunctor {
public:
void operator()(const framework::ExecutionContext& context,
void operator()(const platform::DeviceContext& context,
const framework::Tensor* X, framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
@ -58,8 +58,8 @@ class SoftmaxFunctor {
.broadcast(one_by_class))
.unaryExpr(ValueClip<T>());
softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(context.GetEigenDevice<Place>()) =
softmax.device(*context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(*context.GetEigenDevice<Place>()) =
(softmax *
softmax.sum(along_class)
.inverse()

@ -35,7 +35,7 @@ class SoftmaxKernel : public framework::OpKernel<T> {
// allocate memory on device.
Y->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<Place, T>()(context, X, Y);
math::SoftmaxFunctor<Place, T>()(context.device_context(), X, Y);
}
};

@ -66,7 +66,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<platform::GPUPlace, T>()(context, logits, softmax);
math::SoftmaxFunctor<platform::GPUPlace, T>()(context.device_context(),
logits, softmax);
math::CrossEntropyFunctor<platform::GPUPlace, T>()(
context, loss, softmax, labels, context.Attr<bool>("softLabel"));
}

@ -40,7 +40,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<platform::CPUPlace, T>()(context, logits, softmax);
math::SoftmaxFunctor<platform::CPUPlace, T>()(context.device_context(),
logits, softmax);
math::CrossEntropyFunctor<platform::CPUPlace, T>()(
context.device_context(), loss, softmax, labels,
context.Attr<bool>("softLabel"));

Loading…
Cancel
Save