Skip to content

PyTorch Lightning

Use TraceML with PyTorch Lightning to find training bottlenecks without changing your training loop.

TraceMLCallback adds step-aware diagnosis so you can quickly see whether a run is input-bound, compute-bound, straggler-heavy, wait-heavy, or showing memory drift.

Start with the Quickstart if you have not used TraceML yet.


Install

Install TraceML with Lightning support:

pip install "traceml-ai[lightning]"

Basic usage

Add TraceMLCallback to your Lightning Trainer. Everything else stays the same.

import lightning as L
from traceml.integrations.lightning import TraceMLCallback

model = MyLightningModule()

trainer = L.Trainer(
    max_steps=500,
    accelerator="auto",
    devices=1,
    enable_progress_bar=False,
    callbacks=[TraceMLCallback()],
)

trainer.fit(model, train_dataloaders=loader)

Run with:

traceml run train.py

Or open the local UI:

traceml run train.py --mode=dashboard

What TraceML will show

In Lightning runs, TraceML helps you spot:

  • input-bound training
  • compute-bound steps
  • wait-heavy behavior
  • rank imbalance and stragglers
  • memory creep over time

You keep the normal Lightning workflow. TraceML adds diagnosis around the training step.


How it works

TraceMLCallback hooks into Lightning’s training lifecycle automatically.

That means you do not need to wrap your code with traceml.trace_step(...) manually in Lightning.


Use with Lightning loggers

TraceML works alongside Lightning loggers such as:

  • W&B
  • TensorBoard
  • CSVLogger

For the cleanest terminal experience during diagnosis runs, it helps to use:

enable_progress_bar=False

You do not need to replace your existing logger stack to use TraceML.


Optional: local UI

If you want a richer browser-based view, run:

traceml run train.py --mode=dashboard

The local UI is useful when you want:

  • a richer run review experience
  • easier local comparison
  • less terminal clutter

Trainer tips

These settings usually give the cleanest experience with TraceML:

Setting Recommended value Why
enable_progress_bar=False Yes Prevents Lightning progress output from fighting with the TraceML CLI
enable_model_summary=False Optional Keeps terminal output cleaner
logger=False Optional Useful for local diagnosis runs if you want minimal output

Optional: deeper layer-level signals

Use this only for short diagnostic runs when step-level diagnosis already told you where to dig.

Since TraceMLCallback handles step-level tracing, deeper layer-level hooks are added separately with traceml.trace_model_instance(...).

import traceml
import lightning as L
import torch.nn as nn


class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = MyCoreModel()
        self.loss_fn = nn.CrossEntropyLoss()

        traceml.trace_model_instance(
            self,
            trace_layer_forward_memory=True,
            trace_layer_backward_memory=True,
            trace_layer_forward_time=True,
            trace_layer_backward_time=True,
        )

Use this when you want:

  • per-layer timing
  • per-layer memory detail
  • a short follow-up diagnostic run

Hooks add overhead, so keep them off for normal runs unless you need them.


Full example

Save as train_lightning.py:

import lightning as L
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

from traceml.integrations.lightning import TraceMLCallback

SEED = 42
INPUT_DIM = 128
HIDDEN_DIM = 256
NUM_CLASSES = 10
NUM_SAMPLES = 4096
BATCH_SIZE = 64
MAX_STEPS = 200


class SyntheticClassificationDataset(Dataset):
    def __init__(self, num_samples: int):
        self.x = torch.randn(num_samples, INPUT_DIM)
        self.y = torch.randint(0, NUM_CLASSES, (num_samples,))

    def __len__(self) -> int:
        return len(self.y)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


class TinyLightningModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(INPUT_DIM, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, NUM_CLASSES),
        )
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)

        if batch_idx % 50 == 0:
            print(f"Step {batch_idx} | loss={loss.item():.4f}")

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


def main() -> None:
    torch.manual_seed(SEED)

    dataset = SyntheticClassificationDataset(NUM_SAMPLES)
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0,
    )

    model = TinyLightningModel()

    trainer = L.Trainer(
        max_steps=MAX_STEPS,
        accelerator="auto",
        devices=1,
        enable_progress_bar=False,
        callbacks=[TraceMLCallback()],
        logger=False,
    )

    trainer.fit(model, train_dataloaders=loader)


if __name__ == "__main__":
    main()

Run with:

traceml run train_lightning.py

Gradient accumulation

TraceMLCallback supports gradient accumulation.

When Lightning uses accumulate_grad_batches=N, TraceML still preserves step alignment so the dashboard and summaries stay consistent.


Troubleshooting

Terminal output overlaps with TraceML

Set:

enable_progress_bar=False

This gives the TraceML CLI cleaner terminal control.

I still want W&B or TensorBoard

That is fine. TraceML is designed to work alongside them.

If terminal output gets noisy, use:

traceml run train.py --mode=dashboard

I want a baseline without TraceML

Run:

traceml run train_lightning.py --disable-traceml

This launches your script natively through torchrun without TraceML telemetry.


Next steps

  • Read the Quickstart for plain PyTorch loops
  • Read huggingface.md for Hugging Face Trainer integration
  • Open an issue if you hit a problem: https://github.com/traceopt-ai/traceml/issues