Merge pull request #16135 from heavengate/shift
Add temporal_shift op for TSM modelrevert-16555-model_data_cryption_link_all_lib
commit
63ac947e2f
@ -0,0 +1,155 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/temporal_shift_op.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class TemporalShiftOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of TemporalShiftOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of TemporalShiftOp should not be null.");
|
||||
|
||||
auto dim_x = ctx->GetInputDim("X");
|
||||
PADDLE_ENFORCE_EQ(dim_x.size(), 4,
|
||||
"Input(X) rank should be 4 in shape of [N*T, C, H, W].");
|
||||
|
||||
int seg_num = ctx->Attrs().Get<int>("seg_num");
|
||||
float shift_ratio = ctx->Attrs().Get<float>("shift_ratio");
|
||||
PADDLE_ENFORCE_GT(seg_num, 0, "Attr(seg_num) should be greater than 0.");
|
||||
PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5,
|
||||
"Attr(shift_ratio) should be greater than 0 and less "
|
||||
"than 0.5.");
|
||||
|
||||
if (ctx->IsRuntime()) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
dim_x[0] % seg_num, 0,
|
||||
"Input(X) dims[0] should be divided exactly by Attr(seg_num).");
|
||||
}
|
||||
|
||||
ctx->SetOutputDim("Out", dim_x);
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"The input tensor of temporal shift operator. "
|
||||
"This is a 4-D tensor with shape of [N*T, C, H, W]. "
|
||||
"While N is the batch size, T is the temporal segment "
|
||||
"number, C is the channel number, H is the height of "
|
||||
"features and W is the width of features.");
|
||||
AddOutput("Out",
|
||||
"The output tensor of temporal shift operator. "
|
||||
"This is a 4-D tensor in the same shape with Input(X).");
|
||||
|
||||
AddAttr<int>("seg_num",
|
||||
"The temporal segment number, this should be a positive "
|
||||
"integer.");
|
||||
AddAttr<float>(
|
||||
"shift_ratio",
|
||||
"The shift ratio of the channels, the first :attr:`shift_ratio` part "
|
||||
"of channels will be shifted by -1 along the temporal dimension, "
|
||||
"and the second :attr:`shift_ratio` part of channels will be shifted "
|
||||
"by 1 along the temporal dimension. Default 0.25.")
|
||||
.SetDefault(0.25);
|
||||
|
||||
AddComment(R"DOC(
|
||||
This operator calculates the temporal shifting features for Input(X).
|
||||
|
||||
Input(X) should be in shape of [N*T, C, H, W], while N is the batch
|
||||
size, T is the temporal segment number specified by :attr:`seg_num`,
|
||||
C is the channel number, H and W is the height and width of features.
|
||||
|
||||
Temporal Shifting is calculated as follows:
|
||||
|
||||
Step 1: Reshape Input(X) to [N, T, C, H, W].
|
||||
|
||||
Step 2: Pad 0 to reshaping result in the 2nd(T) dimension with
|
||||
padding width as 1 on each side, padding result will be in shape
|
||||
of [N, T+2, C, H, W].
|
||||
|
||||
Step 3: Assume :attr:`shift_ratio` is :math:`1/4`, slice padding
|
||||
result as follows:
|
||||
|
||||
$$
|
||||
slice1 = x[:, :T, :C/4, :, :]
|
||||
$$
|
||||
$$
|
||||
slice2 = x[:, 2:T+2, C/4:C/2, :, :]
|
||||
$$
|
||||
$$
|
||||
slice3 = x[:, 1:T+1, C/2:, :, :]
|
||||
$$
|
||||
|
||||
Step 4: Concatenate three slices along the 3rd(C) dimension and
|
||||
reshape result to [N*T, C, H, W].
|
||||
|
||||
For details of temporal shifting, please refer to paper:
|
||||
`Temporal Shift Module <http://arxiv.org/abs/1811.08383>`_ .
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class TemporalShiftOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) should not be null");
|
||||
auto dim_x = ctx->GetInputDim("X");
|
||||
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
|
||||
}
|
||||
}
|
||||
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp,
|
||||
ops::TemporalShiftOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>,
|
||||
ops::TemporalShiftKernel<double>);
|
||||
REGISTER_OP_CPU_KERNEL(temporal_shift_grad, ops::TemporalShiftGradKernel<float>,
|
||||
ops::TemporalShiftGradKernel<double>);
|
@ -0,0 +1,168 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/temporal_shift_op.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
template <typename T>
|
||||
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
|
||||
const int tchw, const int chw, const int hw,
|
||||
const int w, const int t, const int c,
|
||||
const float shift_ratio) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
int src_it = 0;
|
||||
for (; tid < ntchw; tid += stride) {
|
||||
int in = tid / tchw;
|
||||
int it = (tid % tchw) / chw;
|
||||
int ic = (tid % chw) / hw;
|
||||
int ih = (tid % hw) / w;
|
||||
int iw = tid % w;
|
||||
|
||||
const int c1 = static_cast<T>(c * shift_ratio);
|
||||
const int c2 = static_cast<T>(c * 2 * shift_ratio);
|
||||
|
||||
if (ic < c1) {
|
||||
src_it = it - 1;
|
||||
} else if (ic < c2) {
|
||||
src_it = it + 1;
|
||||
} else {
|
||||
src_it = it;
|
||||
}
|
||||
|
||||
if (src_it < 0 || src_it >= t) {
|
||||
output[tid] = 0;
|
||||
} else {
|
||||
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
|
||||
output[tid] = input[src_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad,
|
||||
const int ntchw, const int tchw,
|
||||
const int chw, const int hw, const int w,
|
||||
const int t, const int c,
|
||||
const float shift_ratio) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
int src_it = 0;
|
||||
for (; tid < ntchw; tid += stride) {
|
||||
int in = tid / tchw;
|
||||
int it = (tid % tchw) / chw;
|
||||
int ic = (tid % chw) / hw;
|
||||
int ih = (tid % hw) / w;
|
||||
int iw = tid % w;
|
||||
|
||||
const int c1 = static_cast<T>(c * shift_ratio);
|
||||
const int c2 = static_cast<T>(c * 2 * shift_ratio);
|
||||
|
||||
if (ic < c1) {
|
||||
src_it = it - 1;
|
||||
} else if (ic < c2) {
|
||||
src_it = it + 1;
|
||||
} else {
|
||||
src_it = it;
|
||||
}
|
||||
|
||||
if (src_it >= 0 && src_it < t) {
|
||||
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
|
||||
input_grad[src_idx] = output_grad[tid];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
auto* input = ctx.Input<Tensor>("X");
|
||||
auto* output = ctx.Output<Tensor>("Out");
|
||||
int t = ctx.Attr<int>("seg_num");
|
||||
float shift_ratio = ctx.Attr<float>("shift_ratio");
|
||||
|
||||
const int nt = input->dims()[0];
|
||||
const int c = input->dims()[1];
|
||||
const int h = input->dims()[2];
|
||||
const int w = input->dims()[3];
|
||||
|
||||
const int hw = h * w;
|
||||
const int chw = c * hw;
|
||||
const int tchw = t * chw;
|
||||
const int ntchw = nt * chw;
|
||||
|
||||
const T* input_data = input->data<T>();
|
||||
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
|
||||
|
||||
int pixelNum = nt * chw;
|
||||
int grid_dim = (pixelNum + 512 - 1) / 512;
|
||||
grid_dim = grid_dim > 8 ? 8 : grid_dim;
|
||||
|
||||
KeTemporalShiftFw<
|
||||
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
||||
input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
int t = ctx.Attr<int>("seg_num");
|
||||
float shift_ratio = ctx.Attr<float>("shift_ratio");
|
||||
|
||||
const int nt = output_grad->dims()[0];
|
||||
const int c = output_grad->dims()[1];
|
||||
const int h = output_grad->dims()[2];
|
||||
const int w = output_grad->dims()[3];
|
||||
|
||||
const int hw = h * w;
|
||||
const int chw = c * hw;
|
||||
const int tchw = t * chw;
|
||||
const int ntchw = nt * chw;
|
||||
|
||||
const T* output_grad_data = output_grad->data<T>();
|
||||
T* input_grad_data =
|
||||
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
|
||||
math::SetConstant<platform::CUDADeviceContext, T>()(
|
||||
ctx.template device_context<platform::CUDADeviceContext>(), input_grad,
|
||||
static_cast<T>(0));
|
||||
|
||||
int pixelNum = nt * chw;
|
||||
int grid_dim = (pixelNum + 512 - 1) / 512;
|
||||
grid_dim = grid_dim > 8 ? 8 : grid_dim;
|
||||
|
||||
KeTemporalShiftBw<
|
||||
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
|
||||
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c,
|
||||
shift_ratio);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(temporal_shift, ops::TemporalShiftOpCUDAKernel<float>,
|
||||
ops::TemporalShiftOpCUDAKernel<double>);
|
||||
REGISTER_OP_CUDA_KERNEL(temporal_shift_grad,
|
||||
ops::TemporalShiftGradOpCUDAKernel<float>,
|
||||
ops::TemporalShiftGradOpCUDAKernel<double>);
|
@ -0,0 +1,129 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih,
|
||||
int iw, const int tchw,
|
||||
const int chw, const int hw,
|
||||
const int w) {
|
||||
return in * tchw + it * chw + ic * hw + ih * w + iw;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class TemporalShiftKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<Tensor>("X");
|
||||
auto* output = ctx.Output<Tensor>("Out");
|
||||
int t = ctx.Attr<int>("seg_num");
|
||||
float shift_ratio = ctx.Attr<float>("shift_ratio");
|
||||
|
||||
const int nt = input->dims()[0];
|
||||
const int c = input->dims()[1];
|
||||
const int h = input->dims()[2];
|
||||
const int w = input->dims()[3];
|
||||
|
||||
const int c1 = static_cast<int>(c * shift_ratio);
|
||||
const int c2 = static_cast<int>(c * 2 * shift_ratio);
|
||||
|
||||
const int hw = h * w;
|
||||
const int chw = c * hw;
|
||||
const int tchw = t * chw;
|
||||
|
||||
const T* input_data = input->data<T>();
|
||||
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
|
||||
|
||||
int src_it = 0;
|
||||
for (int i = 0; i < output->numel(); i++) {
|
||||
int in = i / tchw;
|
||||
int it = (i % tchw) / chw;
|
||||
int ic = (i % chw) / hw;
|
||||
int ih = (i % hw) / w;
|
||||
int iw = i % w;
|
||||
|
||||
if (ic < c1) {
|
||||
src_it = it - 1;
|
||||
} else if (ic < c2) {
|
||||
src_it = it + 1;
|
||||
} else {
|
||||
src_it = it;
|
||||
}
|
||||
|
||||
if (src_it < 0 || src_it >= t) {
|
||||
output_data[i] = 0;
|
||||
} else {
|
||||
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
|
||||
output_data[i] = input_data[src_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class TemporalShiftGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
int t = ctx.Attr<int>("seg_num");
|
||||
float shift_ratio = ctx.Attr<float>("shift_ratio");
|
||||
|
||||
const int nt = output_grad->dims()[0];
|
||||
const int c = output_grad->dims()[1];
|
||||
const int h = output_grad->dims()[2];
|
||||
const int w = output_grad->dims()[3];
|
||||
|
||||
const int c1 = static_cast<int>(c * shift_ratio);
|
||||
const int c2 = static_cast<int>(c * 2 * shift_ratio);
|
||||
|
||||
const int hw = h * w;
|
||||
const int chw = c * hw;
|
||||
const int tchw = t * chw;
|
||||
|
||||
const T* output_grad_data = output_grad->data<T>();
|
||||
T* input_grad_data =
|
||||
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
|
||||
memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
|
||||
|
||||
int src_it = 0;
|
||||
for (int i = 0; i < output_grad->numel(); i++) {
|
||||
int in = i / tchw;
|
||||
int it = (i % tchw) / chw;
|
||||
int ic = (i % chw) / hw;
|
||||
int ih = (i % hw) / w;
|
||||
int iw = i % w;
|
||||
|
||||
if (ic < c1) {
|
||||
src_it = it - 1;
|
||||
} else if (ic < c2) {
|
||||
src_it = it + 1;
|
||||
} else {
|
||||
src_it = it;
|
||||
}
|
||||
|
||||
if (src_it >= 0 && src_it < t) {
|
||||
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
|
||||
input_grad_data[src_idx] = output_grad_data[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,81 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
from paddle.fluid import core
|
||||
|
||||
|
||||
def temporal_shift(x, seg_num, shift_ratio):
|
||||
shape = x.shape
|
||||
reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3]))
|
||||
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)),
|
||||
'constant')
|
||||
c1 = int(shape[1] * shift_ratio)
|
||||
c2 = int(shape[1] * 2 * shift_ratio)
|
||||
slice1 = pad_x[:, :seg_num, :c1, :, :]
|
||||
slice2 = pad_x[:, 2:seg_num + 2, c1:c2, :, :]
|
||||
slice3 = pad_x[:, 1:seg_num + 1, c2:, :, :]
|
||||
concat_x = np.concatenate([slice1, slice2, slice3], axis=2)
|
||||
return concat_x.reshape(shape)
|
||||
|
||||
|
||||
class TestTemporalShift(OpTest):
|
||||
def setUp(self):
|
||||
self.initTestCase()
|
||||
self.op_type = 'temporal_shift'
|
||||
x = np.random.random(self.x_shape).astype('float32')
|
||||
|
||||
self.attrs = {
|
||||
"seg_num": self.seg_num,
|
||||
"shift_ratio": self.shift_ratio,
|
||||
}
|
||||
|
||||
self.inputs = {"X": x, }
|
||||
|
||||
output = temporal_shift(x, self.seg_num, self.shift_ratio)
|
||||
self.outputs = {"Out": output}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_ignore_uv(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
def initTestCase(self):
|
||||
self.x_shape = (6, 4, 4, 4)
|
||||
self.seg_num = 3
|
||||
self.shift_ratio = 0.25
|
||||
|
||||
|
||||
class TestTemporalShift2(TestTemporalShift):
|
||||
def initTestCase(self):
|
||||
self.x_shape = (4, 9, 7, 7)
|
||||
self.seg_num = 2
|
||||
self.shift_ratio = 0.2
|
||||
|
||||
|
||||
class TestTemporalShift3(TestTemporalShift):
|
||||
def initTestCase(self):
|
||||
self.x_shape = (3, 10, 5, 5)
|
||||
self.seg_num = 1
|
||||
self.shift_ratio = 0.3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue