训练模型
训练模型我们一般分为: 加载数据, 实现网络, 寻找损失函数,优化器,计算精确值,训练,在测试集上测试。
加载数据
我们使用datasets 来下载训练和测试数据集,并使用transforms来转换成tensor的结构,最后返回DataLoader的批量数据。
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time
def load_data(batch_size , resize):
trans = transforms.Compose([
transforms.ToTensor(),
])
if resize:
trans.transforms.insert(0, transforms.Resize((resize, resize)))
mnist_train = datasets.FashionMNIST("./dataset", train=True, download=True, transform=trans)
mnist_test = datasets.FashionMNIST("./dataset", train=False, download=True, transform=trans)
return (DataLoader(mnist_train, batch_size, shuffle=True, num_workers=4),
DataLoader(mnist_test, batch_size, shuffle=False, num_workers=4))
如果我们的数据集是图像,我们可以使用matplot 来打印出图像的内容:
def print_image():
train_loader, test_loader = load_data(64, 64)
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
images, labels = next(iter(train_loader))
plt.figure(figsize=(12, 8))
for i in range(12):
plt.subplot(2, 6, i+1)
plt.imshow(images[i].numpy().reshape(28, 28))
plt.title(f"{classes[labels[i]]}")
plt.axis('off')
plt.tight_layout()
plt.show()
class Animator:
def __init__(self, xlabel = None, ylabel = None, legend = None,
xlim = None, ylim = None, xscale = 'linear', yscale = 'linear',
fmts = ('-', 'm--', 'g-.', 'r:'), nrows = 1, ncols = 1,
figsize = (3.5, 2.5)):
if legend is None:
legend = []
backend_inline.set_matplotlib_formats('svg')
self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)
if nrows * ncols == 1:
self.axes = [self.axes, ]
self.config_axes = lambda: self.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
self.X, self.Y, self.fmts = None, None, fmts
def set_axes(self, axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
axes.set_xlabel(xlabel)
axes.set_ylabel(ylabel)
axes.set_xlim(xlim)
axes.set_ylim(ylim)
axes.set_xscale(xscale)
axes.set_yscale(yscale)
if legend:
axes.legend(legend)
axes.grid()
def add(self, x, y):
if not hasattr(y, '__len__'):
y = [y]
n = len(y)
if not hasattr(x, '__len__'):
x = [x] * n
if not self.X:
self.X = [[] for _ in range(n)]
if not self.Y:
self.Y = [[] for _ in range(n)]
for i, (a, b) in enumerate(zip(x, y)):
if a is not None and b is not None:
self.X[i].append(a)
self.Y[i].append(b)
self.axes[0].cla()
for x, y, fmt in zip(self.X, self.Y, self.fmts):
self.axes[0].plot(x, y , fmt)
self.config_axes()
display.display(self.fig)
display.clear_output(wait=True)
实现网络
这里我们以LeNet 为例: 使用nn.Sequential 来把所有的网络都放在一起,包括卷积层,pooling层,和全连接层。
net = nn.Sequential(
nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Flatten(),
nn.Linear(256*5*5, 4096), nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, 4096), nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096, 10)
)
X = torch.randn(1, 1, 224, 224)
for layer in net:
X = layer(X)
print(layer.__class__.__name__, 'output shape:\t', X.shape)
训练
class Accumulator:
"""For accumulating sums over `n` variables."""
def __init__(self, n):
"""Defined in :numref:`sec_utils`"""
self.data = [0.0] * n
def add(self, *args):
self.data = [a + float(b) for a, b in zip(self.data, args)]
def reset(self):
self.data = [0.0] * len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def accuracy(y_hat, y):
"""返回预测正确的样本个数(float 类型)"""
# 如果 y_hat 是 logits 或概率(如 [N, C]),取预测类别
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = torch.argmax(y_hat, dim=1)
# 比较预测 vs 真实标签,并确保类型一致(避免 int64 vs int32 问题)
cmp = (y_hat.to(y.dtype) == y)
# 求 cmp 中 True 的个数(True=1, False=0)
return float(cmp.sum()) # cmp.sum() 等价于 torch.sum(cmp)
def evaluate_accuracy_gpu(net, data_iter, device=None):
if isinstance(net, nn.Module):
net.eval()
if not device:
device = next(iter(net.parameters())).device
metric = Accumulator(2)
with torch.no_grad():
for X, y in data_iter:
if isinstance(X, list):
X = [x.to(device) for x in X]
else:
X = X.to(device)
y = y.to(device)
metric.add(accuracy(net(X), y), y.numel())
return metric[0] / metric[1]
# train_loader, test_loader = load_data(128, 224)
# loss = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(net.parameters(), lr=lr)
def train_and_eval(net, batch_size = 128, resize = 224, lr=0.05, num_epochs=10):
def init_weights(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)
train_loader, test_loader = load_data(batch_size, resize)
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net.to(device)
animator = Animator(xlabel='epoch', xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test acc'])
total_start = time.time()
num_batches = len(train_loader)
for epoch in range(num_epochs):
metric = Accumulator(3)
net.train()
epoch_start = time.time()
for i, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
images, labels = images.to(device), labels.to(device)
outputs = net(images)
loss_value = loss(outputs, labels)
loss_value.backward()
optimizer.step()
with torch.no_grad():
metric.add(loss_value * images.shape[0], accuracy(outputs, labels), images.shape[0])
epoch_time = time.time() - epoch_start
train_l = metric[0] / metric[2]
train_acc = metric[1] / metric[2]
if ((i+1) % (num_batches // 5)) == 0 or i == num_batches - 1:
animator.add(epoch + (i + 1) / num_batches, (train_l, train_acc, None))
test_acc = evaluate_accuracy_gpu(net, test_loader, device)
animator.add(epoch + 1, (None, None, test_acc))
print(f'epoch {epoch + 1}, '
f'loss {train_l:.3f}, '
f'train acc {train_acc:.3f}, '
f'test acc {test_acc:.3f}, '
f'time {epoch_time:.2f}s'
f'on {str(device)}')
total_time = time.time() - total_start
total_samples = metric[2] * num_epochs
print(f'Total training time: {total_time:.2f}s')
print(f'{total_samples / total_time:.1f} examples/sec on {device}')