这篇文章主要介绍PyTorch 6中batch_train批训练操作的示例分析,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!看代码吧~import torchimport torch.utils.
这篇文章主要介绍PyTorch 6中batch_train批训练操作的示例分析,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!
import torchimport torch.utils.data as Datatorch.manual_seed(1) # reproducible# BATCH_SIZE = 5 BATCH_SIZE = 8 # 每次使用8个数据同时传入网路x = torch.linspace(1, 10, 10) # this is x data (torch tensor)y = torch.linspace(10, 1, 10) # this is y data (torch tensor)torch_dataset = Data.TensorDataset(x, y)loader = Data.DataLoader( dataset=torch_dataset, # torch TensorDataset fORMat batch_size=BATCH_SIZE, # mini batch size shuffle=False, # 设置不随机打乱数据 random shuffle for training num_workers=2, # 使用两个进程提取数据,subprocesses for loading data)def show_batch(): for epoch in range(3): # 全部的数据使用3遍,train entire dataset 3 times for step, (batch_x, batch_y) in enumerate(loader): # for each training step # train your data... print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())if __name__ == '__main__': show_batch()
Epoch: 0 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]Epoch: 0 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]Epoch: 1 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]Epoch: 1 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]Epoch: 2 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.]Epoch: 2 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
补充:pytorch批训练bug
在进行pytorch神经网络批训练的时候,有时会出现报错
TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>
检查(重点!!!!!):
train_dataset = Data.TensorDataset(train_x, train_y)
train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable
可以这样将数据变为tensor类:
train_x = torch.FloatTensor(train_x)
train_loader = Data.DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=True )
实例化一个DataLoader对象
for epoch in range(epochs): for step, (batch_x, batch_y) in enumerate(train_loader): batch_x, batch_y = Variable(batch_x), Variable(batch_y)
这样就可以批训练了
需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable
以上是“pytorch 6中batch_train批训练操作的示例分析”这篇文章的所有内容,感谢各位的阅读!希望分享的内容对大家有帮助,更多相关知识,欢迎关注编程网精选频道!
--结束END--
本文标题: pytorch 6中batch_train批训练操作的示例分析
本文链接: https://www.lsjlt.com/news/278488.html(转载时请注明来源链接)
有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341
下载Word文档到电脑,方便收藏和打印~
2024-05-15
2024-05-15
2024-05-15
2024-05-15
2024-05-15
2024-05-15
2024-05-15
2024-05-15
2024-05-15
2024-05-15
回答
回答
回答
回答
回答
回答
回答
回答
回答
回答
0