-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
115 lines (92 loc) · 3.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
functional functions
"""
import os
import shutil
import glob
import yaml
import csv
import logging
import random
import numpy as np
import torch
sep = os.sep
def load_yaml(file_path='./config.yaml'):
with open(file_path) as f:
params = yaml.safe_load(f)
return params
def save_yaml_file(file_path, data: dict):
with open(file_path, "w") as f:
yaml.safe_dump(data, f, encoding='utf-8', allow_unicode=True)
def save_load_version_files(path, file_patterns, pass_dirs=None):
# save latest version files
if pass_dirs is None:
pass_dirs = ['.', '_', 'runs', 'results']
copy_files(f'.{sep}', 'runs/latest_project', file_patterns, pass_dirs)
copy_files(f'.{sep}', os.path.join(path, 'project'), file_patterns, pass_dirs)
def save_csv(file_path, data: list):
with open(file_path, 'w', newline='') as f:
writer = csv.writer(f, lineterminator='\n')
writer.writerows(data)
# 复制目标文件到目标路径
def copy_files(root_dir, target_dir, file_patterns, pass_dirs=['.git']):
# print(root_dir, root_dir.split(sep), [name for name in root_dir.split(sep) if name != ''])
os.makedirs(target_dir, exist_ok=True)
len_root = len([name for name in root_dir.split(sep) if name != ''])
for root, _, _ in os.walk(root_dir):
cur_dir = sep.join(root.split(sep)[len_root:])
first_dir_name = cur_dir.split(sep)[0]
if first_dir_name != '':
if (first_dir_name in pass_dirs) or (first_dir_name[0] in pass_dirs): continue
# print(len_root, root, cur_dir)
target_path = os.path.join(target_dir, cur_dir)
os.makedirs(target_path, exist_ok=True)
files = []
for file_pattern in file_patterns:
file_path_pattern = os.path.join(root, file_pattern)
files += sorted(glob.glob(file_path_pattern))
for file in files:
target_path_file = os.path.join(target_path, os.path.split(file)[-1])
shutil.copyfile(file, target_path_file)
def save_model_state_dict(file_path, epoch=None, net=None, optimizer=None):
import torch
state_dict = {
'epoch': epoch,
'optimizer': optimizer.state_dict() if optimizer else None,
'model': net.state_dict() if net else None,
}
torch.save(state_dict, file_path)
def get_logger(filename):
logging.basicConfig(filename=filename, level=logging.INFO)
logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler())
return logger
def get_filename_list(dir_path, pattern='*', ext='*'):
"""
find all extention files under directory
:param dir_path: directory path
:param ext: extention name, like wav, png...
:param pattern: filename pattern for searching
:return: files path list
"""
filename_list = []
for root, _, _ in os.walk(dir_path):
file_path_pattern = os.path.join(root, f'{pattern}.{ext}')
files = sorted(glob.glob(file_path_pattern))
filename_list += files
return filename_list
def set_type(value):
if value.lower() == 'true':
return True
elif value.lower() == 'false':
return False
else:
return value
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
print(get_filename_list('../Fastorch', ext='py'))