torch自定义数据集模型训练demo

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch import nn
import os
import pandas as pd
from torchvision.io import decode_image


device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
#  独热数据逆转化 ,本例独热化非必须
def arc_one_hot(x,list=torch.tensor([0,1,2,3,4,5,6,7,8,9],dtype=torch.float).to(device)):
    return x@list
#  1.创建自定义数据集
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file, header=None) # 注意首行默认会被作为标题行忽略,或者设置header=None 
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        #print(img_path)
        image = decode_image(img_path).float().div(255) #需要转成float类型,否则无法训练
        #print(image.shape)
        label = self.img_labels.iloc[idx, 1]
        #print(label)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        #独热化
        # print(label)
        new_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
        label = new_transform(label)
        # print("------fuck")
        # print(label)
        return image, label
    

# csv注意是否有标题行
csv_path='/Users/mnist_test_cus_data/imglist_train.csv'
img_dir='/Users/mnist_test_cus_data/imgs_train/'
batch_size = 64
# 创建自定义数据集实例
mydataset = CustomImageDataset(annotations_file=csv_path, img_dir=img_dir, transform=None, target_transform=None)

# 使用 DataLoader 加载数据
mydataloader = DataLoader(mydataset, batch_size, shuffle=True, num_workers=0) #, num_workers=4 macos报错
print(len(mydataloader))
print(len(mydataloader.dataset))
# print(mydataset[59999])
# print(mydataset[0][0])
# print(mydataset[0][1])
# exit()

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
    target_transform = Lambda(lambda y: torch.zeros(10,dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

# 遍历 DataLoader
#for batch in mydataloader:
#    images, labels = batch
    #print(images.size(), labels.size())
    #print(images)
    #print(labels)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")

    # print(X)
    # print(y)
    # print(arc_one_hot(y.to(device)))

    break
print(len(mydataloader))
# exit()


#  2. 可视化数据
def showdata():
    labels_map = {
        0: "T-Shirt",
        1: "Trouser",
        2: "Pullover",
        3: "Dress",
        4: "Coat",
        5: "Sandal",
        6: "Shirt",
        7: "Sneaker",
        8: "Bag",
        9: "Ankle Boot",
    }
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 3
    xxx=''
    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(mydataset), size=(1,)).item()
        img, label = mydataset[sample_idx]
        figure.add_subplot(rows, cols, i)
        # 独热逆转化
        label=arc_one_hot(label.to(device)).item()
        plt.title(labels_map[label])
        plt.axis("off")
        xxx=img
        plt.imshow(img.squeeze(), cmap="gray")
    plt.show()
    print(xxx.shape)
    print('------')
    print(xxx.squeeze().shape)

    # Display image and label.
    train_features, train_labels = next(iter(mydataloader))
    print(f"Feature batch shape: {train_features.size()}")
    print(f"Labels batch shape: {train_labels.size()}")
    img = train_features[0].squeeze()
    label = train_labels[0]
    plt.imshow(img, cmap="gray")
    plt.show()
    print(f"Label: {label}")
    # exit()


# 3.定义模型

print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten() #维度展平
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

# 4. 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss() #交叉熵
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# 5. 训练
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    print("size="+str(size))
    model.train() # 启用 Batch Normalization 和 Dropout,归一化,随机丢弃神经元防止过拟合,测试时不丢弃
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        # Compute prediction error
        pred = model(X)
        # print("---------->pred=")
        # print(pred)
        # print(y)
        # print("-----------------------<")
        # 不需要独热逆转化,交叉熵计算过程包含了独热编码,但是不仍然可以使用独热参数
        # y=arc_one_hot(y)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward() # 计算梯度
        optimizer.step() # 根据梯度优化参数
        optimizer.zero_grad() # 梯度归零

        if batch % 100 == 0: # 每100个batch打印一次
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        # exit()
# 6. 测试
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            #print(y)
            #不需要独热逆编码
            #y=arc_one_hot(y)
            #print(y)
            pred = model(X)
            #print(pred)
            test_loss += loss_fn(pred, y).item()
            #统计个数
            # >>> xx==zz
            # tensor([ True, False, False,  True, False,  True, False,  True, False, False,
            #          True, False, False,  True, False, False,  True, False,  True, False,
            #          True, False, False, False, False,  True,  True, False,  True, False,
            #         False,  True, False, False, False, False, False, False,  True,  True,
            #         False, False, False,  True, False, False,  True, False, False,  True,
            #         False,  True, False,  True,  True, False, False, False, False,  True,
            #         False,  True,  True,  True])
            # >>> (xx==zz).type(torch.float)
            # tensor([1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0.,
            #         1., 0., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0.,
            #         0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1.,
            #         1., 0., 0., 0., 0., 1., 0., 1., 1., 1.])
            # >>> (xx==zz).type(torch.float).sum()
            # tensor(25.)
            # >>> (xx==zz).type(torch.float).sum().item()
            # 25.0
            yy=arc_one_hot(y)
            correct += (pred.argmax(1) == yy).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

#  7. 训练和测试
def do_train():
    epochs = 5
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train(mydataloader, model, loss_fn, optimizer)
        test(test_dataloader, model, loss_fn)
    print("Done!")
    __save_model__()

# for var_name in model.state_dict():
#     print(var_name, "\t", model.state_dict()[var_name])

# for var_name in optimizer.state_dict():
#     print(var_name, "\t", optimizer.state_dict()[var_name])

#  8. 保存模型
def __save_model__():
    path="model.pth"
    torch.save(model.state_dict(), path)
    print("Saved PyTorch Model State to "+path)
#  9. 加载模型
def load_model():
    model = NeuralNetwork().to(device)
    model.load_state_dict(torch.load("model.pth", weights_only=True))
    return model
#  10. 测试模型
def  test_model():
    model = load_model()
    classes = [
        "T-shirt/top",
        "Trouser",
        "Pullover",
        "Dress",
        "Coat",
        "Sandal",
        "Shirt",
        "Sneaker",
        "Bag",
        "Ankle boot",
    ]

    model.eval()
    x, y = test_data[0][0], test_data[0][1]
    with torch.no_grad():
        x = x.to(device)
        pred = model(x)
        # 独热转化,同device才能计算
        y=arc_one_hot(y.to(device)).int()
        print(y)

        print(f"Predicted: {pred[0].argmax(0)}, Actual: {y}")
        predicted, actual = classes[pred[0].argmax(0)], classes[y]
        print(f'Predicted: "{predicted}", Actual: "{actual}"')

def  do_test_model():
    load_model()
    test_model()
def do_train_model():
    showdata()
    do_train()
    test_model()
def main():
    do_train_model()
    #do_test_model()
    

if __name__ == '__main__':
    main()

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注