|
|
|
@ -26,10 +26,10 @@ class TemporalShiftOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of TemporalShiftOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of TemporalShiftOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
|
"Input(X) of TemporalShiftOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
"Output(Out) of TemporalShiftOp should not be null.");
|
|
|
|
|
|
|
|
|
|
auto dim_x = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_x.size(), 4,
|
|
|
|
@ -38,9 +38,10 @@ class TemporalShiftOp : public framework::OperatorWithKernel {
|
|
|
|
|
int seg_num = ctx->Attrs().Get<int>("seg_num");
|
|
|
|
|
float shift_ratio = ctx->Attrs().Get<float>("shift_ratio");
|
|
|
|
|
PADDLE_ENFORCE_GT(seg_num, 0, "Attr(seg_num) should be greater than 0.");
|
|
|
|
|
PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5,
|
|
|
|
|
"Attr(shift_ratio) should be greater than 0 and less "
|
|
|
|
|
"than 0.5.");
|
|
|
|
|
PADDLE_ENFORCE_GT(shift_ratio, 0.,
|
|
|
|
|
"Attr(shift_ratio) should be greater than 0");
|
|
|
|
|
PADDLE_ENFORCE_LT(shift_ratio, 0.5,
|
|
|
|
|
"Attr(shift_ratio) should be less than 0.5");
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|