|
|
|
@ -37,6 +37,8 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using framework::To32BitIndex;
|
|
|
|
|
|
|
|
|
|
enum ActBwdOpFwdDeps {
|
|
|
|
|
kNoDeps = 0x00, // Do not need any forward input/output
|
|
|
|
|
kDepX = 0x01, // Only need forward input X
|
|
|
|
@ -177,7 +179,14 @@ class ActivationKernel
|
|
|
|
|
for (auto& attr : attrs) {
|
|
|
|
|
*attr.second = context.Attr<float>(attr.first);
|
|
|
|
|
}
|
|
|
|
|
functor(*place, x, out);
|
|
|
|
|
// use 32bit index to speed up computation
|
|
|
|
|
bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
|
|
|
|
|
bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
|
|
|
|
|
if (use_32bit_index && is_gpu_place) {
|
|
|
|
|
functor(*place, To32BitIndex(x), To32BitIndex(out));
|
|
|
|
|
} else {
|
|
|
|
|
functor(*place, x, out);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -208,7 +217,15 @@ class ActivationGradKernel
|
|
|
|
|
for (auto& attr : attrs) {
|
|
|
|
|
*attr.second = context.Attr<float>(attr.first);
|
|
|
|
|
}
|
|
|
|
|
functor(*place, x, out, dout, dx);
|
|
|
|
|
// use 32bit index to speed up computation
|
|
|
|
|
bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
|
|
|
|
|
bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
|
|
|
|
|
if (use_32bit_index && is_gpu_place) {
|
|
|
|
|
functor(*place, To32BitIndex(x), To32BitIndex(out), To32BitIndex(dout),
|
|
|
|
|
To32BitIndex(dx));
|
|
|
|
|
} else {
|
|
|
|
|
functor(*place, x, out, dout, dx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|