add gru op_register_version; test=op_version; (#29931)

* add gru op_register_version; test=op_version;

* Update fc,mul version;test=op_version;
revert-31562-mean
Jack Zhou 5 years ago committed by GitHub
parent 2b1d796cd0
commit 5a4e42ca9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -203,11 +203,11 @@ REGISTER_PASS_CAPABILITY(mul_gru_fuse_pass)
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0)
.EQ("gru", 0)
.EQ("fusion_gru", 0));
.LE("fusion_gru", 1));
REGISTER_PASS_CAPABILITY(fc_gru_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0)
.EQ("elementwise_add", 0)
.EQ("gru", 0)
.EQ("fusion_gru", 0));
.LE("fusion_gru", 1));

@ -16,6 +16,7 @@ limitations under the License. */
#include <cstring> // for memcpy
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc.h"
@ -479,3 +480,13 @@ REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker);
REGISTER_OP_CPU_KERNEL(fusion_gru, ops::FusionGRUKernel<float>,
ops::FusionGRUKernel<double>);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(fusion_gru)
.AddCheckpoint(
R"ROC(Upgrade fusion_gru add a new attribute [Scale_weights])ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Scale_weights",
"The added attribute 'Scale_weights' is not yet "
"registered.",
{1.0f}));

Loading…
Cancel
Save