API Reference

Understand and debug torch.compile

Warning

It is recommended to read the A Walk Through Example of torch.compile to have a basic understanding of how torch.compile works, before using the following functions.

depyf.prepare_debug(dump_src_dir, clean_wild_fx_code=True, log_bytecode=False)[source]

A context manager to dump debugging information for torch.compile. It should wrap the code that actually triggers the compilation, rather than the code that applies torch.compile.

Example:

import torch

@torch.compile
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

def main():
    for _ in range(100):
        toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
    # main()
    # surround the code you want to run inside `with depyf.prepare_debug`
    import depyf
    with depyf.prepare_debug("./dump_src_dir"):
        main()

After running the code, you will find the dumped information in the directory dump_src_dir. The details are organized into the following:

  • full_code_for_xxx.py for each function using torch.compile

  • __transformed_code_for_xxx.py for Python code associated with each graph.

  • __transformed_code_for_xxx.py.xxx_bytecode for Python bytecode, dumped code object, can be loaded via dill.load(open("/path/to/file", "wb")). Note that the load function might import some modules like transformers. Make sure you have these modules installed.

  • __compiled_fn_xxx.py for each computation graph and its optimization:
    • Captured Graph: a plain forward computation graph

    • Joint Graph: joint forward-backward graph from AOTAutograd

    • Forward Graph: forward graph from AOTAutograd

    • Backward Graph: backward graph from AOTAutograd

    • kernel xxx: compiled CPU/GPU kernel wrapper from Inductor.

Arguments:

  • dump_src_dir: the directory to dump the source code.

  • clean_wild_fx_code: whether to clean the wild fx code that are not recognized for parts of compiled functions. They are usually used by PyTorch internally.

  • log_bytecode: whether to log bytecode (original bytecode, transformed bytecode from Dynamo, and decompiled_and_compiled_back_code).

depyf.debug()[source]

A context manager to debug the compiled code. Essentially, it sets a breakpoint to pause the program and allows you to check the full source code in files with prefix full_code_for_ in the dump_src_dir argument of depyf.prepare_debug(), and set breakpoints in their separate __transformed_code_ files according to the function name. Then continue your debugging.

Decompile general Python Bytecode/Function

depyf.decompile(code: CodeType | Callable) str[source]

Decompile any callable or code object into Python source code. It is especially useful for some dynamically generated code, like torch.compile, or dataclasses.

Example usage:

from dataclasses import dataclass
@dataclass
class Data:
    x: int
    y: float

import depyf
print(depyf.decompile(Data.__init__))
print(depyf.decompile(Data.__eq__))

Output:

def __init__(self, x, y):
    self.x = x
    self.y = y
    return None

def __eq__(self, other):
    if other.__class__ is self.__class__:
        return (self.x, self.y) == (other.x, other.y)
    return NotImplemented

The output source code is semantically equivalent to the function, but not syntactically the same. It verbosely adds many details that are hidden in the Python code. For example, the above output code of __init__ explicitly returns None, which is typically ignored.

Another detail is that the output code of __eq__ returns NotImplemented instead of raising NotImplemented exception when the types are different. At the first glance, it seems to be a bug. However, it is actually the correct behavior. The __eq__ method should return NotImplemented when the types are different, so that the other object can try to compare with the current object. See the Python documentation for more details.

Enhance PyTorch Logging

depyf.install()[source]

Install the bytecode hook for PyTorch, integrate into PyTorch’s logging system.

Example:

import torch
import depyf
depyf.install()
# anything with torch.compile
@torch.compile
def f(a, b):
    return a + b
f(torch.tensor(1), torch.tensor(2))

Turn on bytecode log by export TORCH_LOGS="+bytecode", and execute the script. We will see the decompiled source code in the log:

ORIGINAL BYTECODE f test.py line 5 
7           0 LOAD_FAST                0 (a)
            2 LOAD_FAST                1 (b)
            4 BINARY_ADD
            6 RETURN_VALUE


MODIFIED BYTECODE f test.py line 5 
5           0 LOAD_GLOBAL              0 (__compiled_fn_1)
            2 LOAD_FAST                0 (a)
            4 LOAD_FAST                1 (b)
            6 CALL_FUNCTION            2
            8 UNPACK_SEQUENCE          1
            10 RETURN_VALUE


possible source code:
def f(a, b):
    __temp_2, = __compiled_fn_1(a, b)
    return __temp_2

If you find the decompiled code is wrong,please submit an issue at https://github.com/thuml/depyf/issues.

To uninstall the hook, use depyf.uninstall().

depyf.uninstall()[source]

Uninstall the bytecode hook for PyTorch. Should be called after depyf.install().