|
|
@ -82,7 +82,12 @@ class ClipByNormKernel : public framework::OpKernel<T> {
|
|
|
|
auto scaling = temp + (static_cast<T>(1) - temp) * max_norm / x_norm;
|
|
|
|
auto scaling = temp + (static_cast<T>(1) - temp) * max_norm / x_norm;
|
|
|
|
Eigen::array<int, 1> one_dim{{1}};
|
|
|
|
Eigen::array<int, 1> one_dim{{1}};
|
|
|
|
Eigen::DSizes<int, 1> m_dsize(input->numel());
|
|
|
|
Eigen::DSizes<int, 1> m_dsize(input->numel());
|
|
|
|
out.device(place) = x * scaling.reshape(one_dim).broadcast(m_dsize);
|
|
|
|
if (context.GetPlace() == platform::CPUPlace()) {
|
|
|
|
|
|
|
|
out.device(place) =
|
|
|
|
|
|
|
|
x * scaling.reshape(one_dim).eval().broadcast(m_dsize);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
out.device(place) = x * scaling.reshape(one_dim).broadcast(m_dsize);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|