use 32 bit index to improve activation ops (#24206)

* improve activation ops performance, test=develop

* use 32bit only GPU computation, test=develop
revert-24314-dev/fix_err_msg
Zhang Ting 5 years ago committed by GitHub
parent 89c76a5342
commit b71abeee1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,6 +37,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
using framework::To32BitIndex;
enum ActBwdOpFwdDeps {
kNoDeps = 0x00, // Do not need any forward input/output
kDepX = 0x01, // Only need forward input X
@ -177,7 +179,14 @@ class ActivationKernel
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(*place, x, out);
// use 32bit index to speed up computation
bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
if (use_32bit_index && is_gpu_place) {
functor(*place, To32BitIndex(x), To32BitIndex(out));
} else {
functor(*place, x, out);
}
}
};
@ -208,7 +217,15 @@ class ActivationGradKernel
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(*place, x, out, dout, dx);
// use 32bit index to speed up computation
bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
if (use_32bit_index && is_gpu_place) {
functor(*place, To32BitIndex(x), To32BitIndex(out), To32BitIndex(dout),
To32BitIndex(dx));
} else {
functor(*place, x, out, dout, dx);
}
}
};

Loading…
Cancel
Save