Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModuleNotFoundError: No module named 'jax.experimental.maps' after main-brunch update (#662) #673

Open
MikeMpapa opened this issue Jul 25, 2024 · 11 comments

Comments

@MikeMpapa
Copy link

MikeMpapa commented Jul 25, 2024

Hi - after yesterday's code update I am getting the following error. Any advise? I am using the Jax-levanter docker image

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 119, in main
    Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 597, in round_axis_for_partitioning
    size = physical_axis_size(axis, mapping)
  File "/opt/haliax/src/haliax/partitioning.py", line 566, in physical_axis_size
    mesh = _get_mesh()
  File "/opt/haliax/src/haliax/partitioning.py", line 606, in _get_mesh
    from jax.experimental.maps import thread_resources
ModuleNotFoundError: No module named 'jax.experimental.maps'

EDIT: Actually I now see the main error with previous repo HEAD, which is weird cause yesterday the container was working fine. You think something has change on the NVIDA-image ?

@dlwh
Copy link
Member

dlwh commented Jul 25, 2024

yeah looks like they moved/removed jax.experimental.maps. the container sticks close to JAX head, which we don't.

@dlwh
Copy link
Member

dlwh commented Jul 25, 2024

Once I merge stanford-crfm/haliax#102 and a suitable interval for the package to propagate, you should be able to update to the latest dev version of haliax (probably 308 or 309)

@dlwh
Copy link
Member

dlwh commented Jul 25, 2024

try pip install haliax==1.4.dev310 and see if it fixes

@MikeMpapa

This comment was marked as outdated.

@MikeMpapa
Copy link
Author

MikeMpapa commented Jul 25, 2024

Please ignore and let me retest - I wasn't on Levanter head so that might be it. Will follow up

@MikeMpapa
Copy link
Author

Yeah same output unfortunately. Is there a way I can access older jax containers?

INFO:levanter.distributed:Not initializing jax.distributed because no distributed config was provided, and no cluster was detected.
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:levanter.trainer:Setting run id to oz48brcz
2024-07-25T21:34:46 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /levanter
2024-07-25T21:34:46 - 0 - wandb.sdk.lib.gitlib - gitlib.py:92 - ERROR :: git root error: Cmd('git') failed due to: exit code(128)
  cmdline: git rev-parse --show-toplevel
  stderr: 'fatal: detected dubious ownership in repository at '/levanter'
To add an exception for this directory, call:

        git config --global --add safe.directory /levanter'
2024-07-25T21:34:46 - 0 - wandb.sdk.lib.gitlib - gitlib.py:92 - ERROR :: git root error: Cmd('git') failed due to: exit code(128)
  cmdline: git rev-parse --show-toplevel
  stderr: 'fatal: detected dubious ownership in repository at '/levanter'
To add an exception for this directory, call:

        git config --global --add safe.directory /levanter'
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3
wandb: You chose "Don't visualize my results"
wandb: WARNING `resume` will be ignored since W&B syncing is set to `offline`. Starting a new run with run id oz48brcz.
wandb: Tracking run with wandb version 0.17.5
wandb: W&B syncing is set to `offline` in this directory.  
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
2024-07-25T21:34:49 - 0 - levanter.distributed - distributed.py:215 - INFO :: No auto-discovered ray address found. Using ray.init('local').
2024-07-25T21:34:49 - 0 - levanter.distributed - distributed.py:267 - INFO :: ray.init(address='local', namespace='levanter', **{})
/usr/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _posixsubprocess.fork_exec(
/usr/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _posixsubprocess.fork_exec(
2024-07-25 21:34:51,416 INFO worker.py:1779 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
2024-07-25T21:34:52 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /levanter
2024-07-25T21:34:52 - 0 - levanter.tracker.wandb - wandb.py:200 - INFO :: Setting wandb code_dir to /levanter
train:   0%|                                                                                                                                                                    | 0/50 [00:00<?, ?it/s]2024-07-25T21:34:53 - 0 - levanter.data.shard_cache - shard_cache.py:1266 - INFO :: Loading cache from /cache/validation
2024-07-25T21:34:53 - 0 - levanter.data.text - text.py:692 - INFO :: Building cache for validation...
2024-07-25T21:34:53 - 0 - levanter.data.shard_cache - shard_cache.py:1266 - INFO :: Loading cache from /cache/validation
(ChunkCacheBuilder pid=505) 2024-07-25 21:34:59,609 - levanter.data.shard_cache.builder::cache/validation - INFO - Starting cache build for 1 shards
2024-07-25T21:35:06 - 0 - levanter.data.text - text.py:256 - INFO :: Cache /cache/validation is complete.
2024-07-25T21:35:06 - 0 - levanter.data.shard_cache - shard_cache.py:1266 - INFO :: Loading cache from /cache/train
2024-07-25T21:35:06 - 0 - levanter.data.text - text.py:692 - INFO :: Building cache for train...
2024-07-25T21:35:06 - 0 - levanter.data.shard_cache - shard_cache.py:1266 - INFO :: Loading cache from /cache/train
2024-07-25T21:35:06 - 0 - preprocessing..validation - metrics_monitor.py:143 - INFO ::  done: Shards: 0 | Chunks: 1 | Docs: 28
2024-07-25T21:35:06 - 0 - preprocessing..validation - metrics_monitor.py:143 - INFO ::  done: Shards: 1 | Chunks: 1 | Docs: 28
2024-07-25T21:35:06 - 0 - preprocessing..validation - metrics_monitor.py:150 - INFO :: Cache creation finished
(ChunkCacheBroker pid=460) 2024-07-25 21:35:06,723 - levanter.data.shard_cache - INFO - Finalizing cache /cache/validation...
(ChunkCacheBuilder pid=505) 2024-07-25 21:35:06,718 - levanter.data.shard_cache - INFO - Shard valid_txt finished
2024-07-25T21:35:10 - 0 - levanter.data.text - text.py:258 - INFO :: Cache /cache/train is incomplete. This will block until at least one chunk per process is complete.
(ChunkCacheBuilder pid=699) 2024-07-25 21:35:13,905 - levanter.data.shard_cache.builder::cache/train - INFO - Starting cache build for 1 shards
2024-07-25T21:35:14 - 0 - __main__ - train_lm.py:129 - INFO :: No training checkpoint found. Initializing model from HF checkpoint 'stanford-crfm/music-medium-800k'
config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.96k/1.96k [00:00<00:00, 8.35MB/s]
config.json:   0%|                                                                                                                                                         | 0.00/1.96k [00:00<?, ?B/s2024-07-25T21:35:21 - 0 - preprocessing..train - metrics_monitor.py:143 - INFO ::  done: Shards: 0 | Chunks: 1 | Docs: 279                                          | 262M/1.44G [00:05<00:23, 51.0MB/s]
2024-07-25T21:35:21 - 0 - preprocessing..train - metrics_monitor.py:143 - INFO ::  done: Shards: 1 | Chunks: 1 | Docs: 279
2024-07-25T21:35:21 - 0 - preprocessing..train - metrics_monitor.py:143 - INFO ::  done: Shards: 1 | Chunks: 1 | Docs: 279
2024-07-25T21:35:21 - 0 - preprocessing..train - metrics_monitor.py:150 - INFO :: Cache creation finished
(ChunkCacheBroker pid=664) 2024-07-25 21:35:21,188 - levanter.data.shard_cache - INFO - Finalizing cache /cache/train...
(ChunkCacheBuilder pid=699) 2024-07-25 21:35:21,183 - levanter.data.shard_cache - INFO - Shard train_txt finished
pytorch_model.bin: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.44G/1.44G [00:38<00:00, 37.6MB/s]
Loading weights: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 293/293 [00:01<00:00, 189.18it/s]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.<00:00, 182.51it/s]

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 215, in main
    trainer.train(state, train_loader)
  File "/levanter/src/levanter/trainer.py", line 403, in train
    for info in self.training_steps(state, train_loader, run_hooks=run_hooks):
  File "/levanter/src/levanter/trainer.py", line 386, in training_steps
    info = self.train_step(state, example)
  File "/levanter/src/levanter/trainer.py", line 370, in train_step
    loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 261, in __call__
    return self._call(False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/equinox/_module.py", line 1053, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 315, in _call
    output_shape = _cached_filter_eval_shape(self._fn, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 546, in _cached_filter_eval_shape
    _eval_shape_cache[static] = eqx.filter_eval_shape(fun, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/equinox/_eval_shape.py", line 38, in filter_eval_shape
    dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
  File "/levanter/src/levanter/trainer.py", line 498, in _train_step
    loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)
  File "/levanter/src/levanter/trainer.py", line 515, in _compute_gradients_microbatched
    return grad_fn(model, *batch, **batch_kwargs)
  File "/levanter/src/levanter/grad_accum.py", line 92, in wrapped_fn
    r_shape = eqx.filter_eval_shape(fn, *args, **kwargs)
  File "/levanter/src/levanter/trainer.py", line 191, in fn
    return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs))
  File "/levanter/src/levanter/types.py", line 75, in __call__
    return model.compute_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs)
  File "/levanter/src/levanter/models/lm_model.py", line 129, in compute_loss
    logits = self(example.tokens, example.attn_mask, key=key)
  File "/levanter/src/levanter/models/gpt2.py", line 399, in __call__
    x = self.transformer(x, attn_mask, key=k_transformer)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 298, in __call__
    x = self.blocks.fold(x, attn_mask, hax.arange(self.config.Layers), key=keys)
  File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 221, in fold
    return haliax.fold(do_block, self.Block)(init, self.stacked, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 202, in scanned_f
    return scan_preconfig(init, *args, **kwargs)[0]
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 134, in scanned_f
    carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 127, in wrapped_fn
    carry, y = f(carry, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 197, in scan_compatible_fn
    return fn(carry, *args, **kwargs), None
  File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 75, in wrapper
    dynamic_out, static_out = checkpointed_fun(static, dynamic)
  File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 66, in _fn
    _out = fun(*_args, **_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 225, in _do_block
    return block(carry, *extra_args, **extra_kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 268, in __call__
    attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 200, in __call__
    attn_output = dot_product_attention(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/attention.py", line 119, in dot_product_attention
    attention_out = _try_te_attention(
  File "/levanter/src/levanter/models/attention.py", line 242, in _try_te_attention
    return _te_flash_attention(
  File "/levanter/src/levanter/models/attention.py", line 313, in _te_flash_attention
    from transformer_engine.jax.fused_attn import fused_attn  # noqa: F401
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/__init__.py", line 10, in <module>
    import transformer_engine.common
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 107, in <module>
    _TE_LIB_CTYPES = _load_library()
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 78, in _load_library
    return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
  File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /usr/local/lib/python3.10/dist-packages/transformer_engine/libtransformer_engine.so: undefined symbol: cudnnGetLastErrorString
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 215, in main
    trainer.train(state, train_loader)
  File "/levanter/src/levanter/trainer.py", line 403, in train
    for info in self.training_steps(state, train_loader, run_hooks=run_hooks):
  File "/levanter/src/levanter/trainer.py", line 386, in training_steps
    info = self.train_step(state, example)
  File "/levanter/src/levanter/trainer.py", line 370, in train_step
    loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 261, in __call__
    return self._call(False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/equinox/_module.py", line 1053, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 315, in _call
    output_shape = _cached_filter_eval_shape(self._fn, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/partitioning.py", line 546, in _cached_filter_eval_shape
    _eval_shape_cache[static] = eqx.filter_eval_shape(fun, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/equinox/_eval_shape.py", line 38, in filter_eval_shape
    dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
  File "/levanter/src/levanter/trainer.py", line 498, in _train_step
    loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)
  File "/levanter/src/levanter/trainer.py", line 515, in _compute_gradients_microbatched
    return grad_fn(model, *batch, **batch_kwargs)
  File "/levanter/src/levanter/grad_accum.py", line 92, in wrapped_fn
    r_shape = eqx.filter_eval_shape(fn, *args, **kwargs)
  File "/levanter/src/levanter/trainer.py", line 191, in fn
    return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs))
  File "/levanter/src/levanter/types.py", line 75, in __call__
    return model.compute_loss(*inputs, reduction=reduction, reduction_axis=reduction_axis, **kwargs)
  File "/levanter/src/levanter/models/lm_model.py", line 129, in compute_loss
    logits = self(example.tokens, example.attn_mask, key=key)
  File "/levanter/src/levanter/models/gpt2.py", line 399, in __call__
    x = self.transformer(x, attn_mask, key=k_transformer)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 298, in __call__
    x = self.blocks.fold(x, attn_mask, hax.arange(self.config.Layers), key=keys)
  File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 221, in fold
    return haliax.fold(do_block, self.Block)(init, self.stacked, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 202, in scanned_f
    return scan_preconfig(init, *args, **kwargs)[0]
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 134, in scanned_f
    carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 127, in wrapped_fn
    carry, y = f(carry, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/hof.py", line 197, in scan_compatible_fn
    return fn(carry, *args, **kwargs), None
  File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 75, in wrapper
    dynamic_out, static_out = checkpointed_fun(static, dynamic)
  File "/usr/local/lib/python3.10/dist-packages/haliax/jax_utils.py", line 66, in _fn
    _out = fun(*_args, **_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/haliax/nn/scan.py", line 225, in _do_block
    return block(carry, *extra_args, **extra_kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 268, in __call__
    attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/gpt2.py", line 200, in __call__
    attn_output = dot_product_attention(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/levanter/src/levanter/models/attention.py", line 119, in dot_product_attention
    attention_out = _try_te_attention(
  File "/levanter/src/levanter/models/attention.py", line 242, in _try_te_attention
    return _te_flash_attention(
  File "/levanter/src/levanter/models/attention.py", line 313, in _te_flash_attention
    from transformer_engine.jax.fused_attn import fused_attn  # noqa: F401
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/__init__.py", line 10, in <module>
    import transformer_engine.common
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 107, in <module>
    _TE_LIB_CTYPES = _load_library()
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/common/__init__.py", line 78, in _load_library
    return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
  File "/usr/lib/python3.10/ctypes/__init__.py", line 374, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /usr/local/lib/python3.10/dist-packages/transformer_engine/libtransformer_engine.so: undefined symbol: cudnnGetLastErrorString
2024-07-25 21:36:00,899 WARNING worker.py:1450 -- SIGTERM handler is not set because current thread is not the main thread.
/usr/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _posixsubprocess.fork_exec(
wandb: 
wandb: Run history:
wandb:              preprocessing//train/chunks ▁
wandb:            preprocessing//train/finished ▁
wandb:           preprocessing//train/input_ids ▁
wandb:                preprocessing//train/rows ▁
wandb:              preprocessing//train/shards ▁
wandb:      preprocessing//train/token_type_ids ▁
wandb:         preprocessing//validation/chunks ▁
wandb:       preprocessing//validation/finished ▁
wandb:      preprocessing//validation/input_ids ▁
wandb:           preprocessing//validation/rows ▁
wandb:         preprocessing//validation/shards ▁
wandb: preprocessing//validation/token_type_ids ▁
wandb: 
wandb: Run summary:
wandb:                                  backend gpu
wandb:                              num_devices 1
wandb:                                num_hosts 1
wandb:                          parameter_count 359708672
wandb:              preprocessing//train/chunks 1
wandb:            preprocessing//train/finished 1
wandb:           preprocessing//train/input_ids 285696
wandb:                preprocessing//train/rows 279
wandb:              preprocessing//train/shards 1
wandb:      preprocessing//train/token_type_ids 285696
wandb:         preprocessing//validation/chunks 1
wandb:       preprocessing//validation/finished 1
wandb:      preprocessing//validation/input_ids 28672
wandb:           preprocessing//validation/rows 28
wandb:         preprocessing//validation/shards 1
wandb: preprocessing//validation/token_type_ids 28672
wandb:                   throughput/device_kind NVIDIA A10G
wandb:             throughput/flops_per_example 2514493636608.0
wandb:             throughput/theoretical_flops 125000000000000.0
wandb:  throughput/theoretical_flops_per_device 125000000000000.0
wandb: 
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /levanter/wandb/offline-run-20240725_213448-oz48brcz
wandb: Find logs at: ./wandb/offline-run-20240725_213448-oz48brcz/logs
wandb: WARNING The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require("core")`! See https://wandb.me/wandb-core for more information.
2024-07-25 21:36:03,171 INFO worker.py:1779 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
2024-07-25T21:36:04 - 0 - ShardCache.cache/train - shard_cache.py:1418 - ERROR :: Error while reading from shard cache.
Traceback (most recent call last):
  File "/levanter/src/levanter/data/shard_cache.py", line 1406, in iter_batches_from_chunks
    chunk = self._get_chunk_unmapped(i)
  File "/levanter/src/levanter/data/shard_cache.py", line 1336, in _get_chunk_unmapped
    chunk = ray.get(self._broker.get_chunk.remote(mapped_index), timeout=current_timeout)
  File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 202, in remote
    return self._remote(args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/util/tracing/tracing_helper.py", line 426, in _start_span
    return method(self, args, kwargs, *_args, **_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 330, in _remote
    return invocation(args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 311, in invocation
    return actor._actor_method_call(
  File "/usr/local/lib/python3.10/dist-packages/ray/actor.py", line 1460, in _actor_method_call
    object_refs = worker.core_worker.submit_actor_task(
  File "python/ray/_raylet.pyx", line 4258, in ray._raylet.CoreWorker.submit_actor_task
  File "python/ray/_raylet.pyx", line 4313, in ray._raylet.CoreWorker.submit_actor_task
Exception: Failed to submit task to actor ActorID(0e2b26ec3dd9161148784ee601000000) due to b"Can't find actor 0e2b26ec3dd9161148784ee601000000. It might be dead or it's from a different cluster"
Exception in thread ray_print_logs:
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 893, in print_logs
    subscriber.subscribe()
  File "python/ray/_raylet.pyx", line 3111, in ray._raylet._GcsSubscriber.subscribe
  File "python/ray/_raylet.pyx", line 586, in ray._raylet.check_status
ray.exceptions.RpcError: recvmsg:Connection reset by peer

@dlwh
Copy link
Member

dlwh commented Jul 25, 2024

this seems like a probelm with the cuda stuff. @DwarKapex Any thoughts?

@dlwh
Copy link
Member

dlwh commented Jul 25, 2024

(You can probably use https://github.com/orgs/nvidia/packages/container/jax/248105500?tag=levanter-2024-07-24 as jax:levanter-2024-07-24)

@MikeMpapa
Copy link
Author

that did the trick! Thanks so much!

@MikeMpapa
Copy link
Author

If it is of any help to you here is a slightly more detailed description of what I am trying to do.

  • I am trying to fine-tune this model which has been built using Levanter
  • The author of that work suggested in the past the this Levanter code edit plus installing torch to run the implementation (essentially is a gpt2 model)
  • With jax:levanter-2024-07-24 this trick seems to work.
  • With the current JAX image and without this Levanter code edit I am getting this error
  • With the current JAX image and with this Levanter code edit I am getting
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
    levanter.config.main(main)()
  File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/levanter/src/levanter/main/train_lm.py", line 82, in main
    levanter.initialize(config)
  File "/levanter/src/levanter/trainer.py", line 796, in initialize
    trainer_config.initialize()
  File "/levanter/src/levanter/trainer.py", line 627, in initialize
    _initialize_global_tracker(self.tracker, id)
  File "/levanter/src/levanter/trainer.py", line 522, in _initialize_global_tracker
    tracker = config.init(run_id)
  File "/levanter/src/levanter/tracker/wandb.py", line 131, in init
    git_settings = self._git_settings()
  File "/levanter/src/levanter/tracker/wandb.py", line 204, in _git_settings
    sha = self._get_git_sha(code_dir)
  File "/levanter/src/levanter/tracker/wandb.py", line 216, in _get_git_sha
    git_sha = repo.head.commit.hexsha
  File "/usr/local/lib/python3.10/dist-packages/git/refs/symbolic.py", line 297, in _get_commit
    obj = self._get_object()
  File "/usr/local/lib/python3.10/dist-packages/git/refs/symbolic.py", line 288, in _get_object
    return Object.new_from_sha(self.repo, hex_to_bin(self.dereference_recursive(self.repo, self.path)))
  File "/usr/local/lib/python3.10/dist-packages/git/objects/base.py", line 149, in new_from_sha
    oinfo = repo.odb.info(sha1)
  File "/usr/local/lib/python3.10/dist-packages/git/db.py", line 41, in info
    hexsha, typename, size = self._git.get_object_header(bin_to_hex(binsha))
  File "/usr/local/lib/python3.10/dist-packages/git/cmd.py", line 1678, in get_object_header
    return self.__get_object_header(cmd, ref)
  File "/usr/local/lib/python3.10/dist-packages/git/cmd.py", line 1661, in __get_object_header
    cmd.stdin.flush()
BrokenPipeError: [Errno 32] Broken pipe

Sharing this in case it helps you guys with debugging. Feel free to follow up if you have any question.

@dlwh
Copy link
Member

dlwh commented Jul 26, 2024

weird. that's easy to workaround but I don't know why it's happening

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants