parent
ac5371b38f
commit
50d7062fed
@ -0,0 +1,75 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
eval.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from src.pet_dataset import create_dataset
|
||||||
|
from src.config import config_ascend, config_gpu
|
||||||
|
from src.tnt import tnt_b
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Image classification')
|
||||||
|
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||||
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||||
|
parser.add_argument('--platform', type=str, default=None, help='run platform')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config_platform = None
|
||||||
|
if args_opt.platform == "Ascend":
|
||||||
|
config_platform = config_ascend
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
|
||||||
|
device_id=device_id, save_graphs=False)
|
||||||
|
elif args_opt.platform == "GPU":
|
||||||
|
config_platform = config_gpu
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE,
|
||||||
|
device_target="GPU", save_graphs=False)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported platform.")
|
||||||
|
|
||||||
|
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||||
|
|
||||||
|
net = tnt_b(num_class=config_platform.num_classes)
|
||||||
|
|
||||||
|
if args_opt.checkpoint_path:
|
||||||
|
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
net.set_train(False)
|
||||||
|
|
||||||
|
if args_opt.platform == "Ascend":
|
||||||
|
net.to_float(mstype.float16)
|
||||||
|
for _, cell in net.cells_and_names():
|
||||||
|
if isinstance(cell, nn.Dense):
|
||||||
|
cell.to_float(mstype.float32)
|
||||||
|
|
||||||
|
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||||
|
do_train=False,
|
||||||
|
config=config_platform,
|
||||||
|
platform=args_opt.platform,
|
||||||
|
batch_size=config_platform.batch_size)
|
||||||
|
step_size = dataset.get_dataset_size()
|
||||||
|
|
||||||
|
model = Model(net, loss_fn=loss, metrics={'acc'})
|
||||||
|
res = model.eval(dataset)
|
||||||
|
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
After Width: | Height: | Size: 147 KiB |
@ -0,0 +1,22 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""hub config."""
|
||||||
|
from src.tnt import tnt_b
|
||||||
|
|
||||||
|
|
||||||
|
def create_network(name, *args, **kwargs):
|
||||||
|
if name == 'TNT-B':
|
||||||
|
return tnt_b(*args, **kwargs)
|
||||||
|
raise NotImplementedError(f"{name} is not implemented in the repo")
|
@ -0,0 +1,128 @@
|
|||||||
|
# Contents
|
||||||
|
|
||||||
|
- [TNT Description](#tnt-description)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Dataset](#dataset)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Script Description](#script-description)
|
||||||
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
|
- [Training Process](#training-process)
|
||||||
|
- [Evaluation Process](#evaluation-process)
|
||||||
|
- [Evaluation](#evaluation)
|
||||||
|
- [Model Description](#model-description)
|
||||||
|
- [Performance](#performance)
|
||||||
|
- [Training Performance](#evaluation-performance)
|
||||||
|
- [Inference Performance](#evaluation-performance)
|
||||||
|
- [Description of Random Situation](#description-of-random-situation)
|
||||||
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
## [TNT Description](#contents)
|
||||||
|
|
||||||
|
The TNT (Transformer in Transformer) network is a pure transformer model for visual recognition. TNT treats an image as a sequence of patches and treats a patch as a sequence of pixels. TNT block utilizes a outer transformer block to process the sequence of patches and an inner transformer block to process the sequence of pixels.
|
||||||
|
|
||||||
|
[Paper](https://arxiv.org/abs/2103.00112): Kai Han, An Xiao, Enhua Wu, Jianyuan Guo, Chunjing Xu, Yunhe Wang. Transformer in Transformer. preprint 2021.
|
||||||
|
|
||||||
|
## [Model architecture](#contents)
|
||||||
|
|
||||||
|
The overall network architecture of TNT is show below:
|
||||||
|
![](./fig/tnt.PNG)
|
||||||
|
|
||||||
|
## [Dataset](#contents)
|
||||||
|
|
||||||
|
Dataset used: [Oxford-IIIT Pet](https://www.robots.ox.ac.uk/~vgg/data/pets/)
|
||||||
|
|
||||||
|
- Dataset size: 7049 colorful images in 1000 classes
|
||||||
|
- Train: 3680 images
|
||||||
|
- Test: 3369 images
|
||||||
|
- Data format: RGB images.
|
||||||
|
- Note: Data will be processed in src/dataset.py
|
||||||
|
|
||||||
|
## [Environment Requirements](#contents)
|
||||||
|
|
||||||
|
- Hardware(Ascend/GPU)
|
||||||
|
- Prepare hardware environment with Ascend or GPU. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||||
|
- Framework
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- For more information, please check the resources below£º
|
||||||
|
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||||
|
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||||
|
|
||||||
|
## [Script description](#contents)
|
||||||
|
|
||||||
|
### [Script and sample code](#contents)
|
||||||
|
|
||||||
|
```python
|
||||||
|
TNT
|
||||||
|
├── eval.py # inference entry
|
||||||
|
├── fig
|
||||||
|
│ └── tnt.png # the illustration of TNT network
|
||||||
|
├── readme.md # Readme
|
||||||
|
└── src
|
||||||
|
├── config.py # config of model and data
|
||||||
|
├── pet_dataset.py # dataset loader
|
||||||
|
└── tnt.py # TNT network
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Training process](#contents)
|
||||||
|
|
||||||
|
To Be Done
|
||||||
|
|
||||||
|
## [Eval process](#contents)
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
After installing MindSpore via the official website, you can start evaluation as follows:
|
||||||
|
|
||||||
|
### Launch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# infer example
|
||||||
|
GPU: python eval.py --model tnt-b --dataset_path ~/Pets/test.mindrecord --platform GPU --checkpoint_path [CHECKPOINT_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
> checkpoint can be downloaded at https://www.mindspore.cn/resources/hub.
|
||||||
|
|
||||||
|
### Result
|
||||||
|
|
||||||
|
```bash
|
||||||
|
result: {'acc': 0.95} ckpt= ./tnt-b-pets.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Model Description](#contents)
|
||||||
|
|
||||||
|
### [Performance](#contents)
|
||||||
|
|
||||||
|
#### Evaluation Performance
|
||||||
|
|
||||||
|
##### TNT on ImageNet2012
|
||||||
|
|
||||||
|
| Parameters | | |
|
||||||
|
| -------------------------- | -------------------------------------- |---------------------------------- |
|
||||||
|
| Model Version | TNT-B |TNT-S|
|
||||||
|
| uploaded Date | 21/03/2021 (month/day/year) | 21/03/2021 (month/day/year) |
|
||||||
|
| MindSpore Version | 1.1 | 1.1 |
|
||||||
|
| Dataset | ImageNet2012 | ImageNet2012|
|
||||||
|
| Input size | 224x224 | 224x224|
|
||||||
|
| Parameters (M) | 86.4 | 23.8 |
|
||||||
|
| FLOPs (M) | 14.1 | 5.2 |
|
||||||
|
| Accuracy (Top1) | 82.8 | 81.3 |
|
||||||
|
|
||||||
|
###### TNT on Oxford-IIIT Pet
|
||||||
|
|
||||||
|
| Parameters | | |
|
||||||
|
| -------------------------- | -------------------------------------- |---------------------------------- |
|
||||||
|
| Model Version | TNT-B |TNT-S|
|
||||||
|
| uploaded Date | 21/03/2021 (month/day/year) | 21/03/2021 (month/day/year) |
|
||||||
|
| MindSpore Version | 1.1 | 1.1 |
|
||||||
|
| Dataset | Oxford-IIIT Pet | Oxford-IIIT Pet|
|
||||||
|
| Input size | 384x384 | 384x384|
|
||||||
|
| Parameters (M) | 86.4 | 23.8 |
|
||||||
|
| Accuracy (Top1) | 95.0 | 94.7 |
|
||||||
|
|
||||||
|
## [Description of Random Situation](#contents)
|
||||||
|
|
||||||
|
In dataset.py, we set the seed inside "create_dataset" function. We also use random seed in train.py.
|
||||||
|
|
||||||
|
## [ModelZoo Homepage](#contents)
|
||||||
|
|
||||||
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
@ -0,0 +1,54 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
network config setting, will be used in train.py and eval.py
|
||||||
|
"""
|
||||||
|
from easydict import EasyDict as ed
|
||||||
|
|
||||||
|
config_ascend = ed({
|
||||||
|
"num_classes": 37,
|
||||||
|
"image_height": 384,
|
||||||
|
"image_width": 384,
|
||||||
|
"batch_size": 50,
|
||||||
|
"epoch_size": 300,
|
||||||
|
"warmup_epochs": 5,
|
||||||
|
"lr": 1e-3,
|
||||||
|
"momentum": 0.9,
|
||||||
|
"weight_decay": 0.05,
|
||||||
|
"label_smooth": 0.1,
|
||||||
|
"loss_scale": 1024,
|
||||||
|
"save_checkpoint": True,
|
||||||
|
"save_checkpoint_epochs": 1,
|
||||||
|
"keep_checkpoint_max": 200,
|
||||||
|
"save_checkpoint_path": "./checkpoint",
|
||||||
|
})
|
||||||
|
|
||||||
|
config_gpu = ed({
|
||||||
|
"num_classes": 37,
|
||||||
|
"image_height": 384,
|
||||||
|
"image_width": 384,
|
||||||
|
"batch_size": 50,
|
||||||
|
"epoch_size": 300,
|
||||||
|
"warmup_epochs": 5,
|
||||||
|
"lr": 1e-3,
|
||||||
|
"momentum": 0.9,
|
||||||
|
"weight_decay": 0.05,
|
||||||
|
"label_smooth": 0.1,
|
||||||
|
"loss_scale": 1024,
|
||||||
|
"save_checkpoint": True,
|
||||||
|
"save_checkpoint_epochs": 1,
|
||||||
|
"keep_checkpoint_max": 500,
|
||||||
|
"save_checkpoint_path": "./checkpoint",
|
||||||
|
})
|
@ -0,0 +1,97 @@
|
|||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
create train or eval dataset.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset.engine as de
|
||||||
|
import mindspore.dataset.transforms.py_transforms as py_transforms
|
||||||
|
import mindspore.dataset.transforms.c_transforms as c_transforms
|
||||||
|
import mindspore.dataset.vision.py_transforms as py_vision
|
||||||
|
from mindspore.dataset.vision import Inter
|
||||||
|
|
||||||
|
def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch_size=1):
|
||||||
|
"""
|
||||||
|
create a train or eval dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_path(string): the path of dataset.
|
||||||
|
do_train(bool): whether dataset is used for train or eval.
|
||||||
|
repeat_num(int): the repeat times of dataset. Default: 1
|
||||||
|
batch_size(int): the batch size of dataset. Default: 32
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dataset
|
||||||
|
"""
|
||||||
|
if platform == "Ascend":
|
||||||
|
rank_size = int(os.getenv("RANK_SIZE"))
|
||||||
|
rank_id = int(os.getenv("RANK_ID"))
|
||||||
|
if rank_size == 1:
|
||||||
|
ds = de.MindDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||||
|
else:
|
||||||
|
ds = de.MindDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||||
|
num_shards=rank_size, shard_id=rank_id)
|
||||||
|
elif platform == "GPU":
|
||||||
|
if do_train:
|
||||||
|
from mindspore.communication.management import get_rank, get_group_size
|
||||||
|
ds = de.MindDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||||
|
num_shards=get_group_size(), shard_id=get_rank())
|
||||||
|
else:
|
||||||
|
ds = de.MindDataset(dataset_path, num_parallel_workers=8, shuffle=False)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported platform.")
|
||||||
|
|
||||||
|
resize_height = config.image_height
|
||||||
|
resize_width = config.image_width
|
||||||
|
buffer_size = 1000
|
||||||
|
|
||||||
|
# define map operations
|
||||||
|
random_resize_crop_bicubic = py_vision.RandomResizedCrop(size=(resize_height, resize_width),
|
||||||
|
scale=(0.08, 1.0), ratio=(3./4., 4./3.),
|
||||||
|
interpolation=Inter.BICUBIC)
|
||||||
|
random_horizontal_flip_op = py_vision.RandomHorizontalFlip(0.5)
|
||||||
|
color_jitter = 0.4
|
||||||
|
adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter)
|
||||||
|
random_color_jitter_op = py_vision.RandomColorAdjust(brightness=adjust_range,
|
||||||
|
contrast=adjust_range,
|
||||||
|
saturation=adjust_range)
|
||||||
|
|
||||||
|
decode_p = py_vision.Decode()
|
||||||
|
resize_p = py_vision.Resize(int(resize_height), interpolation=Inter.BICUBIC)
|
||||||
|
center_crop_p = py_vision.CenterCrop(resize_height)
|
||||||
|
totensor = py_vision.ToTensor()
|
||||||
|
normalize_p = py_vision.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||||
|
|
||||||
|
if do_train:
|
||||||
|
trans = py_transforms.Compose([decode_p, random_resize_crop_bicubic, random_horizontal_flip_op,
|
||||||
|
random_color_jitter_op, totensor, normalize_p])
|
||||||
|
else:
|
||||||
|
trans = py_transforms.Compose([decode_p, resize_p, center_crop_p, totensor, normalize_p])
|
||||||
|
|
||||||
|
type_cast_op = c_transforms.TypeCast(mstype.int32)
|
||||||
|
|
||||||
|
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
|
||||||
|
ds = ds.map(input_columns="label_list", operations=type_cast_op, num_parallel_workers=8)
|
||||||
|
|
||||||
|
# apply shuffle operations
|
||||||
|
ds = ds.shuffle(buffer_size=buffer_size)
|
||||||
|
|
||||||
|
# apply batch operations
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=True)
|
||||||
|
|
||||||
|
# apply dataset repeat operation
|
||||||
|
ds = ds.repeat(repeat_num)
|
||||||
|
return ds
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue