[Dy2stat] Add TSM as ProgramTranslator Unit Test. (#25008)
Add TSM as ProgramTranslator Unit Test. The TSM code is referred from PaddlePaddle/models#4229fix-sync_batch_norm-hang-in-fleet
parent
770c11a117
commit
9b5b726729
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,43 @@
|
|||||||
|
MODEL:
|
||||||
|
name: "TSM"
|
||||||
|
format: "pkl"
|
||||||
|
num_classes: 400
|
||||||
|
seg_num: 8
|
||||||
|
seglen: 1
|
||||||
|
image_mean: [0.485, 0.456, 0.406]
|
||||||
|
image_std: [0.229, 0.224, 0.225]
|
||||||
|
num_layers: 50
|
||||||
|
topk: 5
|
||||||
|
|
||||||
|
TRAIN:
|
||||||
|
epoch: 1
|
||||||
|
short_size: 256
|
||||||
|
target_size: 224
|
||||||
|
num_reader_threads: 12
|
||||||
|
buf_size: 1024
|
||||||
|
batch_size: 4 #128
|
||||||
|
use_gpu: True
|
||||||
|
num_gpus: 1 #8
|
||||||
|
filelist: "./data/dataset/kinetics/train.list"
|
||||||
|
learning_rate: 0.01
|
||||||
|
learning_rate_decay: 0.1
|
||||||
|
decay_epochs: [40, 60]
|
||||||
|
l2_weight_decay: 1e-4
|
||||||
|
momentum: 0.9
|
||||||
|
total_videos: 8000 #239781
|
||||||
|
|
||||||
|
VALID:
|
||||||
|
short_size: 256
|
||||||
|
target_size: 224
|
||||||
|
num_reader_threads: 12
|
||||||
|
buf_size: 1024
|
||||||
|
batch_size: 32 #128
|
||||||
|
filelist: "./data/dataset/kinetics/val.list"
|
||||||
|
|
||||||
|
TEST:
|
||||||
|
short_size: 256
|
||||||
|
target_size: 224
|
||||||
|
num_reader_threads: 12
|
||||||
|
buf_size: 1024
|
||||||
|
batch_size: 64
|
||||||
|
filelist: "./data/dataset/kinetics/test.list"
|
@ -0,0 +1,85 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
#you may not use this file except in compliance with the License.
|
||||||
|
#You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
#Unless required by applicable law or agreed to in writing, software
|
||||||
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
#See the License for the specific language governing permissions and
|
||||||
|
#limitations under the License.
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CONFIG_SECS = [
|
||||||
|
'train',
|
||||||
|
'valid',
|
||||||
|
'test',
|
||||||
|
'infer',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __getattr__(self, key):
|
||||||
|
return self[key]
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
if key in self.__dict__:
|
||||||
|
self.__dict__[key] = value
|
||||||
|
else:
|
||||||
|
self[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def parse_config(cfg_file):
|
||||||
|
"""Load a config file into AttrDict"""
|
||||||
|
import yaml
|
||||||
|
with open(cfg_file, 'r') as fopen:
|
||||||
|
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader))
|
||||||
|
create_attr_dict(yaml_config)
|
||||||
|
return yaml_config
|
||||||
|
|
||||||
|
|
||||||
|
def create_attr_dict(yaml_config):
|
||||||
|
from ast import literal_eval
|
||||||
|
for key, value in yaml_config.items():
|
||||||
|
if type(value) is dict:
|
||||||
|
yaml_config[key] = value = AttrDict(value)
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
value = literal_eval(value)
|
||||||
|
except BaseException:
|
||||||
|
pass
|
||||||
|
if isinstance(value, AttrDict):
|
||||||
|
create_attr_dict(yaml_config[key])
|
||||||
|
else:
|
||||||
|
yaml_config[key] = value
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def merge_configs(cfg, sec, args_dict):
|
||||||
|
assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
|
||||||
|
sec_dict = getattr(cfg, sec.upper())
|
||||||
|
for k, v in args_dict.items():
|
||||||
|
if v is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if hasattr(sec_dict, k):
|
||||||
|
setattr(sec_dict, k, v)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def print_configs(cfg, mode):
|
||||||
|
logger.info("---------------- {:>5} Arguments ----------------".format(
|
||||||
|
mode))
|
||||||
|
for sec, sec_items in cfg.items():
|
||||||
|
logger.info("{}:".format(sec))
|
||||||
|
for k, v in sec_items.items():
|
||||||
|
logger.info(" {}:{}".format(k, v))
|
||||||
|
logger.info("-------------------------------------------------")
|
Loading…
Reference in new issue