parent
de2db11735
commit
ae7d22862b
@ -0,0 +1,162 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/group_norm_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using DataLayout = framework::DataLayout;
|
||||
|
||||
class GroupNormOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of GroupNormOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Y"),
|
||||
"Output(Y) of GroupNormOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Mean"),
|
||||
"Output(Mean) of GroupNormOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Variance"),
|
||||
"Output(Variance) of GroupNormOp should not be null.");
|
||||
|
||||
auto x_dim = ctx->GetInputDim("X");
|
||||
auto channel_num = x_dim[1];
|
||||
auto batch_size = x_dim[0];
|
||||
auto groups = ctx->Attrs().Get<int>("groups");
|
||||
PADDLE_ENFORCE_LE(
|
||||
groups, channel_num,
|
||||
"'groups' must be less equal than the number of channels.");
|
||||
PADDLE_ENFORCE_GE(groups, 1, "'groups' must be greater equal than 1.");
|
||||
|
||||
if (ctx->HasInput("Scale")) {
|
||||
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
|
||||
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], channel_num);
|
||||
}
|
||||
if (ctx->HasInput("Bias")) {
|
||||
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
|
||||
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], channel_num);
|
||||
}
|
||||
|
||||
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
|
||||
ctx->SetOutputDim("Mean", {batch_size, groups});
|
||||
ctx->SetOutputDim("Variance", {batch_size, groups});
|
||||
ctx->ShareLoD("X", "Y");
|
||||
}
|
||||
};
|
||||
|
||||
class GroupNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "The input tensor.");
|
||||
AddInput("Scale",
|
||||
"Scale is a 1-dimensional tensor of size C"
|
||||
"that is applied to the output.")
|
||||
.AsDispensable();
|
||||
AddInput("Bias",
|
||||
"Bias is a 1-dimensional tensor of size C "
|
||||
"that is applied to the output")
|
||||
.AsDispensable();
|
||||
AddOutput("Y", "Result after normalization.");
|
||||
AddOutput("Mean", "Mean of each group.").AsIntermediate();
|
||||
AddOutput("Variance", "Variance of each group.").AsIntermediate();
|
||||
|
||||
AddAttr<float>("epsilon",
|
||||
"Constant for numerical stability [default 1e-5].")
|
||||
.SetDefault(1e-5)
|
||||
.AddCustomChecker([](const float &epsilon) {
|
||||
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 1.0f,
|
||||
"'epsilon' should be between 0.0 and 1.0.");
|
||||
});
|
||||
AddAttr<int>("groups", "The number of groups that divided from channels.")
|
||||
.AddCustomChecker([](const int &groups) {
|
||||
PADDLE_ENFORCE_GT(groups, 0, "'groups' should be greater than zero.");
|
||||
});
|
||||
|
||||
AddComment(R"DOC(
|
||||
Group Normalization
|
||||
|
||||
Refer to `Group Normalization <https://arxiv.org/abs/1803.08494>`_
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class GroupNormGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
// check input
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of GroupNormOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Mean"),
|
||||
"Input(Mean) of GroupNormOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Variance"),
|
||||
"Input(Variance) of GroupNormOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
|
||||
"Input(Y@GRAD) of GroupNormOp should not be null.");
|
||||
|
||||
// check output
|
||||
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("Scale"),
|
||||
ctx->GetInputDim("Scale"));
|
||||
}
|
||||
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
|
||||
ctx->SetOutputDim(framework::GradVarName("Bias"),
|
||||
ctx->GetInputDim("Bias"));
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
|
||||
if (var == nullptr) {
|
||||
PADDLE_THROW("can't find Y@GRAD");
|
||||
}
|
||||
const Tensor *t = nullptr;
|
||||
if (var->IsType<Tensor>()) {
|
||||
t = &var->Get<Tensor>();
|
||||
} else if (var->IsType<LoDTensor>()) {
|
||||
t = &var->Get<LoDTensor>();
|
||||
}
|
||||
if (t == nullptr) {
|
||||
PADDLE_THROW("can't find Y@GRAD");
|
||||
}
|
||||
return framework::OpKernelType(framework::ToDataType(t->type()),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(group_norm, ops::GroupNormOp, ops::GroupNormOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(group_norm_grad, ops::GroupNormGradOp);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
group_norm, ops::GroupNormKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::GroupNormKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
group_norm_grad,
|
||||
ops::GroupNormGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::GroupNormGradKernel<paddle::platform::CPUDeviceContext, double>);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,197 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using DataLayout = framework::DataLayout;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class GroupNormKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const float epsilon = ctx.Attr<float>("epsilon");
|
||||
auto* scale = ctx.Input<Tensor>("Scale");
|
||||
auto* bias = ctx.Input<Tensor>("Bias");
|
||||
auto* x = ctx.Input<Tensor>("X");
|
||||
|
||||
auto* y = ctx.Output<Tensor>("Y");
|
||||
auto* mean = ctx.Output<Tensor>("Mean");
|
||||
auto* var = ctx.Output<Tensor>("Variance");
|
||||
const auto groups = ctx.Attr<int>("groups");
|
||||
|
||||
const auto x_dims = x->dims();
|
||||
const int group_size = (x_dims[1] - 1) / groups + 1;
|
||||
|
||||
y->mutable_data<T>(ctx.GetPlace());
|
||||
mean->mutable_data<T>(ctx.GetPlace());
|
||||
var->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto* x_data = x->data<T>();
|
||||
auto* y_data = y->data<T>();
|
||||
auto* mean_data = mean->data<T>();
|
||||
auto* var_data = var->data<T>();
|
||||
|
||||
const T* scale_data = nullptr;
|
||||
if (scale) scale_data = scale->data<T>();
|
||||
const T* bias_data = nullptr;
|
||||
if (bias) bias_data = bias->data<T>();
|
||||
|
||||
int imsize = x_dims[2] * x_dims[3];
|
||||
auto* iter_x_data = x_data;
|
||||
auto* iter_y_data = y_data;
|
||||
for (int bid = 0; bid < x_dims[0]; bid++)
|
||||
for (int gid = 0; gid < groups; gid++) {
|
||||
T x_mean = 0, x_var = 0;
|
||||
int number = std::min(group_size,
|
||||
static_cast<int>(x_dims[1] - gid * group_size));
|
||||
auto* tmp = iter_x_data;
|
||||
for (int cid = 0; cid < number; cid++) {
|
||||
for (int imid = 0; imid < imsize; imid++, iter_x_data++) {
|
||||
x_mean += iter_x_data[0];
|
||||
x_var += iter_x_data[0] * iter_x_data[0];
|
||||
}
|
||||
}
|
||||
x_mean /= number * imsize;
|
||||
x_var /= number * imsize;
|
||||
x_var = x_var - x_mean * x_mean;
|
||||
T var_inv = 1.0 / sqrt(x_var + epsilon);
|
||||
mean_data[bid * groups + gid] = x_mean;
|
||||
var_data[bid * groups + gid] = x_var;
|
||||
for (int cid = 0; cid < number; cid++) {
|
||||
for (int imid = 0; imid < imsize; imid++, tmp++, iter_y_data++) {
|
||||
T val = (tmp[0] - x_mean) * var_inv;
|
||||
if (scale_data) val *= scale_data[gid * group_size + cid];
|
||||
if (bias_data) val += bias_data[gid * group_size + cid];
|
||||
iter_y_data[0] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class GroupNormGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const float epsilon = ctx.Attr<float>("epsilon");
|
||||
auto* x = ctx.Input<Tensor>("X");
|
||||
auto* mean = ctx.Input<Tensor>("Mean");
|
||||
auto* var = ctx.Input<Tensor>("Variance");
|
||||
auto* scale = ctx.Input<Tensor>("Scale");
|
||||
auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
||||
const auto groups = ctx.Attr<int>("groups");
|
||||
|
||||
// init output
|
||||
auto* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
|
||||
auto* d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
|
||||
|
||||
const auto& x_dims = x->dims();
|
||||
const int group_size = (x_dims[1] - 1) / groups + 1;
|
||||
|
||||
// TODO(liangdun): need to check d_x is null
|
||||
math::SetConstant<DeviceContext, T> set_zero;
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
T* d_x_data = nullptr;
|
||||
if (d_x) {
|
||||
d_x->mutable_data<T>(ctx.GetPlace());
|
||||
set_zero(dev_ctx, d_x, static_cast<T>(0));
|
||||
d_x_data = d_x->data<T>();
|
||||
}
|
||||
|
||||
auto* x_data = x->data<T>();
|
||||
auto* y_data = d_y->data<T>();
|
||||
auto* mean_data = mean->data<T>();
|
||||
auto* var_data = var->data<T>();
|
||||
T* d_scale_data = nullptr;
|
||||
if (d_scale) {
|
||||
d_scale->mutable_data<T>(ctx.GetPlace());
|
||||
set_zero(dev_ctx, d_scale, static_cast<T>(0));
|
||||
d_scale_data = d_scale->data<T>();
|
||||
}
|
||||
T* d_bias_data = nullptr;
|
||||
if (d_bias) {
|
||||
d_bias->mutable_data<T>(ctx.GetPlace());
|
||||
set_zero(dev_ctx, d_bias, static_cast<T>(0));
|
||||
d_bias_data = d_bias->data<T>();
|
||||
}
|
||||
|
||||
const T* scale_data = nullptr;
|
||||
if (scale) scale_data = scale->data<T>();
|
||||
|
||||
int imsize = x_dims[2] * x_dims[3];
|
||||
auto* iter_x_data = x_data;
|
||||
auto* iter_d_x_data = d_x_data;
|
||||
auto* iter_y_data = y_data;
|
||||
for (int bid = 0; bid < x_dims[0]; bid++)
|
||||
for (int gid = 0; gid < groups; gid++) {
|
||||
T x_mean = mean_data[bid * groups + gid];
|
||||
T x_var = var_data[bid * groups + gid];
|
||||
T var_inv = 1.0 / sqrt(x_var + epsilon);
|
||||
int number = std::min(group_size,
|
||||
static_cast<int>(x_dims[1] - gid * group_size));
|
||||
auto* tmp = iter_x_data;
|
||||
auto* tmp2 = iter_d_x_data;
|
||||
T d_var_inv = 0, d_x_mean = 0;
|
||||
for (int cid = 0; cid < number; cid++) {
|
||||
for (int imid = 0; imid < imsize;
|
||||
imid++, tmp++, iter_y_data++, iter_d_x_data++) {
|
||||
T val = (tmp[0] - x_mean) * var_inv;
|
||||
T dval = iter_y_data[0];
|
||||
if (d_bias_data) d_bias_data[gid * group_size + cid] += dval;
|
||||
if (d_scale_data)
|
||||
d_scale_data[gid * group_size + cid] += val * dval;
|
||||
if (scale_data) dval = scale_data[gid * group_size + cid] * dval;
|
||||
|
||||
d_var_inv += (tmp[0] - x_mean) * dval;
|
||||
T d_tmp = dval * var_inv;
|
||||
if (d_x_data) iter_d_x_data[0] += d_tmp;
|
||||
d_x_mean -= d_tmp;
|
||||
}
|
||||
}
|
||||
|
||||
T d_x_var =
|
||||
-1.0 / (2 * (x_var + epsilon) * sqrt(x_var + epsilon)) * d_var_inv;
|
||||
d_x_mean -= 2 * d_x_var * x_mean;
|
||||
d_x_var /= number * imsize;
|
||||
d_x_mean /= number * imsize;
|
||||
|
||||
iter_d_x_data = tmp2;
|
||||
|
||||
if (d_x_data) {
|
||||
for (int cid = 0; cid < number; cid++) {
|
||||
for (int imid = 0; imid < imsize;
|
||||
imid++, iter_x_data++, iter_d_x_data++) {
|
||||
iter_d_x_data[0] += d_x_mean;
|
||||
iter_d_x_data[0] += iter_x_data[0] * 2 * d_x_var;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,143 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from operator import mul
|
||||
import paddle.fluid.core as core
|
||||
import paddle.fluid as fluid
|
||||
from op_test import OpTest
|
||||
|
||||
from testsuite import create_op
|
||||
|
||||
|
||||
def group_norm_naive(x, scale, bias, epsilon, groups):
|
||||
N, C, H, W = x.shape
|
||||
G = groups
|
||||
x = x.reshape((N * G, -1))
|
||||
mean = np.mean(x, axis=1, keepdims=True)
|
||||
var = np.var(x, axis=1, keepdims=True)
|
||||
output = (x - mean) / np.sqrt(var + epsilon)
|
||||
output = output.reshape((N, C, H, W)) * scale.reshape(
|
||||
(-1, 1, 1)) + bias.reshape((-1, 1, 1))
|
||||
return output, mean.reshape((N, G)), var.reshape((N, G))
|
||||
|
||||
|
||||
class TestGroupNormOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "group_norm"
|
||||
self.data_format = "NCHW"
|
||||
self.dtype = np.float32
|
||||
self.shape = (2, 4, 3, 3)
|
||||
self.attrs = {'epsilon': 1e-5, 'groups': 2}
|
||||
self.compare_between_place = False
|
||||
self.init_test_case()
|
||||
|
||||
input = np.random.random(self.shape).astype(self.dtype)
|
||||
scale = np.random.random([self.shape[1]]).astype(self.dtype)
|
||||
bias = np.random.random([self.shape[1]]).astype(self.dtype)
|
||||
output, mean, var = group_norm_naive(
|
||||
input, scale, bias, self.attrs['epsilon'], self.attrs['groups'])
|
||||
|
||||
self.inputs = {
|
||||
'X': OpTest.np_dtype_to_fluid_dtype(input),
|
||||
'Scale': OpTest.np_dtype_to_fluid_dtype(scale),
|
||||
'Bias': OpTest.np_dtype_to_fluid_dtype(bias)
|
||||
}
|
||||
self.outputs = {'Y': output, 'Mean': mean, 'Variance': var}
|
||||
|
||||
def test_check_output(self):
|
||||
atol = 1e-4
|
||||
place = core.CPUPlace()
|
||||
self.check_output_with_place(place, atol=atol)
|
||||
if core.is_compiled_with_cuda():
|
||||
place = core.CUDAPlace(0)
|
||||
self.check_output_with_place(place, atol=atol)
|
||||
|
||||
def do_compare_between_place(self):
|
||||
if not core.is_compiled_with_cuda(): return
|
||||
place = core.CPUPlace()
|
||||
place2 = core.CUDAPlace(0)
|
||||
self.scope = core.Scope()
|
||||
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
|
||||
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
|
||||
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
|
||||
self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs,
|
||||
op_attrs)
|
||||
inputs_to_check = set(['X', 'Scale', 'Bias'])
|
||||
output_names = 'Y'
|
||||
cpu_grads = self._get_gradient(inputs_to_check, place, output_names,
|
||||
None)
|
||||
gpu_grads = self._get_gradient(inputs_to_check, place2, output_names,
|
||||
None)
|
||||
self._assert_is_close(cpu_grads, gpu_grads, inputs_to_check, 0.005,
|
||||
"Gradient Check On %s" % str(place))
|
||||
|
||||
def test_check_grad(self):
|
||||
if self.compare_between_place:
|
||||
self.do_compare_between_place()
|
||||
return
|
||||
place = core.CPUPlace()
|
||||
self.check_grad_with_place(
|
||||
place, set(['X', 'Scale', 'Bias']), 'Y', max_relative_error=0.01)
|
||||
if core.is_compiled_with_cuda():
|
||||
place = core.CUDAPlace(0)
|
||||
self.check_grad_with_place(
|
||||
place,
|
||||
set(['X', 'Scale', 'Bias']),
|
||||
'Y',
|
||||
max_relative_error=0.01)
|
||||
|
||||
def init_test_case(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestGroupNormOp1(TestGroupNormOp):
|
||||
def init_test_case(self):
|
||||
self.attrs['groups'] = 1
|
||||
|
||||
|
||||
class TestGroupNormOp2(TestGroupNormOp):
|
||||
def init_test_case(self):
|
||||
self.attrs['groups'] = 4
|
||||
|
||||
|
||||
class TestGroupNormOpBigEps1(TestGroupNormOp):
|
||||
def init_test_case(self):
|
||||
self.attrs['groups'] = 1
|
||||
self.attrs['epsilon'] = 0.5
|
||||
|
||||
|
||||
class TestGroupNormOpBigEps2(TestGroupNormOp):
|
||||
def init_test_case(self):
|
||||
self.attrs['groups'] = 4
|
||||
self.attrs['epsilon'] = 0.5
|
||||
|
||||
|
||||
class TestGroupNormOpBigEps3(TestGroupNormOp):
|
||||
def init_test_case(self):
|
||||
self.attrs['epsilon'] = 0.5
|
||||
|
||||
|
||||
class TestGroupNormOpLargeData(TestGroupNormOp):
|
||||
def init_test_case(self):
|
||||
self.shape = (2, 32, 64, 64)
|
||||
self.attrs['groups'] = 8
|
||||
self.compare_between_place = True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue