[Dy2stat] Add TSM as ProgramTranslator Unit Test. (#25008)

Add TSM as ProgramTranslator Unit Test. The TSM code is referred from PaddlePaddle/models#4229
fix-sync_batch_norm-hang-in-fleet
Huihuang Zheng 5 years ago committed by GitHub
parent 770c11a117
commit 9b5b726729
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,3 +4,5 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
set_tests_properties(test_tsm PROPERTIES TIMEOUT 900)

@ -0,0 +1,43 @@
MODEL:
name: "TSM"
format: "pkl"
num_classes: 400
seg_num: 8
seglen: 1
image_mean: [0.485, 0.456, 0.406]
image_std: [0.229, 0.224, 0.225]
num_layers: 50
topk: 5
TRAIN:
epoch: 1
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 4 #128
use_gpu: True
num_gpus: 1 #8
filelist: "./data/dataset/kinetics/train.list"
learning_rate: 0.01
learning_rate_decay: 0.1
decay_epochs: [40, 60]
l2_weight_decay: 1e-4
momentum: 0.9
total_videos: 8000 #239781
VALID:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 32 #128
filelist: "./data/dataset/kinetics/val.list"
TEST:
short_size: 256
target_size: 224
num_reader_threads: 12
buf_size: 1024
batch_size: 64
filelist: "./data/dataset/kinetics/test.list"

@ -0,0 +1,85 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import yaml
import logging
logger = logging.getLogger(__name__)
CONFIG_SECS = [
'train',
'valid',
'test',
'infer',
]
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
def parse_config(cfg_file):
"""Load a config file into AttrDict"""
import yaml
with open(cfg_file, 'r') as fopen:
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader))
create_attr_dict(yaml_config)
return yaml_config
def create_attr_dict(yaml_config):
from ast import literal_eval
for key, value in yaml_config.items():
if type(value) is dict:
yaml_config[key] = value = AttrDict(value)
if isinstance(value, str):
try:
value = literal_eval(value)
except BaseException:
pass
if isinstance(value, AttrDict):
create_attr_dict(yaml_config[key])
else:
yaml_config[key] = value
return
def merge_configs(cfg, sec, args_dict):
assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
sec_dict = getattr(cfg, sec.upper())
for k, v in args_dict.items():
if v is None:
continue
try:
if hasattr(sec_dict, k):
setattr(sec_dict, k, v)
except:
pass
return cfg
def print_configs(cfg, mode):
logger.info("---------------- {:>5} Arguments ----------------".format(
mode))
for sec, sec_items in cfg.items():
logger.info("{}:".format(sec))
for k, v in sec_items.items():
logger.info(" {}:{}".format(k, v))
logger.info("-------------------------------------------------")
Loading…
Cancel
Save