Quickstart
This section runs through the API for common tasks in machine learning. Refer to the links in each section to dive deeper.
Working with data
PyTorch has two primitives to work with data: torch.utils.data.DataLoader
and torch.utils.data.Dataset
. Dataset
stores the samples and their corresponding labels, and DataLoader
wraps an iterable around the Dataset
.
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
PyTorch offers domain-specific libraries such as TorchText, TorchVision, andTorchAudio, all of which include datasets. For this tutorial, we will be using a TorchVision dataset.
The torchvision.datasets
module contains Dataset
objects for many real-world vision data like CIFAR, COCO (full list here). In this tutorial, we use the FashionMNIST dataset. Every TorchVision Dataset
includes two arguments:transform
and target_transform
to modify the samples and labels respectively.
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
print(len(training_data))
60000
继续阅读“torch 快速入门”