-
Notifications
You must be signed in to change notification settings - Fork 69
Open
Description
https://gist.github.com/crcrpar/e5ac4212e48e5b8846653d34e5c0857e has full log
[rank1]:[rank1]: Traceback (most recent call last):
[rank1]:[rank1]: File "/opt/pytorch/lightning-thunder/./run_simple_fsdp.py", line 92, in <module>
[rank1]:[rank1]: main()
[rank1]:[rank1]: File "/opt/pytorch/lightning-thunder/./run_simple_fsdp.py", line 78, in main
[rank1]:[rank1]: simple_fsdp_model(input_ids=input_ids, attention_mask=attention_mask)
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 418, in __call__
[rank1]:[rank1]: return super().__call__(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank1]:[rank1]: return self._call_impl(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank1]:[rank1]: return forward_call(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 886, in compile_wrapper
[rank1]:[rank1]: return fn(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank1]:[rank1]: return self._call_impl(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank1]:[rank1]: return forward_call(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 953, in wrapper
[rank1]:[rank1]: @wraps(func)
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 418, in __call__
[rank1]:[rank1]: return super().__call__(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank1]:[rank1]: return self._call_impl(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank1]:[rank1]: return forward_call(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1098, in _fn
[rank1]:[rank1]: return fn(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 838, in call_wrapped
[rank1]:[rank1]: return self._wrapped_call(self, *args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 413, in __call__
[rank1]:[rank1]: raise e
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 400, in __call__
[rank1]:[rank1]: return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank1]:[rank1]: return self._call_impl(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank1]:[rank1]: return forward_call(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "<eval_with_key>.7971", line 6, in forward
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank1]:[rank1]: return self._call_impl(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank1]:[rank1]: return forward_call(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/opt/pytorch/lightning-thunder/thunder/core/module.py", line 80, in forward
[rank1]:[rank1]: res = self._forward_fn(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 839, in wrapped
[rank1]:[rank1]: return fn(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 889, in fn_
[rank1]:[rank1]: result = cache_entry.computation_fn(*inps)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 794, in wrapped
[rank1]:[rank1]: return fn(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 121, in decorate_context
[rank1]:[rank1]: return func(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/opt/pytorch/lightning-thunder/thunder/executors/torchex.py", line 169, in no_autocast_fn
[rank1]:[rank1]: return fn(*args, **kwargs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "thunder.computation_3", line 37, in computation
[rank1]:[rank1]: File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 546, in __call__
[rank1]:[rank1]: self.last_used = self.get_fd(self.to_descriptors(args))
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 624, in get_fd
[rank1]:[rank1]: return create_fd(bsyms, input_descriptors, sorted_unique_inputs, sorted_unique_outputs)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 409, in create_fd
[rank1]:[rank1]: definition(fd)
[rank1]:[rank1]: File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 369, in definition
[rank1]:[rank1]: translate_bound_symbol(bsym)
[rank1]:[rank1]: File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 359, in translate_bound_symbol
[rank1]:[rank1]: nvresults = translator(*bsym.args, **bsym.kwargs, fd=fd, lc_to_nv_map=lc_to_nv_map)
[rank1]:[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[rank1]: File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 2246, in copy_
[rank1]:[rank1]: fd.add_output(alias_output, alias_input=nvcopy_to)
[rank1]:[rank1]: File "/opt/pytorch/nvfuser/python/nvfuser_direct/__init__.py", line 221, in add_output
[rank1]:[rank1]: self._fusion.add_output(*args, **kwargs)
[rank1]:[rank1]: RuntimeError: INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:280, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues.
[rank1]:[rank1]: Expected iter_types_match(id->getIterType(), new_id->getIterType()) . Axes bS4{1 ex 2} and iS8{2} do not match for self replay.
[rank1]:[rank1]: Exception raised from selfReplay at /opt/pytorch/nvfuser/csrc/transform_replay.cpp:280 (most recent call first):
[rank1]:[rank1]: frame #0: nvfuser::nvfCheckFail(char const*, char const*, long, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x110 (0xfffc20c24db4 in /opt/pytorch/nvfuser/python/nvfuser_direct/../build/libnvfuser_codegen.so)
[rank1]:[rank1]: frame #1: nvfuser::nvfErrorFail(char const*, char const*, long, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x88 (0xfffc20f1b668 in /opt/pytorch/nvfuser/python/nvfuser_direct/../build/libnvfuser_codegen.so)
[rank1]:[rank1]: frame #2: <unknown function> + 0xf7ba50 (0xfffc2153ba50 in /opt/pytorch/nvfuser/python/nvfuser_direct/../build/libnvfuser_codegen.so)
[rank1]:[rank1]: frame #3: nvfuser::Fusion::aliasOutputToInput(nvfuser::Val*, nvfuser::Val*, nvfuser::AllocationType) + 0x130 (0xfffc20f62bf4 in /opt/pytorch/nvfuser/python/nvfuser_direct/../build/libnvfuser_codegen.so)
[rank1]:[rank1]: frame #4: <unknown function> + 0x118a98 (0xfffc236a8a98 in /opt/pytorch/nvfuser/python/nvfuser_direct/_C_DIRECT.cpython-312-aarch64-linux-gnu.so)
[rank1]:[rank1]: frame #5: <unknown function> + 0x3b5e4 (0xfffc235cb5e4 in /opt/pytorch/nvfuser/python/nvfuser_direct/_C_DIRECT.cpython-312-aarch64-linux-gnu.so)
[rank1]:[rank1]: frame #6: /usr/bin/python3() [0x503454]
[rank1]:[rank1]: frame #7: _PyObject_MakeTpCall + 0x7c (0x4c2d1c in /usr/bin/python3)
[rank1]:[rank1]: frame #8: /usr/bin/python3() [0x4c7000]
[rank1]:[rank1]: frame #9: PyObject_Call + 0x118 (0x4c5278 in /usr/bin/python3)
[rank1]:[rank1]: frame #10: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #11: /usr/bin/python3() [0x5ef170]
[rank1]:[rank1]: frame #12: _PyObject_MakeTpCall + 0x7c (0x4c2d1c in /usr/bin/python3)
[rank1]:[rank1]: frame #13: _PyEval_EvalFrameDefault + 0x8a0 (0x563824 in /usr/bin/python3)
[rank1]:[rank1]: frame #14: _PyObject_Call_Prepend + 0xc4 (0x4c47c4 in /usr/bin/python3)
[rank1]:[rank1]: frame #15: /usr/bin/python3() [0x528970]
[rank1]:[rank1]: frame #16: _PyObject_MakeTpCall + 0x7c (0x4c2d1c in /usr/bin/python3)
[rank1]:[rank1]: frame #17: _PyEval_EvalFrameDefault + 0x8a0 (0x563824 in /usr/bin/python3)
[rank1]:[rank1]: frame #18: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #19: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #20: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #21: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #22: _PyObject_Call_Prepend + 0xc4 (0x4c47c4 in /usr/bin/python3)
[rank1]:[rank1]: frame #23: /usr/bin/python3() [0x528970]
[rank1]:[rank1]: frame #24: _PyObject_MakeTpCall + 0x7c (0x4c2d1c in /usr/bin/python3)
[rank1]:[rank1]: frame #25: _PyEval_EvalFrameDefault + 0x8a0 (0x563824 in /usr/bin/python3)
[rank1]:[rank1]: frame #26: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #27: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #28: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #29: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #30: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #31: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #32: _PyObject_Call_Prepend + 0xc4 (0x4c47c4 in /usr/bin/python3)
[rank1]:[rank1]: frame #33: /usr/bin/python3() [0x528970]
[rank1]:[rank1]: frame #34: PyObject_Call + 0x6c (0x4c51cc in /usr/bin/python3)
[rank1]:[rank1]: frame #35: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #36: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #37: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #38: <unknown function> + 0xb21058 (0xffffbce21058 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #39: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #40: <unknown function> + 0xb21058 (0xffffbce21058 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #41: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #42: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #43: <unknown function> + 0xb21058 (0xffffbce21058 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #44: /usr/bin/python3() [0x4c6e0c]
[rank1]:[rank1]: frame #45: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #46: <unknown function> + 0xb20fbc (0xffffbce20fbc in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #47: _PyObject_Call_Prepend + 0xc4 (0x4c47c4 in /usr/bin/python3)
[rank1]:[rank1]: frame #48: /usr/bin/python3() [0x528970]
[rank1]:[rank1]: frame #49: _PyObject_MakeTpCall + 0x7c (0x4c2d1c in /usr/bin/python3)
[rank1]:[rank1]: frame #50: _PyEval_EvalFrameDefault + 0x8a0 (0x563824 in /usr/bin/python3)
[rank1]:[rank1]: frame #51: dynamo_eval_custom_code + 0x21c (0xffffbce1f60c in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #52: <unknown function> + 0xb20dfc (0xffffbce20dfc in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #53: /usr/bin/python3() [0x4c6ee8]
[rank1]:[rank1]: frame #54: PyObject_Call + 0x118 (0x4c5278 in /usr/bin/python3)
[rank1]:[rank1]: frame #55: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #56: <unknown function> + 0xb20fbc (0xffffbce20fbc in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #57: /usr/bin/python3() [0x4c6ee8]
[rank1]:[rank1]: frame #58: PyObject_Call + 0x118 (0x4c5278 in /usr/bin/python3)
[rank1]:[rank1]: frame #59: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
[rank1]:[rank1]: frame #60: <unknown function> + 0xb20fbc (0xffffbce20fbc in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
[rank1]:[rank1]: frame #61: /usr/bin/python3() [0x4c6ee8]
[rank1]:[rank1]: frame #62: PyObject_Call + 0x118 (0x4c5278 in /usr/bin/python3)
[rank1]:[rank1]: frame #63: _PyEval_EvalFrameDefault + 0x3e50 (0x566dd4 in /usr/bin/python3)
- pjnl-20251025
- torchtitan: https://github.com/pytorch/torchtitan/tree/81a36c55424272b5d1358473ef2447d9c20a1c03
- command:
torchrun --standalone --role rank --tee 3 --nproc-per-node 4 ./run_simple_fsdp.py --warmup-iters 0 --iters 5 --mode thunderfx
import argparse
import os
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel
from transformers import AutoConfig
from transformers import AutoModel
from thunder.dynamo.compiler import ThunderCompiler
LOCAL_RANK = int(os.getenv("LOCAL_RANK"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE"))
MODEL_ID = "Qwen/Qwen3-14B"
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--warmup-iters", type=int, default=10)
parser.add_argument("--iters", type=int, default=20)
parser.add_argument("--batch-size", "-B", type=int, default=2)
parser.add_argument("--seq-len", type=int, default=1024)
parser.add_argument(
"--execution-mode",
"--mode",
type=str,
choices=("eager", "torch_compile", "thunderfx"),
default="torch_compile",
)
args = parser.parse_args()
WARMUPS: int = args.warmup_iters
N_ITERS: int = args.iters
BATCH_SIZE = args.batch_size
SEQ_LEN = args.seq_len
INPUT_SHAPE = (BATCH_SIZE, SEQ_LEN)
def _init_weights(m: nn.Module) -> None:
if hasattr(m, "reset_parameters"):
m.reset_parameters()
else:
for param in m.parameters(recurse=False):
if param.ndim > 1:
nn.init.kaiming_uniform_(param)
else:
nn.init.zeros_(param)
def main():
config = AutoConfig.from_pretrained(MODEL_ID)
with torch.device("meta"):
model = AutoModel.from_config(config)
simple_fsdp_model = data_parallel(model, device_mesh, mode="fully_shard")
simple_fsdp_model = simple_fsdp_model.to_empty(device=device)
simple_fsdp_model.apply(_init_weights)
match args.execution_mode:
case "torch_compile":
simple_fsdp_model = torch.compile(simple_fsdp_model, fullgraph=True)
case "thunderfx":
thunder_config = {"enable_nv_linear": True, "enable_nv_matmul": True, "enable_nv_sdpa": True}
backend = ThunderCompiler(**thunder_config)
simple_fsdp_model = torch.compile(simple_fsdp_model, fullgraph=True, backend=backend)
case _:
pass
for _ in tqdm.tqdm(range(WARMUPS), disable=LOCAL_RANK != 0):
input_ids = torch.randint(0, 1024 * 64, INPUT_SHAPE, dtype=torch.int64, device=device)
attention_mask = torch.ones(INPUT_SHAPE, device=device, dtype=torch.long)
simple_fsdp_model(input_ids=input_ids, attention_mask=attention_mask)
for _ in tqdm.tqdm(range(N_ITERS), disable=LOCAL_RANK != 0):
input_ids = torch.randint(0, 1024 * 64, INPUT_SHAPE, dtype=torch.int64, device=device)
attention_mask = torch.ones(INPUT_SHAPE, device=device, dtype=torch.long)
simple_fsdp_model(input_ids=input_ids, attention_mask=attention_mask)
if __name__ == "__main__":
device_mesh = init_device_mesh(
"cuda",
(WORLD_SIZE,),
)
device = torch.device("cuda", LOCAL_RANK)
torch.set_default_device(device)
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
try:
main()
except Exception:
raise
finally:
for pg in device_mesh.get_all_groups():
torch.distributed.destroy_process_group(pg)Metadata
Metadata
Assignees
Labels
No labels