|
|
@ -13,6 +13,7 @@
|
|
|
|
// limitations under the License.
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/trace_op.h"
|
|
|
|
#include "paddle/fluid/operators/trace_op.h"
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/op_version_registry.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
@ -88,13 +89,13 @@ class TraceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
R"DOC((int, default 0), the first axis of the 2-D planes from which the diagonals should be taken.
|
|
|
|
R"DOC((int, default 0), the first axis of the 2-D planes from which the diagonals should be taken.
|
|
|
|
Can be either positive or negative. Default: 0.
|
|
|
|
Can be either positive or negative. Default: 0.
|
|
|
|
)DOC")
|
|
|
|
)DOC")
|
|
|
|
.SetDefault(-2);
|
|
|
|
.SetDefault(0);
|
|
|
|
AddAttr<int>(
|
|
|
|
AddAttr<int>(
|
|
|
|
"axis2",
|
|
|
|
"axis2",
|
|
|
|
R"DOC((int, default 1), the second axis of the 2-D planes from which the diagonals should be taken.
|
|
|
|
R"DOC((int, default 1), the second axis of the 2-D planes from which the diagonals should be taken.
|
|
|
|
Can be either positive or negative. Default: 1.
|
|
|
|
Can be either positive or negative. Default: 1.
|
|
|
|
)DOC")
|
|
|
|
)DOC")
|
|
|
|
.SetDefault(-1);
|
|
|
|
.SetDefault(1);
|
|
|
|
AddComment(R"DOC(
|
|
|
|
AddComment(R"DOC(
|
|
|
|
Trace Operator.
|
|
|
|
Trace Operator.
|
|
|
|
Return the sum along diagonals of the input tensor.
|
|
|
|
Return the sum along diagonals of the input tensor.
|
|
|
@ -177,3 +178,21 @@ REGISTER_OP_CPU_KERNEL(
|
|
|
|
paddle::platform::complex64>,
|
|
|
|
paddle::platform::complex64>,
|
|
|
|
ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
paddle::platform::complex128>);
|
|
|
|
paddle::platform::complex128>);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/* ========================== register checkpoint ===========================*/
|
|
|
|
|
|
|
|
REGISTER_OP_VERSION(trace)
|
|
|
|
|
|
|
|
.AddCheckpoint(
|
|
|
|
|
|
|
|
R"ROC(Upgrade trace add a new attribute [axis2])ROC",
|
|
|
|
|
|
|
|
paddle::framework::compatible::OpVersionDesc()
|
|
|
|
|
|
|
|
.NewAttr("axis1",
|
|
|
|
|
|
|
|
"The added attribute 'axis1' is not yet registered.",
|
|
|
|
|
|
|
|
std::vector<float>{0.0f})
|
|
|
|
|
|
|
|
.NewAttr("axis2",
|
|
|
|
|
|
|
|
"The added attribute 'axis2' is not yet registered.",
|
|
|
|
|
|
|
|
std::vector<float>{1.0f})
|
|
|
|
|
|
|
|
.DeleteAttr("dim1",
|
|
|
|
|
|
|
|
"The attribute 'dim1' is not recommend according to "
|
|
|
|
|
|
|
|
"the specification 2.0.")
|
|
|
|
|
|
|
|
.DeleteAttr("dim2",
|
|
|
|
|
|
|
|
"The attribute 'dim2' is not recommend according to "
|
|
|
|
|
|
|
|
"the specification 2.0."));
|
|
|
|