parent
643ff03fbc
commit
2b3ba40e6a
@ -0,0 +1,17 @@
|
||||
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
A set of tools for generating adversarial example on paddle platform
|
||||
"""
|
@ -0,0 +1,42 @@
|
||||
"""
|
||||
The base model of the model.
|
||||
"""
|
||||
from abc import ABCMeta
|
||||
#from advbox.base import Model
|
||||
import abc
|
||||
|
||||
abstractmethod = abc.abstractmethod
|
||||
|
||||
class Attack(object):
|
||||
"""
|
||||
Abstract base class for adversarial attacks. `Attack` represent an adversarial attack
|
||||
which search an adversarial example. subclass should implement the _apply() method.
|
||||
|
||||
Args:
|
||||
model(Model): an instance of the class advbox.base.Model.
|
||||
|
||||
"""
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def __call__(self, image_batch):
|
||||
"""
|
||||
Generate the adversarial sample.
|
||||
|
||||
Args:
|
||||
image_batch(list): The image and label tuple list.
|
||||
"""
|
||||
adv_img = self._apply(image_batch)
|
||||
return adv_img
|
||||
|
||||
@abstractmethod
|
||||
def _apply(self, image_batch):
|
||||
"""
|
||||
Search an adversarial example.
|
||||
|
||||
Args:
|
||||
image_batch(list): The image and label tuple list.
|
||||
"""
|
||||
raise NotImplementedError
|
@ -0,0 +1,36 @@
|
||||
"""
|
||||
This module provide the attack method for FGSM's implement.
|
||||
"""
|
||||
from __future__ import division
|
||||
import numpy as np
|
||||
from collections import Iterable
|
||||
from .base import Attack
|
||||
|
||||
class GradientSignAttack(Attack):
|
||||
"""
|
||||
This attack was originally implemented by Goodfellow et al. (2015) with the
|
||||
infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called
|
||||
the Fast Gradient Method.
|
||||
Paper link: https://arxiv.org/abs/1412.6572
|
||||
"""
|
||||
|
||||
def _apply(self, image_batch, epsilons=1000):
|
||||
pre_label = np.argmax(self.model.predict(image_batch))
|
||||
|
||||
min_, max_ = self.model.bounds()
|
||||
gradient = self.model.gradient(image_batch)
|
||||
gradient_sign = np.sign(gradient) * (max_ - min_)
|
||||
|
||||
if not isinstance(epsilons, Iterable):
|
||||
epsilons = np.linspace(0, 1, num = epsilons + 1)
|
||||
|
||||
for epsilon in epsilons:
|
||||
adv_img = image_batch[0][0].reshape(gradient_sign.shape) + epsilon * gradient_sign
|
||||
adv_img = np.clip(adv_img, min_, max_)
|
||||
adv_label = np.argmax(self.model.predict([(adv_img, 0)]))
|
||||
#print("pre_label="+str(pre_label)+ " adv_label="+str(adv_label))
|
||||
if pre_label != adv_label:
|
||||
#print(epsilon, pre_label, adv_label)
|
||||
return adv_img
|
||||
|
||||
FGSM = GradientSignAttack
|
@ -0,0 +1,16 @@
|
||||
# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Paddle model for target of attack
|
||||
"""
|
@ -0,0 +1,91 @@
|
||||
"""
|
||||
The base model of the model.
|
||||
"""
|
||||
from abc import ABCMeta
|
||||
import abc
|
||||
|
||||
abstractmethod = abc.abstractmethod
|
||||
|
||||
class Model(object):
|
||||
|
||||
"""
|
||||
Base class of model to provide attack.
|
||||
|
||||
|
||||
Args:
|
||||
bounds(tuple): The lower and upper bound for the image pixel.
|
||||
channel_axis(int): The index of the axis that represents the color channel.
|
||||
preprocess(tuple): Two element tuple used to preprocess the input. First
|
||||
substract the first element, then divide the second element.
|
||||
"""
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self, bounds, channel_axis, preprocess=None):
|
||||
assert len(bounds) == 2
|
||||
assert channel_axis in [0, 1, 2, 3]
|
||||
|
||||
if preprocess is None:
|
||||
preprocess = (0, 1)
|
||||
self._bounds = bounds
|
||||
self._channel_axis = channel_axis
|
||||
self._preprocess = preprocess
|
||||
|
||||
def bounds(self):
|
||||
"""
|
||||
Return the upper and lower bounds of the model.
|
||||
"""
|
||||
return self._bounds
|
||||
|
||||
def channel_axis(self):
|
||||
"""
|
||||
Return the channel axis of the model.
|
||||
"""
|
||||
return self._channel_axis
|
||||
|
||||
def _process_input(self, input_):
|
||||
res = input_
|
||||
sub, div = self._preprocess
|
||||
if sub != 0:
|
||||
res = input_ - sub
|
||||
assert div != 0
|
||||
if div != 1:
|
||||
res /= div
|
||||
return res
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, image_batch):
|
||||
"""
|
||||
Calculate the prediction of the image batch.
|
||||
|
||||
Args:
|
||||
image_batch(numpy.ndarray): image batch of shape (batch_size, height, width, channels).
|
||||
|
||||
Return:
|
||||
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_classes(self):
|
||||
"""
|
||||
Determine the number of the classes
|
||||
|
||||
Return:
|
||||
int: the number of the classes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def gradient(self, image_batch):
|
||||
"""
|
||||
Calculate the gradient of the cross-entropy loss w.r.t the image.
|
||||
|
||||
Args:
|
||||
image(numpy.ndarray): image with shape (height, width, channel)
|
||||
label(int): image label used to cal gradient.
|
||||
|
||||
Return:
|
||||
numpy.ndarray: gradient of the cross-entropy loss w.r.t the image with
|
||||
the shape (height, width, channel).
|
||||
"""
|
||||
raise NotImplementedError
|
@ -0,0 +1,106 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import numpy as np
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.fluid as fluid
|
||||
from paddle.v2.fluid.framework import program_guard
|
||||
|
||||
from .base import Model
|
||||
|
||||
class PaddleModel(Model):
|
||||
"""
|
||||
Create a PaddleModel instance.
|
||||
When you need to generate a adversarial sample, you should construct an instance of PaddleModel.
|
||||
|
||||
Args:
|
||||
program(paddle.v2.fluid.framework.Program): The program of the model which generate the adversarial sample.
|
||||
input_name(string): The name of the input.
|
||||
logits_name(string): The name of the logits.
|
||||
predict_name(string): The name of the predict.
|
||||
cost_name(string): The name of the loss in the program.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
program,
|
||||
input_name,
|
||||
logits_name,
|
||||
predict_name,
|
||||
cost_name,
|
||||
bounds,
|
||||
channel_axis=3,
|
||||
preprocess=None):
|
||||
super(PaddleModel, self).__init__(
|
||||
bounds=bounds,
|
||||
channel_axis=channel_axis,
|
||||
preprocess=preprocess)
|
||||
|
||||
if preprocess is None:
|
||||
preprocess = (0, 1)
|
||||
|
||||
self._program = program
|
||||
self._place = fluid.CPUPlace()
|
||||
self._exe = fluid.Executor(self._place)
|
||||
|
||||
self._input_name = input_name
|
||||
self._logits_name = logits_name
|
||||
self._predict_name = predict_name
|
||||
self._cost_name = cost_name
|
||||
|
||||
# gradient
|
||||
loss = self._program.block(0).var(self._cost_name)
|
||||
param_grads = fluid.backward.append_backward(loss, parameter_list=[self._input_name])
|
||||
self._gradient = param_grads[0][1]
|
||||
|
||||
def predict(self, image_batch):
|
||||
"""
|
||||
Predict the label of the image_batch.
|
||||
|
||||
Args:
|
||||
image_batch(list): The image and label tuple list.
|
||||
Return:
|
||||
numpy.ndarray: predictions of the images with shape (batch_size, num_of_classes).
|
||||
"""
|
||||
feeder = fluid.DataFeeder(
|
||||
feed_list=[self._input_name, self._logits_name],
|
||||
place=self._place,
|
||||
program=self._program
|
||||
)
|
||||
predict_var = self._program.block(0).var(self._predict_name)
|
||||
predict = self._exe.run(
|
||||
self._program,
|
||||
feed=feeder.feed(image_batch),
|
||||
fetch_list=[predict_var]
|
||||
)
|
||||
return predict
|
||||
|
||||
def num_classes(self):
|
||||
"""
|
||||
Calculate the number of classes of the output label.
|
||||
|
||||
Return:
|
||||
int: the number of classes
|
||||
"""
|
||||
predict_var = self._program.block(0).var(self._predict_name)
|
||||
assert len(predict_var.shape) == 2
|
||||
return predict_var.shape[1]
|
||||
|
||||
def gradient(self, image_batch):
|
||||
"""
|
||||
Calculate the gradient of the loss w.r.t the input.
|
||||
|
||||
Args:
|
||||
image_batch(list): The image and label tuple list.
|
||||
Return:
|
||||
list: The list of the gradient of the image.
|
||||
"""
|
||||
feeder = fluid.DataFeeder(
|
||||
feed_list=[self._input_name, self._logits_name],
|
||||
place=self._place,
|
||||
program=self._program
|
||||
)
|
||||
|
||||
grad, = self._exe.run(
|
||||
self._program,
|
||||
feed=feeder.feed(image_batch),
|
||||
fetch_list=[self._gradient])
|
||||
return grad
|
@ -0,0 +1,32 @@
|
||||
################################################################################
|
||||
#
|
||||
# Copyright (c) 2017 Baidu.com, Inc. All Rights Reserved
|
||||
#
|
||||
################################################################################
|
||||
"""
|
||||
|
||||
A pure Paddlepaddle implementation of a neural network.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.fluid as fluid
|
||||
from advbox import Model
|
||||
|
||||
def main():
|
||||
"""
|
||||
example main function
|
||||
"""
|
||||
model_dir = "./mnist_model"
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
program, feed_var_names, fetch_vars = fluid.io.load_inferfence_model(model_dir, exe)
|
||||
print(program)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,91 @@
|
||||
"""
|
||||
CNN on mnist data using fluid api of paddlepaddle
|
||||
"""
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.fluid as fluid
|
||||
|
||||
def mnist_cnn_model(img):
|
||||
"""
|
||||
Mnist cnn model
|
||||
|
||||
Args:
|
||||
img(Varaible): the input image to be recognized
|
||||
|
||||
Returns:
|
||||
Variable: the label prediction
|
||||
"""
|
||||
#conv1 = fluid.nets.conv2d()
|
||||
conv_pool_1 = fluid.nets.simple_img_conv_pool(
|
||||
input=img,
|
||||
num_filters=20,
|
||||
filter_size=5,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act='relu')
|
||||
|
||||
conv_pool_2 = fluid.nets.simple_img_conv_pool(
|
||||
input=conv_pool_1,
|
||||
num_filters=50,
|
||||
filter_size=5,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act='relu')
|
||||
|
||||
logits = fluid.layers.fc(
|
||||
input=conv_pool_2,
|
||||
size=10,
|
||||
act='softmax')
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Train the cnn model on mnist datasets
|
||||
"""
|
||||
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
logits = mnist_cnn_model(img)
|
||||
cost = fluid.layers.cross_entropy(input=logits, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
optimizer = fluid.optimizer.Adam(learning_rate=0.01)
|
||||
optimizer.minimize(avg_cost)
|
||||
|
||||
accuracy = fluid.evaluator.Accuracy(input=logits, label=label)
|
||||
|
||||
BATCH_SIZE = 50
|
||||
PASS_NUM = 3
|
||||
ACC_THRESHOLD = 0.98
|
||||
LOSS_THRESHOLD = 10.0
|
||||
train_reader = paddle.batch(
|
||||
paddle.reader.shuffle(
|
||||
paddle.dataset.mnist.train(), buf_size=500),
|
||||
batch_size=BATCH_SIZE)
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
for pass_id in range(PASS_NUM):
|
||||
accuracy.reset(exe)
|
||||
for data in train_reader():
|
||||
loss, acc = exe.run(fluid.default_main_program(),
|
||||
feed=feeder.feed(data),
|
||||
fetch_list=[avg_cost] + accuracy.metrics)
|
||||
pass_acc = accuracy.eval(exe)
|
||||
print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" +
|
||||
str(pass_acc))
|
||||
# print loss, acc
|
||||
if loss < LOSS_THRESHOLD and pass_acc > ACC_THRESHOLD:
|
||||
# if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good.
|
||||
break
|
||||
# exit(0)
|
||||
|
||||
pass_acc = accuracy.eval(exe)
|
||||
print("pass_id=" + str(pass_id) + " pass_acc=" + str(pass_acc))
|
||||
fluid.io.save_params(exe, dirname='./mnist', main_program=fluid.default_main_program())
|
||||
print('train mnist done')
|
||||
exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,113 @@
|
||||
"""
|
||||
This attack was originally implemented by Goodfellow et al. (2015) with the
|
||||
infinity norm (and is known as the "Fast Gradient Sign Method"). This is therefore called
|
||||
the Fast Gradient Method.
|
||||
Paper link: https://arxiv.org/abs/1412.6572
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.fluid as fluid
|
||||
|
||||
BATCH_SIZE = 50
|
||||
PASS_NUM = 1
|
||||
EPS = 0.3
|
||||
CLIP_MIN = -1
|
||||
CLIP_MAX = 1
|
||||
PASS_NUM = 1
|
||||
|
||||
def mnist_cnn_model(img):
|
||||
"""
|
||||
Mnist cnn model
|
||||
|
||||
Args:
|
||||
img(Varaible): the input image to be recognized
|
||||
|
||||
Returns:
|
||||
Variable: the label prediction
|
||||
"""
|
||||
#conv1 = fluid.nets.conv2d()
|
||||
conv_pool_1 = fluid.nets.simple_img_conv_pool(
|
||||
input=img,
|
||||
num_filters=20,
|
||||
filter_size=5,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act='relu')
|
||||
|
||||
conv_pool_2 = fluid.nets.simple_img_conv_pool(
|
||||
input=conv_pool_1,
|
||||
num_filters=50,
|
||||
filter_size=5,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act='relu')
|
||||
|
||||
logits = fluid.layers.fc(
|
||||
input=conv_pool_2,
|
||||
size=10,
|
||||
act='softmax')
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Generate adverserial example and evaluate accuracy on mnist using FGSM
|
||||
"""
|
||||
|
||||
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype='float32')
|
||||
# The gradient should flow
|
||||
images.stop_gradient = False
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
|
||||
predict = mnist_cnn_model(images)
|
||||
cost = fluid.layers.cross_entropy(input=predict, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
|
||||
# Cal gradient of input
|
||||
params_grads = fluid.backward.append_backward_ops(avg_cost, parameter_list=['pixel'])
|
||||
# data batch
|
||||
train_reader = paddle.batch(
|
||||
paddle.reader.shuffle(
|
||||
paddle.dataset.mnist.train(), buf_size=500),
|
||||
batch_size=BATCH_SIZE)
|
||||
|
||||
accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
accuracy.reset(exe)
|
||||
#exe.run(fluid.default_startup_program())
|
||||
feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
|
||||
for pass_id in range(PASS_NUM):
|
||||
fluid.io.load_params(exe, "./mnist/", main_program=fluid.default_main_program())
|
||||
for data in train_reader():
|
||||
# cal gradient and eval accuracy
|
||||
ps, acc = exe.run(
|
||||
fluid.default_main_program(),
|
||||
feed=feeder.feed(data),
|
||||
fetch_list=[params_grads[0][1]]+accuracy.metrics)
|
||||
labels = []
|
||||
for idx, _ in enumerate(data):
|
||||
labels.append(data[idx][1])
|
||||
# generate adversarial example
|
||||
batch_num = ps.shape[0]
|
||||
new_data = []
|
||||
for i in range(batch_num):
|
||||
adv_img = np.reshape(data[0][0], (1, 28, 28)) + EPS * np.sign(ps[i])
|
||||
adv_img = np.clip(adv_img, CLIP_MIN, CLIP_MAX)
|
||||
#adv_imgs.append(adv_img)
|
||||
t = (adv_img, data[0][1])
|
||||
new_data.append(t)
|
||||
|
||||
# predict label
|
||||
predict_label, = exe.run(
|
||||
fluid.default_main_program(),
|
||||
feed=feeder.feed(new_data),
|
||||
fetch_list=[predict])
|
||||
adv_labels = np.argmax(predict_label, axis=1)
|
||||
batch_accuracy = np.mean(np.equal(labels, adv_labels))
|
||||
print "pass_id=" + str(pass_id) + " acc=" + str(acc)+ " adv_acc=" + str(batch_accuracy)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,94 @@
|
||||
"""
|
||||
FGSM demos on mnist using advbox tool.
|
||||
"""
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.fluid as fluid
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from advbox.models.paddle import PaddleModel
|
||||
from advbox.attacks.gradientsign import GradientSignAttack
|
||||
|
||||
def cnn_model(img):
|
||||
"""
|
||||
Mnist cnn model
|
||||
Args:
|
||||
img(Varaible): the input image to be recognized
|
||||
Returns:
|
||||
Variable: the label prediction
|
||||
"""
|
||||
#conv1 = fluid.nets.conv2d()
|
||||
conv_pool_1 = fluid.nets.simple_img_conv_pool(
|
||||
input=img,
|
||||
num_filters=20,
|
||||
filter_size=5,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act='relu')
|
||||
|
||||
conv_pool_2 = fluid.nets.simple_img_conv_pool(
|
||||
input=conv_pool_1,
|
||||
num_filters=50,
|
||||
filter_size=5,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
act='relu')
|
||||
|
||||
logits = fluid.layers.fc(
|
||||
input=conv_pool_2,
|
||||
size=10,
|
||||
act='softmax')
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Advbox demo which demonstrate how to use advbox.
|
||||
"""
|
||||
IMG_NAME = 'img'
|
||||
LABEL_NAME = 'label'
|
||||
|
||||
img = fluid.layers.data(name=IMG_NAME, shape=[1, 28, 28], dtype='float32')
|
||||
# gradient should flow
|
||||
img.stop_gradient = False
|
||||
label = fluid.layers.data(name=LABEL_NAME, shape=[1], dtype='int64')
|
||||
logits = cnn_model(img)
|
||||
cost = fluid.layers.cross_entropy(input=logits, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
|
||||
BATCH_SIZE = 1
|
||||
train_reader = paddle.batch(
|
||||
paddle.reader.shuffle(
|
||||
paddle.dataset.mnist.train(), buf_size=500),
|
||||
batch_size=BATCH_SIZE)
|
||||
feeder = fluid.DataFeeder(
|
||||
feed_list=[IMG_NAME, LABEL_NAME],
|
||||
place=place,
|
||||
program=fluid.default_main_program()
|
||||
)
|
||||
|
||||
fluid.io.load_params(exe, "./mnist/", main_program=fluid.default_main_program())
|
||||
|
||||
# advbox demo
|
||||
m = PaddleModel(
|
||||
fluid.default_main_program(),
|
||||
IMG_NAME,
|
||||
LABEL_NAME,
|
||||
logits.name,
|
||||
avg_cost.name,
|
||||
(-1, 1)
|
||||
)
|
||||
att = GradientSignAttack(m)
|
||||
for data in train_reader():
|
||||
# fgsm attack
|
||||
adv_img = att(data)
|
||||
plt.imshow(n[0][0], cmap='Greys_r')
|
||||
plt.show()
|
||||
#np.save('adv_img', adv_img)
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in new issue