{"id":2483,"date":"2025-05-14T16:11:47","date_gmt":"2025-05-14T08:11:47","guid":{"rendered":"https:\/\/thereisno.top\/?p=2483"},"modified":"2025-05-16T19:18:53","modified_gmt":"2025-05-16T11:18:53","slug":"torch%e8%87%aa%e5%ae%9a%e4%b9%89%e6%95%b0%e6%8d%ae%e9%9b%86%e6%a8%a1%e5%9e%8b%e8%ae%ad%e7%bb%83demo","status":"publish","type":"post","link":"https:\/\/thereisno.top\/?p=2483","title":{"rendered":"torch\u81ea\u5b9a\u4e49\u6570\u636e\u96c6\u6a21\u578b\u8bad\u7ec3demo"},"content":{"rendered":"\n<pre class=\"wp-block-code\"><code>import torch\nfrom torch.utils.data import Dataset\nfrom torchvision import datasets\nfrom torchvision.transforms import ToTensor, Lambda\nimport matplotlib.pyplot as plt\nfrom torch.utils.data import DataLoader\nfrom torch import nn\nimport os\nimport pandas as pd\nfrom torchvision.io import decode_image\n\n\ndevice = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else \"cpu\"\n#  \u72ec\u70ed\u6570\u636e\u9006\u8f6c\u5316 ,<mark style=\"background-color:#fcb900\" class=\"has-inline-color\">\u672c\u4f8b\u72ec\u70ed\u5316\u975e\u5fc5\u987b<\/mark>\ndef arc_one_hot(x,list=torch.tensor(&#91;0,1,2,3,4,5,6,7,8,9],dtype=torch.float).to(device)):\n    return x@list\n#  1.\u521b\u5efa\u81ea\u5b9a\u4e49\u6570\u636e\u96c6\nclass CustomImageDataset(Dataset):\n    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):\n        self.img_labels = pd.read_csv(annotations_file, header=None) <mark style=\"background-color:#fcb900\" class=\"has-inline-color\"># \u6ce8\u610f\u9996\u884c\u9ed8\u8ba4\u4f1a\u88ab\u4f5c\u4e3a\u6807\u9898\u884c\u5ffd\u7565\uff0c\u6216\u8005\u8bbe\u7f6eheader=None <\/mark>\n        self.img_dir = img_dir\n        self.transform = transform\n        self.target_transform = target_transform\n\n    def __len__(self):\n        return len(self.img_labels)\n\n    def __getitem__(self, idx):\n        img_path = os.path.join(self.img_dir, self.img_labels.iloc&#91;idx, 0])\n        #print(img_path)\n        image = decode_image(img_path).float().div(255) <mark style=\"background-color:#fcb900\" class=\"has-inline-color\">#\u9700\u8981\u8f6c\u6210float\u7c7b\u578b\uff0c\u5426\u5219\u65e0\u6cd5\u8bad\u7ec3<\/mark>\n        #print(image.shape)\n        label = self.img_labels.iloc&#91;idx, 1]\n        #print(label)\n        if self.transform:\n            image = self.transform(image)\n        if self.target_transform:\n            label = self.target_transform(label)\n        #\u72ec\u70ed\u5316\n        # print(label)\n        new_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))\n        label = new_transform(label)\n        # print(\"------fuck\")\n        # print(label)\n        return image, label\n    \n\n<mark style=\"background-color:#fcb900\" class=\"has-inline-color\"># csv\u6ce8\u610f\u662f\u5426\u6709\u6807\u9898\u884c<\/mark>\ncsv_path='\/Users\/mnist_test_cus_data\/imglist_train.csv'\nimg_dir='\/Users\/mnist_test_cus_data\/imgs_train\/'\nbatch_size = 64\n# \u521b\u5efa\u81ea\u5b9a\u4e49\u6570\u636e\u96c6\u5b9e\u4f8b\nmydataset = CustomImageDataset(annotations_file=csv_path, img_dir=img_dir, transform=None, target_transform=None)\n\n# \u4f7f\u7528 DataLoader \u52a0\u8f7d\u6570\u636e\nmydataloader = DataLoader(mydataset, batch_size, shuffle=True, num_workers=0) #, num_workers=4 macos\u62a5\u9519\nprint(len(mydataloader))\nprint(len(mydataloader.dataset))\n# print(mydataset&#91;59999])\n# print(mydataset&#91;0]&#91;0])\n# print(mydataset&#91;0]&#91;1])\n# exit()\n\n# Download test data from open datasets.\ntest_data = datasets.FashionMNIST(\n    root=\"data\",\n    train=False,\n    download=True,\n    transform=ToTensor(),\n    target_transform = Lambda(lambda y: torch.zeros(10,dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))\n)\ntest_dataloader = DataLoader(test_data, batch_size=batch_size)\n\n# \u904d\u5386 DataLoader\n#for batch in mydataloader:\n#    images, labels = batch\n    #print(images.size(), labels.size())\n    #print(images)\n    #print(labels)\n\nfor X, y in test_dataloader:\n    print(f\"Shape of X &#91;N, C, H, W]: {X.shape}\")\n    print(f\"Shape of y: {y.shape} {y.dtype}\")\n\n    # print(X)\n    # print(y)\n    # print(arc_one_hot(y.to(device)))\n\n    break\nprint(len(mydataloader))\n# exit()\n\n\n#  2. \u53ef\u89c6\u5316\u6570\u636e\ndef showdata():\n    labels_map = {\n        0: \"T-Shirt\",\n        1: \"Trouser\",\n        2: \"Pullover\",\n        3: \"Dress\",\n        4: \"Coat\",\n        5: \"Sandal\",\n        6: \"Shirt\",\n        7: \"Sneaker\",\n        8: \"Bag\",\n        9: \"Ankle Boot\",\n    }\n    figure = plt.figure(figsize=(8, 8))\n    cols, rows = 3, 3\n    xxx=''\n    for i in range(1, cols * rows + 1):\n        sample_idx = torch.randint(len(mydataset), size=(1,)).item()\n        img, label = mydataset&#91;sample_idx]\n        figure.add_subplot(rows, cols, i)\n        <mark style=\"background-color:#fcb900\" class=\"has-inline-color\"># \u72ec\u70ed\u9006\u8f6c\u5316<\/mark>\n        label=arc_one_hot(label.to(device)).item()\n        plt.title(labels_map&#91;label])\n        plt.axis(\"off\")\n        xxx=img\n        plt.imshow(img.squeeze(), cmap=\"gray\")\n    plt.show()\n    print(xxx.shape)\n    print('------')\n    print(xxx.squeeze().shape)\n\n    # Display image and label.\n    train_features, train_labels = next(iter(mydataloader))\n    print(f\"Feature batch shape: {train_features.size()}\")\n    print(f\"Labels batch shape: {train_labels.size()}\")\n    img = train_features&#91;0].squeeze()\n    label = train_labels&#91;0]\n    plt.imshow(img, cmap=\"gray\")\n    plt.show()\n    print(f\"Label: {label}\")\n    # exit()\n\n\n# 3.\u5b9a\u4e49\u6a21\u578b\n\nprint(f\"Using {device} device\")\n\n# Define model\nclass NeuralNetwork(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.flatten = nn.Flatten() #\u7ef4\u5ea6\u5c55\u5e73\n        self.linear_relu_stack = nn.Sequential(\n            nn.Linear(28*28, 512),\n            nn.ReLU(),\n            nn.Linear(512, 512),\n            nn.ReLU(),\n            nn.Linear(512, 10)\n        )\n\n    def forward(self, x):\n        x = self.flatten(x)\n        logits = self.linear_relu_stack(x)\n        return logits\n\nmodel = NeuralNetwork().to(device)\nprint(model)\n\n# 4. \u5b9a\u4e49\u635f\u5931\u51fd\u6570\u548c\u4f18\u5316\u5668\nloss_fn = nn.CrossEntropyLoss() #\u4ea4\u53c9\u71b5\noptimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n\n# 5. \u8bad\u7ec3\ndef train(dataloader, model, loss_fn, optimizer):\n    size = len(dataloader.dataset)\n    print(\"size=\"+str(size))\n    model.train() # \u542f\u7528 Batch Normalization \u548c Dropout\uff0c\u5f52\u4e00\u5316\uff0c\u968f\u673a\u4e22\u5f03\u795e\u7ecf\u5143\u9632\u6b62\u8fc7\u62df\u5408\uff0c\u6d4b\u8bd5\u65f6\u4e0d\u4e22\u5f03\n    for batch, (X, y) in enumerate(dataloader):\n        X, y = X.to(device), y.to(device)\n        # Compute prediction error\n        pred = model(X)\n        # print(\"---------->pred=\")\n        # print(pred)\n        # print(y)\n        # print(\"-----------------------&lt;\")\n        <mark style=\"background-color:#fcb900\" class=\"has-inline-color\"># \u4e0d\u9700\u8981\u72ec\u70ed\u9006\u8f6c\u5316\uff0c\u4ea4\u53c9\u71b5\u8ba1\u7b97\u8fc7\u7a0b\u5305\u542b\u4e86\u72ec\u70ed\u7f16\u7801\uff0c\u4f46\u662f\u4e0d\u4ecd\u7136\u53ef\u4ee5\u4f7f\u7528\u72ec\u70ed\u53c2\u6570<\/mark>\n        # y=arc_one_hot(y)\n        loss = loss_fn(pred, y)\n\n        # Backpropagation\n        <mark style=\"background-color:#fcb900\" class=\"has-inline-color\">loss.backward() # \u8ba1\u7b97\u68af\u5ea6\n        optimizer.step() # \u6839\u636e\u68af\u5ea6\u4f18\u5316\u53c2\u6570\n        optimizer.zero_grad() # \u68af\u5ea6\u5f52\u96f6<\/mark>\n\n        if batch % 100 == 0: # \u6bcf100\u4e2abatch\u6253\u5370\u4e00\u6b21\n            loss, current = loss.item(), (batch + 1) * len(X)\n            print(f\"loss: {loss:>7f}  &#91;{current:>5d}\/{size:>5d}]\")\n        # exit()\n# 6. \u6d4b\u8bd5\ndef test(dataloader, model, loss_fn):\n    size = len(dataloader.dataset)\n    num_batches = len(dataloader)\n    model.eval()\n    test_loss, correct = 0, 0\n    with torch.no_grad():\n        for X, y in dataloader:\n            X, y = X.to(device), y.to(device)\n            #print(y)\n            #\u4e0d\u9700\u8981\u72ec\u70ed\u9006\u7f16\u7801\n            #y=arc_one_hot(y)\n            #print(y)\n            pred = model(X)\n            #print(pred)\n            test_loss += loss_fn(pred, y).item()\n            #\u7edf\u8ba1\u4e2a\u6570\n            # >>> xx==zz\n            # tensor(&#91; True, False, False,  True, False,  True, False,  True, False, False,\n            #          True, False, False,  True, False, False,  True, False,  True, False,\n            #          True, False, False, False, False,  True,  True, False,  True, False,\n            #         False,  True, False, False, False, False, False, False,  True,  True,\n            #         False, False, False,  True, False, False,  True, False, False,  True,\n            #         False,  True, False,  True,  True, False, False, False, False,  True,\n            #         False,  True,  True,  True])\n            # >>> (xx==zz).type(torch.float)\n            # tensor(&#91;1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0.,\n            #         1., 0., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0.,\n            #         0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1.,\n            #         1., 0., 0., 0., 0., 1., 0., 1., 1., 1.])\n            # >>> (xx==zz).type(torch.float).sum()\n            # tensor(25.)\n            # >>> (xx==zz).type(torch.float).sum().item()\n            # 25.0\n            yy=arc_one_hot(y)\n            correct += (pred.argmax(1) == yy).type(torch.float).sum().item()\n    test_loss \/= num_batches\n    correct \/= size\n    print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")\n\n#  7. \u8bad\u7ec3\u548c\u6d4b\u8bd5\ndef do_train():\n    epochs = 5\n    for t in range(epochs):\n        print(f\"Epoch {t+1}\\n-------------------------------\")\n        train(mydataloader, model, loss_fn, optimizer)\n        test(test_dataloader, model, loss_fn)\n    print(\"Done!\")\n    __save_model__()\n\n# for var_name in model.state_dict():\n#     print(var_name, \"\\t\", model.state_dict()&#91;var_name])\n\n# for var_name in optimizer.state_dict():\n#     print(var_name, \"\\t\", optimizer.state_dict()&#91;var_name])\n\n#  8. \u4fdd\u5b58\u6a21\u578b\ndef __save_model__():\n    path=\"model.pth\"\n    torch.save(model.state_dict(), path)\n    print(\"Saved PyTorch Model State to \"+path)\n#  9. \u52a0\u8f7d\u6a21\u578b\ndef load_model():\n    model = NeuralNetwork().to(device)\n    model.load_state_dict(torch.load(\"model.pth\", weights_only=True))\n    return model\n#  10. \u6d4b\u8bd5\u6a21\u578b\ndef  test_model():\n    model = load_model()\n    classes = &#91;\n        \"T-shirt\/top\",\n        \"Trouser\",\n        \"Pullover\",\n        \"Dress\",\n        \"Coat\",\n        \"Sandal\",\n        \"Shirt\",\n        \"Sneaker\",\n        \"Bag\",\n        \"Ankle boot\",\n    ]\n\n    model.eval()\n    x, y = test_data&#91;0]&#91;0], test_data&#91;0]&#91;1]\n    with torch.no_grad():\n        x = x.to(device)\n        pred = model(x)\n        # \u72ec\u70ed\u8f6c\u5316,\u540cdevice\u624d\u80fd\u8ba1\u7b97\n        y=arc_one_hot(y.to(device)).int()\n        print(y)\n\n        print(f\"Predicted: {pred&#91;0].argmax(0)}, Actual: {y}\")\n        predicted, actual = classes&#91;pred&#91;0].argmax(0)], classes&#91;y]\n        print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')\n\ndef  do_test_model():\n    load_model()\n    test_model()\ndef do_train_model():\n    showdata()\n    do_train()\n    test_model()\ndef main():\n    do_train_model()\n    #do_test_model()\n    \n\nif __name__ == '__main__':\n    main()<\/code><\/pre>\n","protected":false},"excerpt":{"rendered":"","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[246,14],"tags":[253],"class_list":["post-2483","post","type-post","status-publish","format-standard","hentry","category-ai","category-python","tag-torch"],"_links":{"self":[{"href":"https:\/\/thereisno.top\/index.php?rest_route=\/wp\/v2\/posts\/2483","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/thereisno.top\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/thereisno.top\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/thereisno.top\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/thereisno.top\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=2483"}],"version-history":[{"count":3,"href":"https:\/\/thereisno.top\/index.php?rest_route=\/wp\/v2\/posts\/2483\/revisions"}],"predecessor-version":[{"id":2491,"href":"https:\/\/thereisno.top\/index.php?rest_route=\/wp\/v2\/posts\/2483\/revisions\/2491"}],"wp:attachment":[{"href":"https:\/\/thereisno.top\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=2483"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/thereisno.top\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=2483"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/thereisno.top\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=2483"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}