Ray Train¶
TraceML's Ray integration is intentionally thin: Ray launches and manages the training workers, while TraceML observes those workers over its normal TCP telemetry path.
driver process
-> starts one TraceML aggregator actor
-> runs Ray TorchTrainer
-> Ray starts training workers
-> each worker starts one TraceML runtime
-> each worker sends telemetry to the aggregator actor
Ray still owns scheduling, worker placement, ranks, process groups, and DDP/NCCL/Gloo communication. TraceML does not replace Ray's launcher or reach into Ray Train internals.
Install¶
pip install "traceml-ai[ray]"
Minimal Usage¶
import ray
from ray.train import ScalingConfig
from traceml_ai.integrations.ray import TraceMLRayConfig, TraceMLTorchTrainer
def train_loop_per_worker(config):
import torch
import torch.nn as nn
import torch.optim as optim
from ray import train
import traceml_ai as traceml
model = nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 4))
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
train_ds = train.get_dataset_shard("train")
train_loader = train_ds.iter_torch_batches(
batch_size=64,
prefetch_batches=1,
)
train_loader = traceml.wrap_dataloader_fetch(train_loader)
for step, batch in enumerate(train_loader):
if step >= config["steps"]:
break
with traceml.trace_step(model):
x = batch["x"]
y = batch["y"].long()
optimizer.zero_grad(set_to_none=True)
loss = criterion(model(x), y)
loss.backward()
optimizer.step()
ray.init()
train_dataset = ray.data.from_items(
[{"x": [0.0] * 32, "y": 0} for _ in range(4096)]
)
trainer = TraceMLTorchTrainer(
train_loop_per_worker,
train_loop_config={"steps": 10},
scaling_config=ScalingConfig(num_workers=4, use_gpu=False),
datasets={"train": train_dataset},
traceml_config=TraceMLRayConfig(mode="summary"),
)
trainer.fit()
Use the same train_loop_per_worker shape you would pass to Ray's
TorchTrainer. The wrapper starts TraceML before your loop runs and stops it
after the loop exits.
When using Ray Data, wrap the iter_torch_batches(...) iterator with
traceml.wrap_dataloader_fetch(...). Ray Data is not a PyTorch
DataLoader, so the PyTorch DataLoader patch cannot see those fetches.
Example scripts¶
Use --ray-address=auto when you already have a Ray cluster running, or omit
it for a local Ray run.
Minimal Ray Train example:
python examples/ray/torchtrainer_minimal.py \
--ray-address=auto \
--num-workers=2 \
--steps=100 \
--use-gpu
To make input timing visible in the minimal example:
python examples/ray/torchtrainer_minimal.py \
--ray-address=auto \
--num-workers=2 \
--steps=100 \
--use-gpu \
--input-delay-ms=100
Ray + Lightning¶
When combining Ray Train and PyTorch Lightning, add TraceMLCallback() to the
Lightning Trainer and keep wrapping Ray Data iterators with
traceml.wrap_dataloader_fetch(...). To capture Lightning H2D timing inside
Ray workers, initialize the worker patches selectively:
TraceMLRayConfig(
mode="summary",
init_mode="selective",
patch_dataloader=True,
patch_h2d=True,
)
The examples/ray/lightning_text_classifier.py demo also includes
--input-delay-ms / --input-delay-rank for input-straggler demos,
--delay-ms / --delay-rank for compute-straggler demos, and
--transfer-dim to make Lightning H2D timing visible.
--transfer-dim creates a reusable per-batch CPU tensor; it does not add a
full dataset-sized tensor.
Baseline Ray + Lightning run:
python examples/ray/lightning_text_classifier.py \
--ray-address=auto \
--num-workers=2 \
--max-steps=100 \
--use-gpu
Input-straggler demo:
python examples/ray/lightning_text_classifier.py \
--ray-address=auto \
--num-workers=2 \
--max-steps=100 \
--use-gpu \
--input-delay-rank=0 \
--input-delay-ms=200
Compute-straggler demo:
python examples/ray/lightning_text_classifier.py \
--ray-address=auto \
--num-workers=2 \
--max-steps=100 \
--use-gpu \
--delay-rank=0 \
--delay-ms=200
For CPU-only runs, remove --use-gpu.
Network Model¶
The aggregator runs as a normal Ray actor and binds a TCP server. By default it
binds 0.0.0.0 on port 0:
0.0.0.0lets workers on other Ray nodes connect to the actor node.- port
0lets the operating system choose a free port. - workers receive the actor's reachable node IP and chosen port through the wrapped trainer.
If your cluster requires a fixed open port, set it explicitly:
TraceMLRayConfig(port=29765)
Configuration¶
TraceMLRayConfig(
mode="summary",
profile="run",
init_mode="auto",
patch_dataloader=None,
patch_forward=None,
patch_backward=None,
patch_h2d=None,
logs_dir="./logs",
session_id="",
sampler_interval_sec=1.0,
bind_host="0.0.0.0",
port=0,
)
The default mode="summary" is recommended for Ray because distributed worker
logs are noisy. Use mode="cli" only when you specifically want live terminal
rendering from the aggregator actor.
init_mode is passed to traceml.init(mode="auto") inside each Ray
worker. The Ray Data wrap_dataloader_fetch(...) pattern above works with
the default auto mode because Ray Data iterators are separate from PyTorch
DataLoader. Use init_mode="manual" only if your training loop wraps
dataloader, forward, backward, and optimizer timing explicitly. Use
init_mode="selective" with the patch_* options when you only want some
automatic patches.
Lifecycle¶
TraceMLTorchTrainer.fit() starts the aggregator actor, runs Ray Train, and
then stops the actor in a finally block. Each worker also stops its local
TraceML runtime in a finally block. Normal exceptions and keyboard
interrupts should therefore release TraceML resources. A hard SIGKILL cannot
run Python cleanup code in any framework.