|
|
|
@ -24,18 +24,21 @@ class DGCClipByNormKernel : public ClipByNormKernel<DeviceContext, T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto rampup_begin_step = context.Attr<float>("rampup_begin_step");
|
|
|
|
|
if (static_cast<int>(rampup_begin_step) >= 0) {
|
|
|
|
|
auto current_step_tensor =
|
|
|
|
|
context.Input<framework::Tensor>("current_step");
|
|
|
|
|
auto* current_step = current_step_tensor->data<T>();
|
|
|
|
|
|
|
|
|
|
if (static_cast<int>(*current_step) <
|
|
|
|
|
static_cast<int>(rampup_begin_step)) {
|
|
|
|
|
VLOG(10) << "current_step:" << *current_step
|
|
|
|
|
<< " < rampup_begin_step:" << rampup_begin_step
|
|
|
|
|
<< " so does't use dgc_clip_by_norm";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (static_cast<int>(rampup_begin_step) < 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto current_step_tensor = context.Input<framework::Tensor>("current_step");
|
|
|
|
|
auto* current_step = current_step_tensor->data<T>();
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "current_step:" << *current_step
|
|
|
|
|
<< ", rampup_begin_step:" << rampup_begin_step;
|
|
|
|
|
|
|
|
|
|
if (static_cast<int>(*current_step) < static_cast<int>(rampup_begin_step)) {
|
|
|
|
|
VLOG(10) << "current_step:" << *current_step
|
|
|
|
|
<< " < rampup_begin_step:" << rampup_begin_step
|
|
|
|
|
<< " so does't use dgc_clip_by_norm";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ClipByNormKernel<DeviceContext, T>::Compute(context);
|
|
|
|
|