|
|
@ -81,7 +81,12 @@ class ClipByNormKernel : public framework::OpKernel<T> {
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
*context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
|
|
|
|
|
|
|
auto temp = (x_norm <= max_norm).template cast<T>();
|
|
|
|
auto temp = (x_norm <= max_norm).template cast<T>();
|
|
|
|
auto scaling = temp + (static_cast<T>(1) - temp) * max_norm / x_norm;
|
|
|
|
auto epsilon =
|
|
|
|
|
|
|
|
((x_norm <= static_cast<T>(1e-30)).all().template cast<T>()) *
|
|
|
|
|
|
|
|
static_cast<T>(1e-6);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto scaling =
|
|
|
|
|
|
|
|
temp + (static_cast<T>(1) - temp) * max_norm / (x_norm + epsilon);
|
|
|
|
Eigen::array<int, 1> one_dim{{1}};
|
|
|
|
Eigen::array<int, 1> one_dim{{1}};
|
|
|
|
Eigen::DSizes<int, 1> m_dsize(input->numel());
|
|
|
|
Eigen::DSizes<int, 1> m_dsize(input->numel());
|
|
|
|
if (context.GetPlace() == platform::CPUPlace()) {
|
|
|
|
if (context.GetPlace() == platform::CPUPlace()) {
|
|
|
|