pytorch-lightning tutorial
Project description
⚡ lightning-tutorial
Installation of the partner package
pip install lightning-tutorial
Table of contents
-
PyTorch Datasets and DataLoaders
- Key module:
torch.utils.data.Dataset
- Key module:
torch.utils.data.DataLoader
- Other essential functions
- Key module:
PyTorch Datasets and DataLoaders
Key module: torch.utils.data.Dataset
The Dataset
module is an overwritable python module. You can modify it at will as long as you maintain the following three class methods:
__init__
__len__
__getitem__
These are name-specific handles used by torch
under the hood when passing data through a model.
from torch.utils.data import Dataset
class TurtleData(Dataset):
def __init__(self):
"""
here we should pass requisite arguments
that enable __len__() and __getitem__()
"""
def __len__(self):
"""
Returns the length/size/# of samples in the dataset.
e.g., a 20,000 cell dataset would return `20_000`.
"""
return # len
def __getitem__(self, idx):
"""
Subset and return a batch of the data.
`idx` is the batch index (# of idx values = batch size).
Maximum `idx` passed is <= `self.__len__()`
"""
return # sampled data
-
Try it for yourself! Colab
Dataset
tutorial notebook
Key module: torch.utils.data.DataLoader
Similar to the usefulness of AnnData
, the Dataset
module creates a base unit for distributing and handling data. We can then take advantage of several torch built-ins to enable not only more organized, but faster data processing.
from torch.utils.data import DataLoader
dataset = TurtleData()
data_size = dataset.__len__()
print(data_size)
20_000
Other essential functions
from torch.utils.data import random_split
train_dataset, val_dataset = random_split(dataset, [18_000, 2_000])
# this can then be fed to a DataLoader, as above
train_loader = DataLoader(train_dataset)
val_loader = DataLoader(val_dataset)
Useful tutorials and documentation
- Parent module:
torch.utils.data
- Datasets and DataLoaders tutorial
Single-cell data structures meet pytorch: torch-adata
Create pytorch Datasets from AnnData
Installation
- Note: This is already done for you, if you've installed this tutorials associated package
pip install torch-adata
Example use of the base class
The base class, AnnDataset
is a subclass of the widely-used torch.utils.data.Dataset
.
import anndata as a
import torch_adata
adata = a.read_h5ad("/path/to/data.h5ad")
dataset = torch_adata.AnnDataset(adata)
Returns sampled data X_batch
as a torch.Tensor
.
# create a dummy index
idx = np.random.choice(range(dataset.__len__()), 5)
X_batch = dataset.__getitem__(idx)
TimeResolvedAnnDataset
Specialized class for time-resolved datasets. A subclass of the class, AnnDataset
.
import anndata as a
import torch_adata as ta
adata = a.read_h5ad("/path/to/data.h5ad")
dataset = torch_adata.TimeResolvedAnnDataset(adata, time_key="Time point")
Lightning basics and the LightningModule
from pytorch_lightning imoport LightningModule
class YourSOTAModel(LightningModule):
def __init__(self,
net,
optimizer_kwargs={"lr":1e-3},
scheduler_kwargs={},
):
super().__init__()
self.net = net
self.optimizer_kwargs = optimizer_kwargs
self.scheduler_kwargs = scheduler_kwargs
def forward(self, batch):
x, y = batch
y_hat = self.net(x)
loss = LossFunc(y_hat, y)
return y_hat, loss
def training_step(self, batch, batch_idx):
y_hat, loss = self.forward(batch)
return loss.sum()
def validation_step(self, batch, batch_idx):
y_hat, loss = self.forward(batch)
return loss.sum()
def test_step(self, batch, batch_idx):
y_hat, loss = self.forward(batch)
return loss.sum()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), **self._optim_kwargs)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer(), **self._scheduler_kwargs)
return [optimizer, ...], [scheduler, ...]
Additional useful documentation and standalone tutorials
LightningDataModule
Purpose: Make your model independent of a given dataset, while at the same time making your dataset reproducible and perhaps just as important: easily shareable.
from pytorch_lightning import LightningDataModule
from torch.data.utils import DataLoader
class YourDataModule(LightningDataModule):
def __init__(self):
# define any setup computations
def prepare_data(self):
# download data if applicable
def setup(self, stage):
# assign data to `Dataset`(s)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
When it comes to actually using one of these, it looks something like the following:
# Init the LightningDataModule as well as the LightningModel
data = YourDataModule()
model = YourLightningModel()
# Define trainer
trainer = Trainer(accelerator="auto", devices=1)
# Ultimately, both model and data are passed as an arg to trainer.fit
trainer.fit(model, data)
Here's an example of a LightningDataModule
implemented in practice, using the LARRY single-cell dataset: link. Initial downloading and formatting occurs only once but takes several minutes so we will leave it outside the scope of this tutorial.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for lightning_tutorial-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f487cdfe9b850d7be83bb5605cbdab3619647f28e0a3d6a3e41e2c8a52c6d245 |
|
MD5 | 385e05e49d6a348e9bf646810f38815a |
|
BLAKE2b-256 | c9eb7656ba3abc165f7b47cd713abde3c5204903f86b70238c7639473403b968 |