Skip to content

An fd.add_output with alias fails in iter_types_match #5431

@crcrpar

Description

@crcrpar

https://gist.github.com/crcrpar/e5ac4212e48e5b8846653d34e5c0857e has full log

[rank1]:[rank1]: Traceback (most recent call last):
[rank1]:[rank1]:   File "/opt/pytorch/lightning-thunder/./run_simple_fsdp.py", line 92, in <module>
[rank1]:[rank1]:     main()
[rank1]:[rank1]:   File "/opt/pytorch/lightning-thunder/./run_simple_fsdp.py", line 78, in main
[rank1]:[rank1]:     simple_fsdp_model(input_ids=input_ids, attention_mask=attention_mask)
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 418, in __call__
[rank1]:[rank1]:     return super().__call__(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank1]:[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank1]:[rank1]:     return forward_call(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 886, in compile_wrapper
[rank1]:[rank1]:     return fn(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank1]:[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank1]:[rank1]:     return forward_call(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 953, in wrapper
[rank1]:[rank1]:     @wraps(func)
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 418, in __call__
[rank1]:[rank1]:     return super().__call__(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank1]:[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank1]:[rank1]:     return forward_call(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1098, in _fn
[rank1]:[rank1]:     return fn(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 838, in call_wrapped
[rank1]:[rank1]:     return self._wrapped_call(self, *args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 413, in __call__
[rank1]:[rank1]:     raise e
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 400, in __call__
[rank1]:[rank1]:     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank1]:[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank1]:[rank1]:     return forward_call(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "<eval_with_key>.7971", line 6, in forward
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank1]:[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank1]:[rank1]:     return forward_call(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/opt/pytorch/lightning-thunder/thunder/core/module.py", line 80, in forward
[rank1]:[rank1]:     res = self._forward_fn(*args, **kwargs)
[rank1]:[rank1]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 839, in wrapped
[rank1]:[rank1]:     return fn(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 889, in fn_
[rank1]:[rank1]:     result = cache_entry.computation_fn(*inps)
[rank1]:[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 794, in wrapped
[rank1]:[rank1]:     return fn(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 121, in decorate_context
[rank1]:[rank1]:     return func(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/opt/pytorch/lightning-thunder/thunder/executors/torchex.py", line 169, in no_autocast_fn
[rank1]:[rank1]:     return fn(*args, **kwargs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "thunder.computation_3", line 37, in computation
[rank1]:[rank1]:   File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 546, in __call__
[rank1]:[rank1]:     self.last_used = self.get_fd(self.to_descriptors(args))
[rank1]:[rank1]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 624, in get_fd
[rank1]:[rank1]:     return create_fd(bsyms, input_descriptors, sorted_unique_inputs, sorted_unique_outputs)
[rank1]:[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 409, in create_fd
[rank1]:[rank1]:     definition(fd)
[rank1]:[rank1]:   File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 369, in definition
[rank1]:[rank1]:     translate_bound_symbol(bsym)
[rank1]:[rank1]:   File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 359, in translate_bound_symbol
[rank1]:[rank1]:     nvresults = translator(*bsym.args, **bsym.kwargs, fd=fd, lc_to_nv_map=lc_to_nv_map)
[rank1]:[rank1]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]:   File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 2246, in copy_
[rank1]:[rank1]:     fd.add_output(alias_output, alias_input=nvcopy_to)
[rank1]:[rank1]:   File "/opt/pytorch/nvfuser/python/nvfuser_direct/__init__.py", line 221, in add_output
[rank1]:[rank1]:     self._fusion.add_output(*args, **kwargs)
[rank1]:[rank1]: RuntimeError:  INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:280, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. 
[rank1]:[rank1]: Expected iter_types_match(id->getIterType(), new_id->getIterType()) . Axes bS4{1 ex 2} and iS8{2} do not match for self replay.
[rank1]:[rank1]: Exception raised from selfReplay at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:280 (most recent call first):
[rank1]:[rank1]: frame #0: nvfuser::nvfCheckFail(char const*, char const*, long, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x110 (0xfffc20c24db4 in /opt/pytorch/nvfuser/python/nvfuser_direct/../build/libnvfuser_codegen.so)
[rank1]:[rank1]: frame #1: nvfuser::nvfErrorFail(char const*, char const*, long, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x88 (0xfffc20f1b668 in /opt/pytorch/nvfuser/python/nvfuser_direct/../build/libnvfuser_codegen.so)
[rank1]:[rank1]: frame #2: <unknown function> + 0xf7ba50 (0xfffc2153ba50 in /opt/pytorch/nvfuser/python/nvfuser_direct/../build/libnvfuser_codegen.so)
[rank1]:[rank1]: frame #3: nvfuser::Fusion::aliasOutputToInput(nvfuser::Val*, nvfuser::Val*, nvfuser::AllocationType) + 0x130 (0xfffc20f62bf4 in /opt/pytorch/nvfuser/python/nvfuser_direct/../build/libnvfuser_codegen.so)
[rank1]:[rank1]: frame #4: <unknown function> + 0x118a98 (0xfffc236a8a98 in /opt/pytorch/nvfuser/python/nvfuser_direct/_C_DIRECT.cpython-312-aarch64-linux-gnu.so)
[rank1]:[rank1]: frame #5: <unknown function> + 0x3b5e4 (0xfffc235cb5e4 in /opt/pytorch/nvfuser/python/nvfuser_direct/_C_DIRECT.cpython-312-aarch64-linux-gnu.so)
[rank1]:[rank1]: frame #6: /usr/bin/python3() [0x503454]
[rank1]:[rank1]: frame #7: _PyObject_MakeTpCall + 0x7c (0x4c2d1c in /usr/bin/python3)
[rank1]:[rank1]: frame #8: /usr/bin/python3() [0x4c7000]
[rank1]:[rank1]: frame #9: PyObject_Call + 0x118 (0x4c5278 in /usr/bin/python3)
[rank1]:[rank1]: frame #10: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #11: /usr/bin/python3() [0x5ef170]
[rank1]:[rank1]: frame #12: _PyObject_MakeTpCall + 0x7c (0x4c2d1c in /usr/bin/python3)
[rank1]:[rank1]: frame #13: _PyEval_EvalFrameDefault + 0x8a0 (0x563824 in /usr/bin/python3)
[rank1]:[rank1]: frame #14: _PyObject_Call_Prepend + 0xc4 (0x4c47c4 in /usr/bin/python3)
[rank1]:[rank1]: frame #15: /usr/bin/python3() [0x528970]
[rank1]:[rank1]: frame #16: _PyObject_MakeTpCall + 0x7c (0x4c2d1c in /usr/bin/python3)
[rank1]:[rank1]: frame #17: _PyEval_EvalFrameDefault + 0x8a0 (0x563824 in /usr/bin/python3)
[rank1]:[rank1]: frame #18: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #19: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #20: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #21: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #22: _PyObject_Call_Prepend + 0xc4 (0x4c47c4 in /usr/bin/python3)
[rank1]:[rank1]: frame #23: /usr/bin/python3() [0x528970]
[rank1]:[rank1]: frame #24: _PyObject_MakeTpCall + 0x7c (0x4c2d1c in /usr/bin/python3)
[rank1]:[rank1]: frame #25: _PyEval_EvalFrameDefault + 0x8a0 (0x563824 in /usr/bin/python3)
[rank1]:[rank1]: frame #26: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #27: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #28: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #29: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #30: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #31: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #32: _PyObject_Call_Prepend + 0xc4 (0x4c47c4 in /usr/bin/python3)
[rank1]:[rank1]: frame #33: /usr/bin/python3() [0x528970]
[rank1]:[rank1]: frame #34: PyObject_Call + 0x6c (0x4c51cc in /usr/bin/python3)
[rank1]:[rank1]: frame #35: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #36: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #37: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #38: <unknown function> + 0xb21058 (0xffffbce21058 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #39: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #40: <unknown function> + 0xb21058 (0xffffbce21058 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #41: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #42: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #43: <unknown function> + 0xb21058 (0xffffbce21058 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #44: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #45: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #46: <unknown function> + 0xb20fbc (0xffffbce20fbc in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #47: _PyObject_Call_Prepend + 0xc4 (0x4c47c4 in /usr/bin/python3)
[rank1]:[rank1]: frame #48: /usr/bin/python3() [0x528970]
[rank1]:[rank1]: frame #49: _PyObject_MakeTpCall + 0x7c (0x4c2d1c in /usr/bin/python3)
[rank1]:[rank1]: frame #50: _PyEval_EvalFrameDefault + 0x8a0 (0x563824 in /usr/bin/python3)
[rank1]:[rank1]: frame #51: dynamo_eval_custom_code + 0x21c (0xffffbce1f60c in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #52: <unknown function> + 0xb20dfc (0xffffbce20dfc in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #53: /usr/bin/python3() [0x4c6ee8]
[rank1]:[rank1]: frame #54: PyObject_Call + 0x118 (0x4c5278 in /usr/bin/python3)
[rank1]:[rank1]: frame #55: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #56: <unknown function> + 0xb20fbc (0xffffbce20fbc in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #57: /usr/bin/python3() [0x4c6ee8]
[rank1]:[rank1]: frame #58: PyObject_Call + 0x118 (0x4c5278 in /usr/bin/python3)
[rank1]:[rank1]: frame #59: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #60: <unknown function> + 0xb20fbc (0xffffbce20fbc in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #61: /usr/bin/python3() [0x4c6ee8]
[rank1]:[rank1]: frame #62: PyObject_Call + 0x118 (0x4c5278 in /usr/bin/python3)
[rank1]:[rank1]: frame #63: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
import argparse
import os

import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel
from transformers import AutoConfig
from transformers import AutoModel
from thunder.dynamo.compiler import ThunderCompiler


LOCAL_RANK = int(os.getenv("LOCAL_RANK"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE"))
MODEL_ID = "Qwen/Qwen3-14B"

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--warmup-iters", type=int, default=10)
parser.add_argument("--iters", type=int, default=20)
parser.add_argument("--batch-size", "-B", type=int, default=2)
parser.add_argument("--seq-len", type=int, default=1024)
parser.add_argument(
    "--execution-mode",
    "--mode",
    type=str,
    choices=("eager", "torch_compile", "thunderfx"),
    default="torch_compile",
)
args = parser.parse_args()

WARMUPS: int = args.warmup_iters
N_ITERS: int = args.iters
BATCH_SIZE = args.batch_size
SEQ_LEN = args.seq_len
INPUT_SHAPE = (BATCH_SIZE, SEQ_LEN)


def _init_weights(m: nn.Module) -> None:
    if hasattr(m, "reset_parameters"):
        m.reset_parameters()
    else:
        for param in m.parameters(recurse=False):
            if param.ndim > 1:
                nn.init.kaiming_uniform_(param)
            else:
                nn.init.zeros_(param)


def main():
    config = AutoConfig.from_pretrained(MODEL_ID)
    with torch.device("meta"):
        model = AutoModel.from_config(config)

    simple_fsdp_model = data_parallel(model, device_mesh, mode="fully_shard")
    simple_fsdp_model = simple_fsdp_model.to_empty(device=device)
    simple_fsdp_model.apply(_init_weights)

    match args.execution_mode:
        case "torch_compile":
            simple_fsdp_model = torch.compile(simple_fsdp_model, fullgraph=True)
        case "thunderfx":
            thunder_config = {"enable_nv_linear": True, "enable_nv_matmul": True, "enable_nv_sdpa": True}
            backend = ThunderCompiler(**thunder_config)
            simple_fsdp_model = torch.compile(simple_fsdp_model, fullgraph=True, backend=backend)
        case _:
            pass

    for _ in tqdm.tqdm(range(WARMUPS), disable=LOCAL_RANK != 0):
        input_ids = torch.randint(0, 1024 * 64, INPUT_SHAPE, dtype=torch.int64, device=device)
        attention_mask = torch.ones(INPUT_SHAPE, device=device, dtype=torch.long)
        simple_fsdp_model(input_ids=input_ids, attention_mask=attention_mask)

    for _ in tqdm.tqdm(range(N_ITERS), disable=LOCAL_RANK != 0):
        input_ids = torch.randint(0, 1024 * 64, INPUT_SHAPE, dtype=torch.int64, device=device)
        attention_mask = torch.ones(INPUT_SHAPE, device=device, dtype=torch.long)
        simple_fsdp_model(input_ids=input_ids, attention_mask=attention_mask)


if __name__ == "__main__":
    device_mesh = init_device_mesh(
        "cuda",
        (WORLD_SIZE,),
    )
    device = torch.device("cuda", LOCAL_RANK)
    torch.set_default_device(device)
    dtype = torch.bfloat16
    torch.set_default_dtype(dtype)

    try:
        main()
    except Exception:
        raise
    finally:
        for pg in device_mesh.get_all_groups():
            torch.distributed.destroy_process_group(pg)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions