Python 官方文档:入门教程 => 点击学习
目录前言使用步骤:常见规则使用config文件传入超参数argparse中action的可选参数store_true前言 argparse是深度学习项目调参时常用的python标准库
argparse是深度学习项目调参时常用的python标准库,使用argparse后,我们在命令行输入的参数就可以以这种形式Python filename.py --lr 1e-4 --batch_size 32来完成对常见超参数的设置。,一般使用时可以归纳为以下三个步骤
import argparse
parser = argparse.ArgumentParser() # 创建一个解析对象
parser.add_argument() # 向该对象中添加你要关注的命令行参数和选项
args = parser.parse_args() # 调用parse_args()方法进行解析
为了使代码更加简洁和模块化,可以将有关超参数的操作写在config.py,然后在train.py或者其他文件导入就可以。具体的config.py可以参考如下内容。
import argparse
def get_options(parser=argparse.ArgumentParser()):
parser.add_argument('--workers', type=int, default=0,
help='number of data loading workers, you had better put it '
'4 times of your gpu')
parser.add_argument('--batch_size', type=int, default=4, help='input batch size, default=64')
parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for, default=10')
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')
parser.add_argument('--seed', type=int, default=118, help="random seed")
parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')
parser.add_argument('--checkpoint_path',type=str,default='',
help='Path to load a previous trained model if not empty (default empty)')
parser.add_argument('--output',action='store_true',default=True,help="shows output")
opt = parser.parse_args()
if opt.output:
print(f'num_workers: {opt.workers}')
print(f'batch_size: {opt.batch_size}')
print(f'epochs (niters) : {opt.niter}')
print(f'learning rate : {opt.lr}')
print(f'manual_seed: {opt.seed}')
print(f'cuda enable: {opt.cuda}')
print(f'checkpoint_path: {opt.checkpoint_path}')
return opt
if __name__ == '__main__':
opt = get_options()
$ python config.py
num_workers: 0
batch_size: 4
epochs (niters) : 10
learning rate : 3e-05
manual_seed: 118
cuda enable: True
checkpoint_path:
随后在train.py等其他文件,我们就可以使用下面的这样的结构来调用参数。
# 导入必要库
...
import config
opt = config.get_options()
manual_seed = opt.seed
num_workers = opt.workers
batch_size = opt.batch_size
lr = opt.lr
niters = opt.niters
checkpoint_path = opt.checkpoint_path
# 随机数的设置,保证复现结果
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
...
if __name__ == '__main__':
set_seed(manual_seed)
for epoch in range(niters):
train(model,lr,batch_size,num_workers,checkpoint_path)
val(model,lr,batch_size,num_workers,checkpoint_path)
# test.py
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--test_action", action='store_true')
args = parser.parse_args()
action_val = args.test_action
print(action_val)
以上面的代码为例,若触发 test_action,则为 True, 否则为 False:
若在上面的代码中加入default,设为 False 时:
parser.add_argument("--test_action", default='False', action='store_true')
default 设为 True 时:
parser.add_argument("--test_action", default='True', action='store_true')
参考:https://www.jb51.net/article/250215.htm
以上就是python深度学习标准库使用argparse调参的详细内容,更多关于python标准库argparse调参的资料请关注编程网其它相关文章!
--结束END--
本文标题: python深度学习标准库使用argparse调参
本文链接: https://www.lsjlt.com/news/118482.html(转载时请注明来源链接)
有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341
下载Word文档到电脑,方便收藏和打印~
2024-03-01
2024-03-01
2024-03-01
2024-02-29
2024-02-29
2024-02-29
2024-02-29
2024-02-29
2024-02-29
2024-02-29
回答
回答
回答
回答
回答
回答
回答
回答
回答
回答
0