PyTorch Lightning¶
Use TraceML with PyTorch Lightning to find training bottlenecks without changing your training loop.
TraceMLCallback adds step-aware diagnosis so you can quickly see whether a
run is input-bound, compute-bound, straggler-heavy, wait-heavy, or showing
memory drift.
1. Install¶
pip install "traceml-ai[lightning]"
2. Add TraceMLCallback¶
Initialize the Lightning integration once, then add TraceMLCallback to your Lightning Trainer. Everything else stays the same.
import lightning as L
from traceml_ai.integrations import lightning as traceml_lightning
traceml_lightning.init()
model = MyLightningModule()
trainer = L.Trainer(
max_steps=500,
accelerator="auto",
devices=1,
enable_progress_bar=False,
callbacks=[traceml_lightning.TraceMLCallback()],
)
trainer.fit(model, train_dataloaders=loader)
You do not need to add traceml.trace_step(...) manually. Lightning still owns
the training loop.
3. Launch The Run¶
Single GPU:
traceml run train.py
Single-node multi-GPU DDP:
traceml run train.py --nproc-per-node=4
For multi-node DDP launch commands, see Distributed Training.
For browser dashboard mode on single-node runs:
traceml run train.py --mode=dashboard
What TraceML will show¶
In Lightning runs, TraceML helps you spot:
- input-bound training
- compute-bound steps
- wait-heavy behavior
- rank imbalance and stragglers
- memory creep over time
You keep the normal Lightning workflow. TraceML adds diagnosis around the training step.
How it works¶
traceml_lightning.init() enables PyTorch DataLoader fetch timing and
installs the H2D .to(...) patch. TraceMLCallback records step, forward,
backward, optimizer, and memory timing. It also scopes H2D timing around
Lightning's internal strategy.batch_to_device(...) path.
Normal PyTorch DataLoader input timing is automatic after
traceml_lightning.init(). If you pass Lightning a custom iterator or
non-PyTorch loader, wrap it with traceml.wrap_dataloader_fetch(...) before
passing it to trainer.fit(...). For Ray Data with Lightning, see
Ray Train.
Small batches may show H2D 0.0ms because the transfer is below display
precision. The full example below uses a wider CPU tensor so H2D timing is
visible.
Use with Lightning loggers¶
TraceML works alongside Lightning loggers such as:
- W&B
- TensorBoard
- CSVLogger
For the cleanest terminal experience during diagnosis runs, it helps to use:
enable_progress_bar=False
You do not need to replace your existing logger stack to use TraceML.
Optional: local UI¶
If you want a richer browser-based view, run:
pip install "traceml-ai[dashboard]"
traceml run train.py --mode=dashboard
Dashboard mode is intended for single-node runs.
The local UI is useful when you want:
- a richer run review experience
- easier local comparison
- less terminal clutter
Trainer tips¶
These settings usually give the cleanest experience with TraceML:
| Setting | Recommended value | Why |
|---|---|---|
enable_progress_bar=False |
Yes | Prevents Lightning progress output from fighting with the TraceML CLI |
enable_model_summary=False |
Optional | Keeps terminal output cleaner |
logger=False |
Optional | Useful for local diagnosis runs if you want minimal output |
Full example¶
Save as train_lightning.py:
import lightning as L
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from traceml_ai.integrations import lightning as traceml_lightning
SEED = 42
MODEL_INPUT_DIM = 128
TRANSFER_INPUT_DIM = 131072
HIDDEN_DIM = 256
NUM_CLASSES = 10
NUM_SAMPLES = 512
BATCH_SIZE = 64
MAX_STEPS = 200
class SyntheticClassificationDataset(Dataset):
def __init__(self, num_samples: int):
# Transfer a wider CPU batch so Lightning H2D timing is visible, while
# the model below only consumes MODEL_INPUT_DIM features for compute.
self.x = torch.randn(num_samples, TRANSFER_INPUT_DIM)
self.y = torch.randint(0, NUM_CLASSES, (num_samples,))
def __len__(self) -> int:
return len(self.y)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
class TinyLightningModel(L.LightningModule):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(MODEL_INPUT_DIM, HIDDEN_DIM),
nn.ReLU(),
nn.Linear(HIDDEN_DIM, NUM_CLASSES),
)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
x = x[..., :MODEL_INPUT_DIM].contiguous()
return self.net(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
if self.global_step % 50 == 0:
print(f"Step {self.global_step} | loss={loss.item():.4f}")
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def main() -> None:
torch.manual_seed(SEED)
traceml_lightning.init()
dataset = SyntheticClassificationDataset(NUM_SAMPLES)
loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0,
pin_memory=torch.cuda.is_available(),
)
model = TinyLightningModel()
trainer = L.Trainer(
max_steps=MAX_STEPS,
accelerator="auto",
devices=1,
enable_progress_bar=False,
callbacks=[traceml_lightning.TraceMLCallback()],
logger=False,
)
trainer.fit(model, train_dataloaders=loader)
if __name__ == "__main__":
main()
Run with:
traceml run train_lightning.py
The checked-in examples/integrations/lightning_minimal.py also accepts small demo flags:
--devices, --num-nodes, --max-steps, --delay-rank, and --delay-ms.
Use the delay flags only when you want to create a deliberate straggler.
Gradient accumulation¶
TraceMLCallback supports gradient accumulation.
When Lightning uses accumulate_grad_batches=N, TraceML still preserves step alignment so the dashboard and summaries stay consistent.
Troubleshooting¶
Terminal output overlaps with TraceML¶
Set:
enable_progress_bar=False
This gives the TraceML CLI cleaner terminal control.
I still want W&B or TensorBoard¶
That is fine. TraceML is designed to work alongside them.
If terminal output gets noisy, use:
pip install "traceml-ai[dashboard]"
traceml run train.py --mode=dashboard
Dashboard mode is intended for single-node runs. For multi-node runs, use the default final summary path.
I want a baseline without TraceML¶
Run:
traceml run train_lightning.py --disable-traceml
This launches your script natively through torchrun without TraceML telemetry.
Next steps¶
- Read the Quickstart for plain PyTorch loops
- Read huggingface.md for Hugging Face Trainer integration
- Open an issue if you hit a problem: https://github.com/traceopt-ai/traceml/issues