pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)
(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)
优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)
备注:
1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"
torch.save(model.state_dict(), PATH)
2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.
模态字典(state_dict)的保存(model是一个网络结构类的对象)
1.1)仅保存学习到的参数,用以下命令
torch.save(model.state_dict(), PATH)
1.2)加载model.state_dict,用以下命令
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名
2.1)保存整个model的状态,用以下命令
torch.save(model,PATH)
2.2)加载整个model的状态,用以下命令:
# Model class must be defined somewhere model = torch.load(PATH) model.eval()
state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项
如何仅加载某一层的训练的到的参数(某一层的state)
If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)
for param in list(model.pretrained.parameters()): param.requires_grad = False
注意: requires_grad的操作对象是tensor.
疑问:能否直接对某个层直接之用requires_grad呢"htmlcode">
#-*-coding:utf-8-*- import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim # define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass,self).__init__() self.conv1 = nn.Conv2d(3,6,5) self.pool = nn.MaxPool2d(2,2) self.conv2 = nn.Conv2d(6,16,5) self.fc1 = nn.Linear(16*5*5,120) self.fc2 = nn.Linear(120,84) self.fc3 = nn.Linear(84,10) def forward(self,x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1,16*5*5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # initial model model = TheModelClass() #initialize the optimizer optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9) # print the model's state_dict print("model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor,'\t',model.state_dict()[param_tensor].size()) print("\noptimizer's state_dict") for var_name in optimizer.state_dict(): print(var_name,'\t',optimizer.state_dict()[var_name]) print("\nprint particular param") print('\n',model.conv1.weight.size()) print('\n',model.conv1.weight) print("------------------------------------") torch.save(model.state_dict(),'./model_state_dict.pt') # model_2 = TheModelClass() # model_2.load_state_dict(torch.load('./model_state_dict')) # model.eval() # print('\n',model_2.conv1.weight) # print((model_2.conv1.weight == model.conv1.weight).size()) ## 仅仅加载某一层的参数 conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight'] print(conv1_weight_state==model.conv1.weight) model_2 = TheModelClass() model_2.load_state_dict(torch.load('./model_state_dict.pt')) model_2.conv1.requires_grad=False print(model_2.conv1.requires_grad) print(model_2.conv1.bias.requires_grad)
以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
免责声明:本站资源来自互联网收集,仅供用于学习和交流,请遵循相关法律法规,本站一切资源不代表本站立场,如有侵权、后门、不妥请联系本站删除!
稳了!魔兽国服回归的3条重磅消息!官宣时间再确认!
昨天有一位朋友在大神群里分享,自己亚服账号被封号之后居然弹出了国服的封号信息对话框。
这里面让他访问的是一个国服的战网网址,com.cn和后面的zh都非常明白地表明这就是国服战网。
而他在复制这个网址并且进行登录之后,确实是网易的网址,也就是我们熟悉的停服之后国服发布的暴雪游戏产品运营到期开放退款的说明。这是一件比较奇怪的事情,因为以前都没有出现这样的情况,现在突然提示跳转到国服战网的网址,是不是说明了简体中文客户端已经开始进行更新了呢?
更新日志
- 好薇2024《兵哥哥》1:124K黄金母盘[WAV+CUE]
- 胡歌.2006-珍惜(EP)【步升大风】【FLAC分轨】
- 洪荣宏.2014-拼乎自己看【华特】【WAV+CUE】
- 伊能静.1999-从脆弱到勇敢1987-1996精选2CD【华纳】【WAV+CUE】
- 刘亮鹭《汽车DJ玩主》[WAV+CUE][1.1G]
- 张杰《最接近天堂的地方》天娱传媒[WAV+CUE][1.1G]
- 群星《2022年度发烧天碟》无损黑胶碟 2CD[WAV+CUE][1.4G]
- 罗文1983-罗文甄妮-射雕英雄传(纯银AMCD)[WAV+CUE]
- 群星《亚洲故事香港纯弦》雨果UPMAGCD2024[低速原抓WAV+CUE]
- 群星《经典咏流传》限量1:1母盘直刻[低速原抓WAV+CUE]
- 庾澄庆1993《老实情歌》福茂唱片[WAV+CUE][1G]
- 许巍《在别处》美卡首版[WAV+CUE][1G]
- 林子祥《单手拍掌》华纳香港版[WAV+CUE][1G]
- 郑秀文.1997-我们的主题曲【华纳】【WAV+CUE】
- 群星.2001-生命因爱动听电影原创音乐AVCD【MEDIA】【WAV+CUE】