{"id":2846,"date":"2026-05-06T14:13:46","date_gmt":"2026-05-06T06:13:46","guid":{"rendered":"https:\/\/thereisno.top\/?p=2846"},"modified":"2026-05-06T14:13:46","modified_gmt":"2026-05-06T06:13:46","slug":"%e4%ba%ba%e5%b7%a5%e6%99%ba%e8%83%bd%e5%9c%a8%e8%84%91%e5%8d%92%e4%b8%ad%e7%9b%91%e6%8e%a7%e4%b8%ad%e7%9a%84%e5%ba%94%e7%94%a8","status":"publish","type":"post","link":"https:\/\/thereisno.top\/?p=2846","title":{"rendered":"\u4eba\u5de5\u667a\u80fd\u5728\u8111\u5352\u4e2d\u76d1\u63a7\u4e2d\u7684\u5e94\u7528"},"content":{"rendered":"\n<h2 class=\"wp-block-heading\">1. \u80cc\u666f\u4ecb\u7ecd<\/h2>\n\n\n\n<p>\u8111\u5352\u4e2d\u4e00\u822c\u53d1\u75c5\u6025\uff0c\u65e0\u5f81\u5146\u6216\u8005\u5f81\u5146\u4e0d\u660e\u663e\uff0c\u5927\u591a\u6570\u8111\u5352\u4e2d\u60a3\u8005\u4f34\u6709\u5634\u6b6a\u773c\u659c\u7684\u75c7\u72b6\u3002\u5f53\u524d\u7684\u667a\u80fd\u6444\u50cf\u5934\u6709\u4eba\u8138\u8bc6\u522b\u3001\u6454\u5012\u68c0\u6d4b\u3001\u5a74\u513f\u557c\u54ed\u7b49\u529f\u80fd\uff0c\u4f46\u662f\u5bf9\u4e8e\u8111\u5352\u4e2d\u7684\u76d1\u63a7\u76ee\u524d\u5c1a\u65e0\u76f8\u5173\u7684\u4e0a\u5e02\u4ea7\u54c1\u3002\u4e8e\u662f\uff0c\u4f5c\u8005\u57fa\u4e8e\u4eba\u5de5\u667a\u80fd\u6280\u672f\uff0c\u5f00\u53d1\u4e00\u6b3e\u57fa\u4e8e\u6444\u50cf\u5934\u7684\u8111\u5352\u4e2d\u76d1\u63a7\u7cfb\u7edf\uff0c\u5145\u5206\u5229\u7528\u6444\u50cf\u5934\u4eba\u7269\u8ddf\u968f\u529f\u80fd\u6765\u5b9e\u65f6\u76d1\u63a7\u6613\u611f\u4eba\uff0c\u82e5\u53d1\u73b0\u5634\u6b6a\u773c\u659c\u7b49\u8111\u5352\u4e2d\u75c7\u72b6\uff0c\u5219\u7acb\u5373\u53d1\u51fa\u8b66\u544a\uff0c\u901a\u77e5\u5bb6\u4eba\u53ca\u65f6\u5c31\u533b\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\">2. \u6838\u5fc3\u7b97\u6cd5\u6a21\u578b<\/h2>\n\n\n\n<p>\u672c\u7cfb\u7edf\u91c7\u75282\u5c42\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\uff08CNN\uff09\u8fdb\u884c\u56fe\u50cf\u8bc6\u522b\u3002CNN\u901a\u8fc7\u5377\u79ef\u5c42\u548c\u6c60\u5316\u5c42\u8fdb\u884c\u7279\u5f81\u63d0\u53d6\uff0c\u7136\u540e\u901a\u8fc7\u5168\u8fde\u63a5\u5c42\u8fdb\u884c\u5206\u7c7b\u3002 \u6a21\u578b\u6587\u4ef6\uff1amymodel.py<\/p>\n\n\n\n<!--more-->\n\n\n\n<pre id=\"codecell0\" class=\"wp-block-preformatted\">import torch\nfrom torch.utils.data import Dataset\nfrom torchvision.transforms import  Lambda\nfrom torch import nn\nimport os\nimport pandas as pd\nfrom torchvision.io import decode_image\nfrom torchvision import transforms\n\ndevice = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else \"cpu\"\n#  \u72ec\u70ed\u6570\u636e\u9006\u8f6c\u5316\ndef arc_one_hot(x,list=torch.tensor([0,1],dtype=torch.float).to(device)):\n    return x@list\n#  1.\u521b\u5efa\u81ea\u5b9a\u4e49\u6570\u636e\u96c6\nclass CustomImageDataset(Dataset):\n    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):\n        self.img_labels = pd.read_csv(annotations_file, header=None) # \u6ce8\u610f\u9996\u884c\u9ed8\u8ba4\u4f1a\u88ab\u4f5c\u4e3a\u6807\u9898\u884c\u5ffd\u7565\uff0c\u6216\u8005\u8bbe\u7f6eheader=None \n        self.img_dir = img_dir\n        self.transform = transform\n        self.target_transform = target_transform\n\n    def __len__(self):\n        return len(self.img_labels)\n\n    def __getitem__(self, idx):\n        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])\n        #print(img_path)\n        image = decode_image(img_path).float().div(255) #\u9700\u8981\u8f6c\u6210float\u7c7b\u578b\uff0c\u5426\u5219\u65e0\u6cd5\u8bad\u7ec3\n        #print(image.shape)\n        label = self.img_labels.iloc[idx, 1]\n        filename = self.img_labels.iloc[idx, 0]\n        #print(label)\n        if self.transform:\n            image = self.transform(image)\n        if self.target_transform:\n            label = self.target_transform(label)\n        #\u72ec\u70ed\u5316\n        # print(label)\n        new_transform = Lambda(lambda y: torch.zeros(2, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))\n        label = new_transform(label)\n        return image, label, filename\n    \nclass NeuralNetwork(nn.Module):\n    def __init__(self):\n        super().__init__()\n        \n        # --- \u5377\u79ef\u5757 ---\n        # \u7b2c\u4e00\u5c42\u5377\u79ef: \u8f93\u51651\u901a\u9053, \u8f93\u51fa32\u901a\u9053, 3x3\u5377\u79ef\u6838, padding=1\u4fdd\u6301\u5c3a\u5bf8\n        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)\n        self.bn1=nn.BatchNorm2d(32) # \u6dfb\u52a0 BN\n        # \u975e\u7ebf\u6027\u6fc0\u6d3b\u51fd\u6570\uff0c\u6b63\u6570\u4fdd\u7559\uff0c\u8d1f\u6570\u7f6e\u96f6\n        self.relu1 = nn.ReLU()\n        # \u6700\u5927\u6c60\u5316: 2x2\u7a97\u53e3, \u6b65\u957f2, \u5c3a\u5bf8\u51cf\u534a (100-&gt;50)\n        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)\n        \n        # # \u7b2c\u4e8c\u5c42\u5377\u79ef (\u53ef\u9009, \u63a8\u8350\u52a0\u6df1\u7f51\u7edc): \u8f93\u516532, \u8f93\u51fa64\n        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)\n        #\u5728\u5377\u79ef\u5c42\u548c\u6fc0\u6d3b\u51fd\u6570\u4e4b\u95f4\u52a0\u5165 Batch Normalization (BN) \u5c42\u3002BN \u53ef\u4ee5\u89e3\u51b3\u5185\u90e8\u534f\u53d8\u91cf\u504f\u79fb\u95ee\u9898\uff0c\u5141\u8bb8\u4f7f\u7528\u66f4\u5927\u7684\u5b66\u4e60\u7387\uff0c\u5e76\u663e\u8457\u52a0\u901f\u8bad\u7ec3\u3002\n        self.bn2=nn.BatchNorm2d(64) # \u6dfb\u52a0 BN\n        self.relu2 = nn.ReLU()\n        # \u518d\u6b21\u6c60\u5316: \u5c3a\u5bf8\u51cf\u534a (50-&gt;25)\n        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)\n\n        # --- \u5168\u8fde\u63a5\u5757 ---\n        self.flatten = nn.Flatten()\n        \n        # \u8ba1\u7b97\u5c55\u5e73\u540e\u7684\u7ef4\u5ea6: \n        # \u521d\u59cb 100x100 -&gt; Pool1 -&gt; 50x50 -&gt; Pool2 -&gt; 25x25 -&gt; Pool3 -&gt; 12x12\n        # \u901a\u9053\u6570: 64\n        # \u603b\u7279\u5f81\u6570: 64 * 25 * 25 = 40000\n        fc_input_dim = 64 * 25 * 25\n        \n        self.linear_relu_stack = nn.Sequential(\n            nn.Linear(fc_input_dim, 512),\n            nn.ReLU(),\n            nn.Dropout(0.5), # \u6dfb\u52a0Dropout\u9632\u6b62\u8fc7\u62df\u5408\n            nn.Linear(512, 2), # \u8f93\u51fa2\u4e2a\u7c7b\u522b\n        )\n\n    def forward(self, x):\n        # \u5377\u79ef\u90e8\u5206\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu1(x)\n        x = self.pool1(x)\n        \n        x = self.conv2(x)\n        x = self.bn2(x)\n        x = self.relu2(x)\n        x = self.pool2(x)\n\n        \n        # \u5c55\u5e73\n        x = self.flatten(x)\n        # \u5168\u8fde\u63a5\u5206\u7c7b\n        logits = self.linear_relu_stack(x)\n        return logits\n    \n# \u5b9a\u4e49\u53d8\u6362\uff1a\u8c03\u6574\u4e3a 28x28 \u5e76\u8f6c\u4e3a\u7070\u5ea6\u56fe (1\u901a\u9053)\n# \u6ce8\u610f\uff1adecode_image \u8fd4\u56de\u7684\u662f Tensor\uff0ctransforms \u652f\u6301 Tensor \u8f93\u5165\ndata_transforms = transforms.Compose([\n    transforms.Resize((100, 100)),       # \u786e\u4fdd\u5c3a\u5bf8\u4e00\u81f4\n    transforms.Grayscale(num_output_channels=1), # \u5173\u952e\uff1a\u8f6c\u4e3a1\u901a\u9053\u7070\u5ea6\u56fe\n    transforms.Normalize(mean=[0.5], std=[0.5]) # \u5c06\u6570\u636e\u6807\u51c6\u5316\u5230 [-1, 1] \u533a\u95f4\uff0c\u6709\u52a9\u4e8e\u52a0\u901f\u6536\u655b\u3002\n])\n<\/pre>\n\n\n\n<h2 class=\"wp-block-heading\">3. \u8bad\u7ec3\u6a21\u578b<\/h2>\n\n\n\n<p>\u8bad\u7ec3\u6a21\u578b\u4ee3\u7801\uff1amytrain.py<\/p>\n\n\n\n<pre id=\"codecell1\" class=\"wp-block-preformatted\">import torch\nfrom torch.utils.data import Dataset\nfrom torchvision import datasets\nfrom torchvision.transforms import ToTensor, Lambda\nimport matplotlib.pyplot as plt\nfrom torch.utils.data import DataLoader\nfrom torch import nn\nimport os\nimport pandas as pd\nfrom torchvision.io import decode_image\nfrom torchvision import transforms\nfrom mymodule.mymodel import NeuralNetwork, CustomImageDataset,arc_one_hot,device,data_transforms\nimport torchvision.utils as vutils\nfrom mymodule.model_visualization import VisualizeModel\n\nfrom torch.utils.tensorboard import SummaryWriter\n\nimport argparse\nparser = argparse.ArgumentParser(description='location the model')\nparser.add_argument(\"-m\",\"--model\", type=str, default='model.pth', help='location of the model')\nparser.add_argument(\"-cl\",\"--checkpoint-latest\", type=str, default='checkpoint_latest.pth', help='location of the latest checkpoint')\nparser.add_argument(\"-cb\",\"--checkpoint-best\", type=str, default='checkpoint_best.pth', help='location of the best checkpoint')\n\nargs = parser.parse_args()\nmodel_path=args.model\ncheckpoint_latest_path=args.checkpoint_latest\ncheckpoint_best_path=args.checkpoint_best\n\n# csv\u6ce8\u610f\u662f\u5426\u6709\u6807\u9898\u884c\ncsv_path='cuzhong\/train.csv'\nimg_dir='cuzhong\/train\/'\nbatch_size = 64\n\n# \u521b\u5efa\u81ea\u5b9a\u4e49\u6570\u636e\u96c6\u5b9e\u4f8b\nmydataset = CustomImageDataset(annotations_file=csv_path, img_dir=img_dir, transform=data_transforms, target_transform=None)\n\n# \u4f7f\u7528 DataLoader \u52a0\u8f7d\u6570\u636e\nmydataloader = DataLoader(mydataset, batch_size, shuffle=True, num_workers=0) #, num_workers=4 macos\u62a5\u9519\nprint(len(mydataloader))\nprint(len(mydataloader.dataset))\n\ncsv_path_test='cuzhong\/test.csv'\nimg_dir_test='cuzhong\/test\/'\n# \u521b\u5efa\u81ea\u5b9a\u4e49\u6570\u636e\u96c6\u5b9e\u4f8b\ntest_data = CustomImageDataset(annotations_file=csv_path_test, img_dir=img_dir_test, transform=data_transforms, target_transform=None)\n\n\ntest_dataloader = DataLoader(test_data, batch_size=batch_size)\n\n\nfor X, y,z in test_dataloader:\n    print(f\"Shape of X [N, C, H, W]: {X.shape}\")\n    print(f\"Shape of y: {y.shape} {y.dtype}\")\n    break\nprint(len(mydataloader))\n# exit()\n\n\n#  2. \u53ef\u89c6\u5316\u6570\u636e\ndef showdata():\n    labels_map = {\n        0: \"OK\",\n        1: \"Error\",\n    }\n    figure = plt.figure(figsize=(8, 8))\n    cols, rows = 2, 1\n    xxx=''\n    for i in range(1, cols * rows + 1):\n        sample_idx = torch.randint(len(mydataset), size=(1,)).item()\n        img, label, filename = mydataset[sample_idx]\n        figure.add_subplot(rows, cols, i)\n        # \u72ec\u70ed\u9006\u8f6c\u5316\n        label=arc_one_hot(label.to(device)).item()\n        plt.title(labels_map[label]+\"  @\"+filename)\n        plt.axis(\"off\")\n        xxx=img\n        print(img.shape)\n        plt.imshow(img.squeeze(), cmap=\"gray\")\n    plt.show()\n    print(xxx.shape)\n    print('------')\n    print(xxx.squeeze().shape)\n\n    # Display image and label.\n    train_features, train_labels, train_filenames = next(iter(mydataloader))\n    print(f\"Feature batch shape: {train_features.size()}\")\n    print(f\"Labels batch shape: {train_labels.size()}\")\n    img = train_features[0].squeeze()\n    label = train_labels[0]\n    # \u72ec\u70ed\u9006\u8f6c\u5316\n    label=arc_one_hot(label.to(device)).item()\n    plt.title(labels_map[label]+\"  @\"+train_filenames[0])\n    plt.axis(\"off\")\n    plt.imshow(img, cmap=\"gray\")\n    plt.show()\n    print(f\"Label: {label}\")\n    # exit()\n\n\n# 3.\u5b9a\u4e49\u6a21\u578b\n\nprint(f\"Using {device} device\")\n\n    \nmodel = NeuralNetwork().to(device)\nprint(model)\n\n#\u521d\u59cb\u5316TensorBoard writer\nwriter = SummaryWriter('runs\/logs\/mymodule_visualize_model_runs')\n\n# 4. \u5b9a\u4e49\u635f\u5931\u51fd\u6570\u548c\u4f18\u5316\u5668\nloss_fn = nn.CrossEntropyLoss() #\u4ea4\u53c9\u71b5\n#\u968f\u673a\u68af\u5ea6\u4e0b\u964d\uff0cSGD\uff08\u968f\u673a\u68af\u5ea6\u4e0b\u964d\uff09\u867d\u7136\u7ecf\u5178\uff0c\u4f46\u5f80\u5f80\u9700\u8981\u7cbe\u7ec6\u8c03\u6574\u5b66\u4e60\u7387\u4e14\u6536\u655b\u8f83\u6162\u3002Adam \u4f18\u5316\u5668\u80fd\u81ea\u9002\u5e94\u8c03\u6574\u6bcf\u4e2a\u53c2\u6570\u7684\u5b66\u4e60\u7387\uff0c\u901a\u5e38\u80fd\u663e\u8457\u52a0\u5feb\u6536\u655b\u901f\u5ea6\uff0c\u5c24\u5176\u662f\u5728 CNN \u4efb\u52a1\u4e2d\u3002\n#optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n# \u4fee\u6539\u4e3aAdam\u4f18\u5316\u5668\noptimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)\n# 5. \u8bad\u7ec3\ndef train(dataloader, model, loss_fn, optimizer):\n    size = len(dataloader.dataset)\n    print(\"size=\"+str(size))\n    model.train() # \u542f\u7528 Batch Normalization \u548c Dropout\uff0c\u5f52\u4e00\u5316\uff0c\u968f\u673a\u4e22\u5f03\u795e\u7ecf\u5143\u9632\u6b62\u8fc7\u62df\u5408\uff0c\u6d4b\u8bd5\u65f6\u4e0d\u4e22\u5f03\n    for batch, (X, y,z) in enumerate(dataloader):\n        X, y = X.to(device), y.to(device)\n        # Compute prediction error\n        pred = model(X)\n        loss = loss_fn(pred, y)\n\n        # Backpropagation\n        loss.backward() # \u8ba1\u7b97\u68af\u5ea6\n        optimizer.step() # \u6839\u636e\u68af\u5ea6\u4f18\u5316\u53c2\u6570\n        optimizer.zero_grad() # \u68af\u5ea6\u5f52\u96f6\n\n        if batch % 1 == 0: # \u6bcf100\u4e2abatch\u6253\u5370\u4e00\u6b21\n            loss, current = loss.item(), (batch + 1) * len(X)\n            print(f\"loss: {loss:&gt;7f}  [{current:&gt;5d}\/{size:&gt;5d}]\")\n        # exit()\n    writer.close()\n# 6. \u6d4b\u8bd5\ndef test(dataloader, model, loss_fn):\n    size = len(dataloader.dataset)\n    num_batches = len(dataloader)\n    model.eval()\n    test_loss, correct = 0, 0\n    with torch.no_grad():\n        for X, y,z in dataloader:\n            X, y = X.to(device), y.to(device)\n            \n            pred = model(X)\n            #print(pred)\n            test_loss += loss_fn(pred, y).item()\n            #\u7edf\u8ba1\u4e2a\u6570\n            yy=arc_one_hot(y)\n            # \u5206\u522b\u6bd4\u8f83\uff0c\u5f97\u5230\u771f\u5047\u503c\uff0c\u7136\u540e\u8f6c\u6210\u6d6e\u70b9\u6570\uff0c\u6c42\u548c\uff0c\u5f97\u5230\u4e2a\u6570\n            correct += (pred.argmax(1) == yy).type(torch.float).sum().item()\n    test_loss \/= num_batches\n    correct \/= size\n    print(f\"Test Error: \\n Accuracy: {(100*correct):&gt;0.1f}%, Avg loss: {test_loss:&gt;8f}\\n\")\n    return test_loss,correct\n\n#  7. \u8bad\u7ec3\u548c\u6d4b\u8bd5\ndef do_train():\n    epochs = 50\n    # \u521d\u59cb\u5316\u6a21\u578b\u548c\u4f18\u5316\u5668 (resume=False \u8868\u793a\u4ece\u5934\u5f00\u59cb\uff0c\u5982\u679c\u60f3\u65ad\u70b9\u7eed\u8bad\u6539\u4e3a True)\n    model, optimizer, scheduler, start_epoch, best_acc = load_model(resume=True, checkpoint_best_path=checkpoint_best_path, checkpoint_latest_path=checkpoint_latest_path, path=model_path)\n    # \u91cd\u65b0\u7ed1\u5b9a\u5168\u5c40\u53d8\u91cf\u4e2d\u7684 model, optimizer, scheduler (\u5982\u679c\u5b83\u4eec\u5728\u5916\u90e8\u5b9a\u4e49)\n    # \u6ce8\u610f\uff1a\u6700\u597d\u5c06 model, optimizer, scheduler \u4f5c\u4e3a\u53c2\u6570\u4f20\u9012\uff0c\u6216\u8005\u5728 global \u4f5c\u7528\u57df\u66f4\u65b0\u5b83\u4eec\n    globals()['model'] = model\n    globals()['optimizer'] = optimizer\n    globals()['scheduler'] = scheduler\n    if scheduler is  None:\n        # \u6dfb\u52a0\u8c03\u5ea6\u5668\uff1a\u5f53\u9a8c\u8bc1\u96c6 Loss \u4e0d\u518d\u4e0b\u964d\u65f6\uff0c\u5c06\u5b66\u4e60\u7387\u4e58\u4ee5 0.5\n        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)\n   \n    #\u6700\u4f73\u6b63\u786e\u7387\n    best_acc=0.0\n    for t in range(epochs):\n        print(f\"Epoch {t+1}\\n-------------------------------\")\n        train(mydataloader, model, loss_fn, optimizer)\n        # \u3010\u5173\u952e\u3011\u6d4b\u8bd5\u5e76\u83b7\u53d6 Loss\uff0c\u7528\u4e8e\u8c03\u5ea6\u5668\u5224\u65ad\n        # \u6ce8\u610f\uff1a\u4f60\u9700\u8981\u4fee\u6539 test \u51fd\u6570\u8fd4\u56de test_loss\uff0c\u6216\u8005\u5728\u8fd9\u91cc\u8ba1\u7b97\n        test_loss, correct = test(test_dataloader, model, loss_fn) # \u5047\u8bbe test \u8fd4\u56de\u8fd9\u4e24\u4e2a\u503c\n        \n        # \u5224\u65ad\u662f\u5426\u4e3a\u6700\u4f73\u6a21\u578b\n        is_best = correct &gt; best_acc\n        if is_best:\n            best_acc = correct\n            # \u4fdd\u5b58 Checkpoint\n            if correct &gt;= 0.75:\n                __save_model0__(f\"best\/best_model_{t}_{correct*100:.2f}.pth\")\n        # \u4fdd\u5b58 Checkpoint\n        __save_model__(t + 1, model, optimizer, scheduler, best_acc, is_best)\n\n        # if correct &gt;= 90:\n        #     break\n        scheduler.step(test_loss) # \u5c06 test_loss \u4f20\u9012\u7ed9\u8c03\u5ea6\u5668\n    print(\"Done!\")\n    __save_model0__()\n\n\n#  8. \u4fdd\u5b58\u6a21\u578b\ndef __save_model0__(path=\"model.pth\"):\n    torch.save(model.state_dict(), path)\n    print(\"Saved PyTorch Model State to \"+path)\n\ndef __save_model__(epoch, model, optimizer, scheduler, best_acc, is_best=False):\n    \"\"\"\n    \u4fdd\u5b58 Checkpoint\n    :param is_best: \u5982\u679c\u4e3a True\uff0c\u5219\u989d\u5916\u4fdd\u5b58\u4e3a best_model.pth\n    \"\"\"\n    checkpoint = {\n        'epoch': epoch,\n        'model_state_dict': model.state_dict(),      # \u6743\u91cd\n        'optimizer_state_dict': optimizer.state_dict(), # \u8fd9\u91cc\u8bb0\u5f55\u4e86 Adam \u7684 exp_avg, exp_avg_sq, step\n        'scheduler_state_dict': scheduler.state_dict(), # \u8fd9\u91cc\u8bb0\u5f55\u4e86\u8c03\u5ea6\u5668\u7684\u6b65\u6570\u548c\u5f53\u524d LR\n        'best_acc': best_acc,\n    }\n    \n    # \u4fdd\u5b58\u6700\u65b0\u7684 checkpoint\n    torch.save(checkpoint, 'checkpoint_latest.pth')\n    \n    # \u5982\u679c\u662f\u6700\u4f73\u6a21\u578b\uff0c\u989d\u5916\u4fdd\u5b58\u4e00\u4efd\n    if is_best:\n        torch.save(checkpoint, 'checkpoint_best.pth')\n        print(f\"Saved Best Model at Epoch {epoch} with Acc: {best_acc:.4f}\")\n    else:\n        print(f\"Saved Latest Checkpoint at Epoch {epoch}\")\n\n#  9. \u52a0\u8f7d\u6a21\u578b\ndef load_model0(path=\"model.pth\"):\n    model = NeuralNetwork().to(device)\n    model.load_state_dict(torch.load(path, weights_only=True,map_location=device))\n    return model\n\ndef load_model(resume=False, path=\"model.pth\", checkpoint_latest_path=\"checkpoint_latest.pth\", checkpoint_best_path=\"checkpoint_best.pth\"):\n    \"\"\"\n    \u52a0\u8f7d\u6a21\u578b\n    :param resume: \u5982\u679c\u4e3a True\uff0c\u5c1d\u8bd5\u4ece latest checkpoint \u6062\u590d\u8bad\u7ec3\u72b6\u6001\n    \"\"\"\n    model = NeuralNetwork().to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)\n    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)\n    \n    start_epoch = 0\n    best_acc = 0.0\n    \n    if resume:\n        if os.path.exists(checkpoint_latest_path):\n            print(\"=&gt; Loading checkpoint 'checkpoint_latest.pth'\")\n            checkpoint = torch.load(checkpoint_latest_path, map_location=device)\n            start_epoch = checkpoint['epoch']\n            model.load_state_dict(checkpoint['model_state_dict'])\n            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n            best_acc = checkpoint.get('best_acc', 0.0)\n            print(f\"Resumed from Epoch {start_epoch}, Best Acc: {best_acc:.4f}\")\n        else:\n            print(\"=&gt; No checkpoint found, starting from scratch.\")\n    else:\n        # \u5982\u679c\u53ea\u662f\u6d4b\u8bd5\uff0c\u52a0\u8f7d\u6700\u4f73\u6a21\u578b\n        if os.path.exists(checkpoint_best_path):\n            print(\"=&gt; Loading best model 'checkpoint_best.pth' for testing\")\n            checkpoint = torch.load(checkpoint_best_path, map_location=device)\n            model.load_state_dict(checkpoint['model_state_dict'])\n        elif os.path.exists(path):\n            # \u517c\u5bb9\u65e7\u4ee3\u7801\n            model.load_state_dict(torch.load(path, weights_only=True, map_location=device))\n\n    return model, optimizer, scheduler, start_epoch, best_acc\n#  10. \u6d4b\u8bd5\u6a21\u578b\ndef  test_model():\n    # model = load_model()\n    # \u52a0\u8f7d\u6700\u4f73\u6a21\u578b\u7528\u4e8e\u6d4b\u8bd5\n    model, _, _, _, _ = load_model(resume=False, checkpoint_best_path=checkpoint_best_path, checkpoint_latest_path=checkpoint_latest_path, path=model_path)\n    classes = [\n        \"OK\",\n        \"Error\",\n    ]\n\n    model.eval()\n\n    visualize_model = VisualizeModel(model)\n    # \u63d0\u53d6\u7b2c\u4e00\u4e2a\u5377\u79ef\u5c42\u7684\u5377\u79ef\u6838\n    conv1_weights = model.conv1.weight.data\n    conv2_weights = model.conv2.weight.data\n    print(\"conv1_weights\")\n    print(conv1_weights.shape)\n    visualize_model.visualize_filters(conv1_weights,\"conv1\")\n    print(\"conv2_weights\")\n    print(conv2_weights.shape)\n    visualize_model.visualize_filters(conv2_weights,\"conv2\")\n\n    \n\n    for data in test_data:\n        x, y,z = data[0], data[1], data[2]\n        visualize_model = VisualizeModel(model)\n        handle1 = visualize_model.register_hook(model.conv1)\n        handle2 = visualize_model.register_hook(model.conv2)\n        with torch.no_grad():\n            x = x.to(device)\n            # \u3010\u5173\u952e\u4fee\u590d\u3011\u589e\u52a0\u6279\u6b21\u7ef4\u5ea6: (C, H, W) -&gt; (1, C, H, W)\n            if x.dim() == 3:\n                x = x.unsqueeze(0)\n            print(x.shape)\n            pred = model(x)\n            # \u72ec\u70ed\u8f6c\u5316,\u540cdevice\u624d\u80fd\u8ba1\u7b97\n            y=arc_one_hot(y.to(device)).int()\n            print(y)\n\n            print(f\"Predicted: {pred[0].argmax(0)}, Actual: {y}\")\n            predicted, actual = classes[pred[0].argmax(0)], classes[y]\n            print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')\n            if predicted != actual:\n                print(\"----------error------------- File:\"+z+\"\\n\")\n\n        # \u63d0\u53d6\u5b8c\u6570\u636e\u540e\uff0c\u79fb\u9664\u94a9\u5b50\uff08\u597d\u4e60\u60ef\uff0c\u9632\u6b62\u5185\u5b58\u6cc4\u6f0f\uff09\n        handle1.remove()\n        handle2.remove()\n        visualize_model.visualize()\n\n        visualize_model.grad_cam_visualize(x, model,classes, target_class=pred[0].argmax(0),actual_class=y.item(),filename=z,device=device)\n        #saliency_map\u53ef\u89c6\u5316\uff0c\u663e\u793a\u8f93\u5165\u56fe\u50cf\u7684\u68af\u5ea6\u70ed\u56fe\uff0c\u53cd\u6620\u6a21\u578b\u5bf9\u8f93\u5165\u7684\u654f\u611f\u7a0b\u5ea6\n        visualize_model.saliency_map_visualize(x, model,classes, target_class=pred[0].argmax(0),actual_class=y.item(),filename=z,device=device)\n        # break  # \u53ea\u770b\u7b2c\u4e00\u5f20\u56fe\u7684\u7279\u5f81\u56fe\uff0c\u53bb\u6389\u8fd9\u4e2a break \u5c31\u4f1a\u4e00\u76f4\u663e\u793a\u6bcf\u5f20\u56fe\u7684\u7279\u5f81\u56fe\n\ndef  do_test_model():\n    # load_model()\n    test_model()\ndef do_train_model():\n    # showdata()\n    do_train()\n    test_model()\ndef main():\n    do_train_model()\n    # do_test_model()\n    \n\nif __name__ == '__main__':\n    main()\n\n<\/pre>\n\n\n\n<h2 class=\"wp-block-heading\">4. \u6a21\u578b\u5e94\u7528<\/h2>\n\n\n\n<p>\u5bf9\u4e8e\u76d1\u63a7\u7167\u7247\u7684\u83b7\u53d6\uff0c\u4e5f\u8d39\u4e86\u4e00\u4e9b\u5468\u6298\u3002\u6700\u521d\u60f3\u7ed9\u6811\u8393\u6d3e\u914d\u4e2a\u6444\u50cf\u5934\uff0c\u53ef\u4ee5\u5b9e\u73b0\u81ea\u5b9a\u4e49\u62cd\u6444\u3002\u7136\u800c\u6811\u8393\u6d3e\u6444\u50cf\u5934\u62cd\u6444\u6548\u679c\u8ddf\u667a\u80fd\u6444\u50cf\u5934\u8fd8\u6709\u8f83\u5927\u5dee\u8ddd\uff0c\u6bd4\u5982\u591c\u89c6\u3001\u77eb\u7578\u3001\u79fb\u52a8\u8ddf\u968f\u7b49\u7b49\u529f\u80fd\u7f3a\u5931\uff0c\u8fd8\u8981\u5355\u72ec\u9020\u8f6e\u5b50\u3002\u90a3\u4e3a\u4ec0\u4e48\u4e0d\u76f4\u63a5\u4ece\u6444\u50cf\u5934\u83b7\u53d6\u6570\u636e\u5462\uff1f\u7ecf\u8fc7\u7814\u7a76\uff0c\u652f\u6301<code>rtsp<\/code>\u534f\u8bae\u7684\u6444\u50cf\u5934\uff0c\u53ef\u4ee5\u5f88\u65b9\u4fbf\u7684\u83b7\u53d6\u6570\u636e\u3002\u53ef\u60dc\u7684\u662f\u624b\u5934\u7684\u5c0f\u7c73\u6444\u50cf\u5934\u5e76\u4e0d\u652f\u6301<code>rtsp<\/code>\u534f\u8bae\uff0c\u6700\u7ec8\u627e\u5230\u4e86<code>\u8424\u77f3<\/code>\u6444\u50cf\u5934\u3002<\/p>\n\n\n\n<h3 class=\"wp-block-heading\">4.1 \u7cfb\u7edf\u6784\u6210<\/h3>\n\n\n\n<ol start=\"0\" class=\"wp-block-list\">\n<li><code>\u8424\u77f3C7<\/code>\u6444\u50cf\u5934\uff0c\u7528\u4e8e\u56fe\u50cf\u91c7\u96c6\u3002<\/li>\n\n\n\n<li><code>\u6811\u8393\u6d3e<\/code>\u4f5c\u4e3a\u670d\u52a1\u5668\uff0c\u5b9a\u65f6\u91c7\u96c6\u6444\u50cf\u5934\u6570\u636e\uff0c\u5e76\u8fd0\u884c\u6a21\u578b\u8fdb\u884c\u9884\u6d4b\u3002<\/li>\n<\/ol>\n\n\n\n<h3 class=\"wp-block-heading\">4.2 \u6a21\u578b\u9884\u6d4b<\/h3>\n\n\n\n<p>\u6a21\u578b\u9884\u6d4b\u7684\u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre id=\"codecell2\" class=\"wp-block-preformatted\">import torch\nfrom torch.utils.data import DataLoader\nfrom torch import nn\nfrom torchvision import transforms\nfrom mymodule.mymodel import NeuralNetwork, CustomImageDataset,arc_one_hot,device,data_transforms\nimport time,os\n\nimport argparse\nparser = argparse.ArgumentParser(description='location the model')\nparser.add_argument(\"-m\",\"--model\", type=str, default='model.pth', help='location of the model')\nargs = parser.parse_args()\nmodel_path=args.model\nbatch_size = 16\n\n\ncsv_path_test='cuzhong\/test.csv'\nimg_dir_test='cuzhong\/test\/'\n\n# 3.\u5b9a\u4e49\u6a21\u578b\n\nprint(f\"Using {device} device\")\n    \n#  9. \u52a0\u8f7d\u6a21\u578b\ndef load_model(path=\"model.pth\"):\n    model = NeuralNetwork().to(device)\n    model.load_state_dict(torch.load(path, weights_only=True,map_location=device))\n    return model\n#  10. \u6d4b\u8bd5\u6a21\u578b\ndef  test_model(path=model_path):\n    model = load_model(path)\n    classes = [\n        \"OK\",\n        \"Error\",\n    ]\n\n    model.eval()\n    \n    # \u521b\u5efa\u81ea\u5b9a\u4e49\u6570\u636e\u96c6\u5b9e\u4f8b\n    test_data = CustomImageDataset(annotations_file=csv_path_test, img_dir=img_dir_test, transform=data_transforms, target_transform=None)\n    for data in test_data:\n        x, y,z = data[0], data[1], data[2]\n        with torch.no_grad():\n            x = x.to(device)\n            # \u3010\u5173\u952e\u4fee\u590d\u3011\u589e\u52a0\u6279\u6b21\u7ef4\u5ea6: (C, H, W) -&gt; (1, C, H, W)\n            if x.dim() == 3:\n                x = x.unsqueeze(0)\n            pred = model(x)\n            # \u72ec\u70ed\u8f6c\u5316,\u540cdevice\u624d\u80fd\u8ba1\u7b97\n            y=arc_one_hot(y.to(device)).int()\n            predicted, actual = classes[pred[0].argmax(0)], classes[y]\n            if predicted != actual:\n                print(\"---Error: \"+z+\"---\\t\"+f'Predicted: \"{predicted}\", Actual: \"{actual}\"')\n                cmd=\"cvlc --play-and-exit  700hzbeep.mp3\"\n                os.system(cmd)\n\ndef  do_test_model():\n    #load_model()\n    test_model()\n\ndef main():\n    do_test_model()\n    \n\nif __name__ == '__main__':\n    main()<\/pre>\n","protected":false},"excerpt":{"rendered":"<p>1. \u80cc\u666f\u4ecb\u7ecd \u8111\u5352\u4e2d\u4e00\u822c\u53d1\u75c5\u6025\uff0c\u65e0\u5f81\u5146\u6216\u8005\u5f81\u5146\u4e0d\u660e\u663e\uff0c\u5927\u591a\u6570\u8111\u5352\u4e2d\u60a3\u8005\u4f34\u6709\u5634\u6b6a\u773c\u659c\u7684\u75c7\u72b6\u3002\u5f53\u524d\u7684\u667a\u80fd\u6444\u50cf\u5934\u6709 &hellip; <\/p>\n<p class=\"link-more\"><a href=\"https:\/\/thereisno.top\/?p=2846\" class=\"more-link\">\u7ee7\u7eed\u9605\u8bfb<span class=\"screen-reader-text\">\u201c\u4eba\u5de5\u667a\u80fd\u5728\u8111\u5352\u4e2d\u76d1\u63a7\u4e2d\u7684\u5e94\u7528\u201d<\/span><\/a><\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"closed","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[246,14],"tags":[170,298,247],"class_list":["post-2846","post","type-post","status-publish","format-standard","hentry","category-ai","category-python","tag-ai","tag-298","tag-247"],"_links":{"self":[{"href":"https:\/\/thereisno.top\/index.php?rest_route=\/wp\/v2\/posts\/2846","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/thereisno.top\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/thereisno.top\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/thereisno.top\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/thereisno.top\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=2846"}],"version-history":[{"count":1,"href":"https:\/\/thereisno.top\/index.php?rest_route=\/wp\/v2\/posts\/2846\/revisions"}],"predecessor-version":[{"id":2847,"href":"https:\/\/thereisno.top\/index.php?rest_route=\/wp\/v2\/posts\/2846\/revisions\/2847"}],"wp:attachment":[{"href":"https:\/\/thereisno.top\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=2846"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/thereisno.top\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=2846"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/thereisno.top\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=2846"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}