|
|
|
@ -13,9 +13,9 @@
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/roll_op.h"
|
|
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/op_version_registry.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -142,3 +142,17 @@ REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
ops::RollGradKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::RollGradKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::RollGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_VERSION(roll)
|
|
|
|
|
.AddCheckpoint(
|
|
|
|
|
R"ROC(
|
|
|
|
|
Upgrade roll add 1 attribute [axis], delete 1 attribute[dims].
|
|
|
|
|
)ROC",
|
|
|
|
|
paddle::framework::compatible::OpVersionDesc()
|
|
|
|
|
.NewAttr("axis",
|
|
|
|
|
"(std::vector<int64_t>) Axis along which to roll. "
|
|
|
|
|
"It must have the same size with shifts.",
|
|
|
|
|
std::vector<int64_t>())
|
|
|
|
|
.DeleteAttr("dims",
|
|
|
|
|
"(std::vector<int64_t>) Dims along which to roll. "
|
|
|
|
|
"It must have the same size with shifts."));
|
|
|
|
|