|
|
@ -20,6 +20,7 @@ limitations under the License. */
|
|
|
|
#include <unordered_map>
|
|
|
|
#include <unordered_map>
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/op_version_registry.h"
|
|
|
|
#include "paddle/fluid/operators/common_infer_shape_functions.h"
|
|
|
|
#include "paddle/fluid/operators/common_infer_shape_functions.h"
|
|
|
|
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
|
|
|
|
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
|
|
|
|
#include "paddle/fluid/platform/port.h"
|
|
|
|
#include "paddle/fluid/platform/port.h"
|
|
|
@ -1231,3 +1232,24 @@ REGISTER_OP_CPU_KERNEL(
|
|
|
|
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
ops::AbsGradFunctor<int64_t>>);
|
|
|
|
ops::AbsGradFunctor<int64_t>>);
|
|
|
|
/* ========================================================================== */
|
|
|
|
/* ========================================================================== */
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/* ========================== register checkpoint ===========================*/
|
|
|
|
|
|
|
|
REGISTER_OP_VERSION(leaky_relu)
|
|
|
|
|
|
|
|
.AddCheckpoint(
|
|
|
|
|
|
|
|
R"ROC(fix leaky_relu, bahavior changed when alpha < 0 or alpha > 1)ROC",
|
|
|
|
|
|
|
|
paddle::framework::compatible::OpVersionDesc()
|
|
|
|
|
|
|
|
.BugfixWithBehaviorChanged(
|
|
|
|
|
|
|
|
"leaky_relu calculate formula before checkponit: out = max(x, "
|
|
|
|
|
|
|
|
"alpha * x); after checkpoint: out = x if x > 0 else alpha * "
|
|
|
|
|
|
|
|
"x"));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_VERSION(hard_shrink)
|
|
|
|
|
|
|
|
.AddCheckpoint(
|
|
|
|
|
|
|
|
R"ROC(fix hard_shrink, bahavior changed when threshold<0)ROC",
|
|
|
|
|
|
|
|
paddle::framework::compatible::OpVersionDesc()
|
|
|
|
|
|
|
|
.BugfixWithBehaviorChanged(
|
|
|
|
|
|
|
|
"hard_shrink calculate formula before checkponit: out = x * "
|
|
|
|
|
|
|
|
"((x < -threshold) + (x > threshold)); after checkpoint: out = "
|
|
|
|
|
|
|
|
"x * (((x < -threshold) + (x > threshold)) > 0)"));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/* ========================================================================== */
|
|
|
|