Skip to content

Public API

The stable surface that user code imports and calls. Everything in this page is covered by TraceML's compatibility contract across v0.x minor releases.

Core API

import traceml_ai as traceml

traceml.init(mode="auto")

with traceml.trace_step(model):
    ...

The old import traceml path remains available for compatibility, but emits a FutureWarning and may be removed in a future release. Do not import from decorator compatibility paths.

Hugging Face integration

traceml_ai.integrations.huggingface.init

init()

Initialize TraceML for Hugging Face Trainer runs.

Call once before constructing the Trainer, then register TraceMLTrainerCallback. init() makes TraceML's process-wide instrumentation explicit: PyTorch DataLoader fetch timing, the H2D Tensor.to patch, and the forward/backward/optimizer auto-timers that trace_step arms inside each bracketed step.

The callback is a per-step bracket and cannot install these process-wide patches on its own; the auto-timers it arms are no-ops unless the matching patch is installed. init() is the recommended entry point so the DataLoader fetch patch in particular is installed deterministically rather than relying on import order. This mirrors the PyTorch Lightning integration's init(); HF uses mode="auto" because trace_step drives forward/backward timing through the patch-gated auto-timers, whereas Lightning's callback owns that timing directly.

Source code in src/traceml_ai/integrations/huggingface.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def init():
    """
    Initialize TraceML for Hugging Face ``Trainer`` runs.

    Call once before constructing the ``Trainer``, then register
    ``TraceMLTrainerCallback``. ``init()`` makes TraceML's process-wide
    instrumentation explicit: PyTorch ``DataLoader`` fetch timing, the H2D
    ``Tensor.to`` patch, and the forward/backward/optimizer auto-timers that
    ``trace_step`` arms inside each bracketed step.

    The callback is a per-step bracket and cannot install these process-wide
    patches on its own; the auto-timers it arms are no-ops unless the matching
    patch is installed. ``init()`` is the recommended entry point so the
    DataLoader fetch patch in particular is installed deterministically rather
    than relying on import order. This mirrors the PyTorch Lightning
    integration's ``init()``; HF uses ``mode="auto"`` because ``trace_step``
    drives forward/backward timing through the patch-gated auto-timers, whereas
    Lightning's callback owns that timing directly.
    """
    import traceml_ai as traceml

    return traceml.init(mode="auto")

traceml_ai.integrations.huggingface.TraceMLTrainerCallback

TraceMLTrainerCallback()

Bases: TrainerCallback if HAS_TRANSFORMERS else object

Preferred Hugging Face integration for TraceML.

Register with Trainer(..., callbacks=[TraceMLTrainerCallback()]).

The callback is a pure bracket around TraceML's trace_step context manager: it opens trace_step in on_step_begin and closes it in on_step_end. trace_step owns the step memory tracker, the step counter advance, the auto-timers for forward/backward/h2d, and the per-step flush. Nothing is duplicated here.

One TraceML step equals one optimizer step. With gradient_accumulation_steps > 1, forward and backward events from all accumulated micro-batches fold into a single TraceML step. See the HF integration docs for the full list of limitations vs. TraceMLTrainer.

Source code in src/traceml_ai/integrations/huggingface.py
89
90
91
92
93
94
95
96
def __init__(self) -> None:
    if not HAS_TRANSFORMERS:
        raise ImportError(
            "TraceMLTrainerCallback requires 'transformers' to be "
            "installed. Please run `pip install transformers`."
        )
    super().__init__()
    self._step_cm = None

traceml_ai.integrations.huggingface.TraceMLTrainer

TraceMLTrainer(*args, traceml_enabled: bool = True, **kwargs)

Bases: Trainer if HAS_TRANSFORMERS else object

Thin wrapper around transformers.Trainer that auto-installs TraceMLTrainerCallback.

Kept for backward compatibility with users on the original TraceML HF integration API. New code should prefer Trainer(..., callbacks=[TraceMLTrainerCallback()]) directly.

Source code in src/traceml_ai/integrations/huggingface.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def __init__(
    self,
    *args,
    traceml_enabled: bool = True,
    **kwargs,
):
    if not HAS_TRANSFORMERS:
        raise ImportError(
            "TraceMLTrainer requires 'transformers' to be installed. "
            "Please run `pip install transformers`."
        )

    super().__init__(*args, **kwargs)
    self.traceml_enabled = traceml_enabled

    if not traceml_enabled or _traceml_disabled():
        return

    # Dedup guard: a user passing callbacks=[TraceMLTrainerCallback()] to
    # TraceMLTrainer would otherwise double-instrument every step.
    existing = getattr(self.callback_handler, "callbacks", [])
    if any(isinstance(cb, TraceMLTrainerCallback) for cb in existing):
        return

    self.add_callback(TraceMLTrainerCallback())

PyTorch Lightning integration

traceml_ai.integrations.lightning.init

init()

Initialize TraceML for PyTorch Lightning runs.

Lightning owns the training loop, so TraceMLCallback owns step boundaries, flushing, and framework hook integration. The integration init enables DataLoader fetch timing plus the H2D Tensor.to patch. The callback turns H2D timing on only around Lightning's batch transfer hooks and wraps LightningModule.forward directly for model-forward timing.

Source code in src/traceml_ai/integrations/lightning.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def init():
    """
    Initialize TraceML for PyTorch Lightning runs.

    Lightning owns the training loop, so TraceMLCallback owns step boundaries,
    flushing, and framework hook integration. The integration init enables
    DataLoader fetch timing plus the H2D Tensor.to patch. The callback turns H2D
    timing on only around Lightning's batch transfer hooks and wraps
    LightningModule.forward directly for model-forward timing.
    """
    import traceml_ai as traceml

    return traceml.init(
        mode="selective",
        patch_dataloader=True,
        patch_h2d=True,
    )

traceml_ai.integrations.lightning.TraceMLCallback

TraceMLCallback()

Bases: Callback

Official TraceML Callback for PyTorch Lightning.

Captures full step time (forward + backward + optimizer) as well as individual phase timings. Safely handles gradient accumulation by treating each micro-batch as a step, providing 0-duration optimizer events on accumulating steps to preserve dashboard step alignment.

Source code in src/traceml_ai/integrations/lightning.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def __init__(self):
    if not IS_LIGHTNING_AVAILABLE:
        raise ImportError(
            "Install traceml[lightning] to use Lightning integration"
        )
    super().__init__()
    self._traceml_step_ctx = None
    self._backward_ctx = None
    self._optimizer_ctx = None
    self._batch_to_device_strategy = None
    self._original_batch_to_device = None
    self._forward_module = None
    self._original_forward = None
    self._original_forward_attr = _MISSING
    self._wrapped_forward = None

    self._mem_tracker = None
    self._opt_step_occurred = False

Ray Train integration

traceml_ai.integrations.ray.TraceMLTorchTrainer

TraceMLTorchTrainer(train_loop_per_worker: TrainLoop, *, train_loop_config: Optional[Dict[str, Any]] = None, traceml_config: Optional[TraceMLRayConfig] = None, **torch_trainer_kwargs: Any)

TraceML wrapper for ray.train.torch.TorchTrainer.

The wrapper deliberately uses composition instead of subclassing. Ray's TorchTrainer still owns training orchestration; TraceML only adds:

  • one aggregator actor before fit()
  • one runtime wrapper inside each worker
  • best-effort aggregator shutdown after fit() completes or fails
Source code in src/traceml_ai/integrations/ray.py
295
296
297
298
299
300
301
302
303
304
305
306
307
def __init__(
    self,
    train_loop_per_worker: TrainLoop,
    *,
    train_loop_config: Optional[Dict[str, Any]] = None,
    traceml_config: Optional[TraceMLRayConfig] = None,
    **torch_trainer_kwargs: Any,
) -> None:
    self._train_loop_per_worker = train_loop_per_worker
    self._train_loop_config = dict(train_loop_config or {})
    self._traceml_config = traceml_config or TraceMLRayConfig()
    self._torch_trainer_kwargs = dict(torch_trainer_kwargs)
    self._last_endpoint: Optional[AggregatorEndpoint] = None

last_endpoint property

last_endpoint: Optional[AggregatorEndpoint]

Aggregator endpoint used by the most recent fit() call.

fit

fit() -> Any

Run Ray Train with TraceML telemetry enabled.

Source code in src/traceml_ai/integrations/ray.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def fit(self) -> Any:
    """Run Ray Train with TraceML telemetry enabled."""
    ray, TorchTrainer = _require_ray()
    config = _normalize_config(self._traceml_config)

    Actor = ray.remote(num_cpus=1)(_TraceMLAggregatorActor)
    actor = Actor.remote(config)

    try:
        endpoint = _endpoint_from_mapping(ray.get(actor.endpoint.remote()))
        self._last_endpoint = endpoint
        wrapped_loop = _TraceMLWorkerLoop(
            train_loop_per_worker=self._train_loop_per_worker,
            endpoint=endpoint,
            config=config,
        )
        trainer = TorchTrainer(
            wrapped_loop,
            train_loop_config=dict(self._train_loop_config),
            **self._torch_trainer_kwargs,
        )
        return trainer.fit()
    finally:
        _stop_actor_best_effort(ray, actor)

traceml_ai.integrations.ray.TraceMLRayConfig dataclass

TraceMLRayConfig(mode: str = 'summary', profile: str = 'run', init_mode: str = 'auto', patch_dataloader: Optional[bool] = None, patch_forward: Optional[bool] = None, patch_backward: Optional[bool] = None, patch_h2d: Optional[bool] = None, logs_dir: str = './logs', session_id: str = '', sampler_interval_sec: float = 1.0, summary_window_rows: int = DEFAULT_SUMMARY_WINDOW_ROWS, bind_host: str = '0.0.0.0', port: int = 0, stop_timeout_sec: float = 5.0)

TraceML settings used by the Ray Train integration.

Parameters:

Name Type Description Default
mode str

TraceML display/reporting mode. "summary" is the default for Ray because distributed worker logs are often noisy.

'summary'
profile str

TraceML sampler profile. Public Ray integration uses the normal "run" profile.

'run'
init_mode str

Instrumentation mode passed to traceml.init() inside each worker.

'auto'
patch_dataloader Optional[bool]

Selective-mode-only override for DataLoader fetch patching.

None
patch_forward Optional[bool]

Selective-mode-only override for forward-pass patching.

None
patch_backward Optional[bool]

Selective-mode-only override for backward-pass patching.

None
patch_h2d Optional[bool]

Selective-mode-only override for host-to-device transfer patching.

None
logs_dir str

Directory where TraceML writes session logs and summary artifacts.

'./logs'
session_id str

Optional explicit TraceML session id. If omitted, a unique Ray session id is generated for each fit() call.

''
sampler_interval_sec float

Background sampler cadence in seconds.

1.0
summary_window_rows int

Number of recent history rows used by final summary generation.

DEFAULT_SUMMARY_WINDOW_ROWS
bind_host str

Host interface used by the aggregator actor. Use "0.0.0.0" for multi-node Ray clusters so workers on other nodes can connect.

'0.0.0.0'
port int

Aggregator TCP port. 0 asks the OS to pick a free port.

0
stop_timeout_sec float

Best-effort timeout for aggregator shutdown.

5.0

CLI

TraceML ships with a CLI entry point installed as traceml.

traceml run <script>                 # default: final summary JSON/text
traceml run <script> --mode=cli      # live terminal view
traceml run <script> --mode=dashboard # live browser view
traceml watch <script>               # zero-code system/process view

Live cli and dashboard modes are intended for single-node runs. For multi-node runs, use the default summary mode. Dashboard mode requires the optional dashboard extra: pip install "traceml-ai[dashboard]".

TraceML no longer ships layer-level/deep profiling. Use PyTorch Profiler, Nsight, or another operator-level profiler when you need that detail.

See traceml --help for the full set of options.

Summary APIs

traceml.summary()

Returns a compact flat dict for experiment trackers such as W&B, MLflow, or internal dashboards. Call it near the end of training; it reuses the canonical final_summary.json if one already exists.

summary = traceml.summary(print_text=True)
if summary is not None:
    wandb.log(summary)

traceml.final_summary()

Returns the full final_summary.json payload. Use this when you need the complete structured report or want to store the artifact for traceml compare. TraceML generates this canonical artifact once per run and reuses it on later calls.