1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
| #使用pytorch完成手写数字的识别 import os import torch from torch.utils.data import DataLoader from torch.optim import Adam from mpmath.identification import transforms from torchvision.datasets import MNIST from torchvision.transforms import Compose,ToTensor,Normalize import torch.nn as nn import torch.nn.functional as F BATCH_SIZE=128 #1准备数据集 def get_dataloader(train=True): transform_fn =Compose([ ToTensor(), Normalize(mean=(0.1307),std=(0.3081,)) #mean和std形状和通道数一样 ]) dataset=MNIST(root="./data",train=train,transform=transform_fn) data_loader=DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True) return data_loader
#2,构建模型 class MistMOdel(nn.Module): def __init__(self): super(MistMOdel,self).__init__() self.fc1=nn.Linear(1*28*28,28) self.fc2=nn.Linear(28,10)
def forward(self,input): """ :param input: [batch_size,1,28,28] :return: """ x=input.view([input.size(0),1*28*28])#形状修改 x=self.fcl(x) #全连接操作 x=F.relu(x) #进行激活函数处理,形状不变 out=self.fc2(x) #输出层 return out return F.log_softmax(out,dim=-1)
model=MistMOdel()
optimizer=Adam(model.parameters(),lr=0.001) if os.path.exists("./"): model.load_state_dict(torch.load("./路径")) optimizer.load_state_dict(torch.load("./路径"))
def train(epoch): #模型的训练
data_loader=get_dataloader() for idx,(inout,traget) in enumerate(data_loader): output=model(input) #调用模型的到预测值 loss=F.nll_loss(output,traget) loss.backward()#反向传播 optimizer.step()#提督更新 if idx%100==0: print(epoch,idx,loss.item())
#模型的保存 if idx%100==0: torch.save(model.state_dict(),"./路径") torch.save(optimizer.state_dict(), "./路径")
def test(): test_dataloader=get_dataloader(train=False) for idx,(inout,traget) in enumerate(test_dataloader): with torch.no_grad: output=model(input) cur_loss=F.nll_loss(output,traget)
if __name__ == '__main__': #for i in range(3):#训练3轮 # train(i) loader=get_dataloader(train=False) for input,label in loader: print(label.siae())
|
评论区
欢迎你留下宝贵的意见,昵称输入QQ号会显示QQ头像哦~