|
|
|
@ -16,6 +16,7 @@ limitations under the License. */
|
|
|
|
|
#include <cstring> // for memcpy
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/op_version_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/jit/kernels.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/fc.h"
|
|
|
|
@ -479,3 +480,13 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(fusion_gru, ops::FusionGRUKernel<float>,
|
|
|
|
|
ops::FusionGRUKernel<double>);
|
|
|
|
|
|
|
|
|
|
/* ========================== register checkpoint ===========================*/
|
|
|
|
|
REGISTER_OP_VERSION(fusion_gru)
|
|
|
|
|
.AddCheckpoint(
|
|
|
|
|
R"ROC(Upgrade fusion_gru add a new attribute [Scale_weights])ROC",
|
|
|
|
|
paddle::framework::compatible::OpVersionDesc().NewAttr(
|
|
|
|
|
"Scale_weights",
|
|
|
|
|
"The added attribute 'Scale_weights' is not yet "
|
|
|
|
|
"registered.",
|
|
|
|
|
{1.0f}));
|
|
|
|
|