add trace op_register_version and fix version bug; test=op_version (#30000)

* add trace op_register_version and fix defaulf bug; test=op_version

* add trace op_register_version; test=op_version

* add trace op_register_version; test=op_version

* add trace op_register_version; test=op_version

* fix missing the template bug of vector; test=op_version
revert-31562-mean
chentianyu03 4 years ago committed by GitHub
parent 9f34374b48
commit a5e422c85d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/trace_op.h" #include "paddle/fluid/operators/trace_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -88,13 +89,13 @@ class TraceOpMaker : public framework::OpProtoAndCheckerMaker {
R"DOC((int, default 0), the first axis of the 2-D planes from which the diagonals should be taken. R"DOC((int, default 0), the first axis of the 2-D planes from which the diagonals should be taken.
Can be either positive or negative. Default: 0. Can be either positive or negative. Default: 0.
)DOC") )DOC")
.SetDefault(-2); .SetDefault(0);
AddAttr<int>( AddAttr<int>(
"axis2", "axis2",
R"DOC((int, default 1), the second axis of the 2-D planes from which the diagonals should be taken. R"DOC((int, default 1), the second axis of the 2-D planes from which the diagonals should be taken.
Can be either positive or negative. Default: 1. Can be either positive or negative. Default: 1.
)DOC") )DOC")
.SetDefault(-1); .SetDefault(1);
AddComment(R"DOC( AddComment(R"DOC(
Trace Operator. Trace Operator.
Return the sum along diagonals of the input tensor. Return the sum along diagonals of the input tensor.
@ -177,3 +178,21 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex64>, paddle::platform::complex64>,
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex128>);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(trace)
.AddCheckpoint(
R"ROC(Upgrade trace add a new attribute [axis2])ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("axis1",
"The added attribute 'axis1' is not yet registered.",
std::vector<float>{0.0f})
.NewAttr("axis2",
"The added attribute 'axis2' is not yet registered.",
std::vector<float>{1.0f})
.DeleteAttr("dim1",
"The attribute 'dim1' is not recommend according to "
"the specification 2.0.")
.DeleteAttr("dim2",
"The attribute 'dim2' is not recommend according to "
"the specification 2.0."));

Loading…
Cancel
Save