-
Notifications
You must be signed in to change notification settings - Fork 57
Open
Description
S01/run.py L167 has a bug in the model initialization: int8 is used to initialize when the vocabulary size is 256. This throws an error:
---------------------------------------------------------------------------
OverflowError Traceback (most recent call last)
<ipython-input-7-263240bbee7e> in <cell line: 0>()
----> 1 main()
8 frames
<ipython-input-6-43b29c8cd544> in main()
4 model = OurModel()
5 tx = optax.adam(learning_rate=LEARNING_RATE)
----> 6 params = model.init(rngkey, jax.numpy.ones((BATCH_IN_SEQUENCES, SEQUENCE_LENGTH), dtype = jnp.int8))
7
8 state = train_state.TrainState.create(
[... skipping hidden 9 frame]
<ipython-input-4-5c82d6e6db30> in __call__(self, input_tokens)
12 )
13
---> 14 x = jnp.asarray(embedding)[input_tokens] # BATCH, SEQUENCE, EMBED
15
16 pos_embedding = self.param(
/usr/local/lib/python3.11/dist-packages/jax/_src/array.py in __getitem__(self, idx)
420 out.aval, sharding, [out], committed=False, _skip_checks=True)
421
--> 422 return indexing.rewriting_take(self, idx)
423
424 def __iter__(self):
/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/indexing.py in rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value, out_sharding)
626
627 treedef, static_idx, dynamic_idx = split_index_for_jit(idx, arr.shape)
--> 628 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
629 unique_indices, mode, fill_value, out_sharding)
630
/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/indexing.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value, out_sharding)
635 unique_indices, mode, fill_value, out_sharding):
636 idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
--> 637 indexer = index_to_gather(np.shape(arr), idx) # shared with _scatter_update
638 y = arr
639
/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/indexing.py in index_to_gather(x_shape, idx, normalize_indices)
790 advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
791 for e, i, j in advanced_pairs)
--> 792 advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
793
794 x_axis = 0 # Current axis in x.
/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/indexing.py in <genexpr>(.0)
788 if lax_numpy.isscalar(e) or isinstance(e, (Sequence, Array, np.ndarray)))
789 if normalize_indices:
--> 790 advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
791 for e, i, j in advanced_pairs)
792 advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/indexing.py in _normalize_index(index, axis_size)
192 return index
193 if core.is_constant_dim(axis_size):
--> 194 axis_size_val = lax_internal._const(index, axis_size)
195 else:
196 axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size),
/usr/local/lib/python3.11/dist-packages/jax/_src/lax/lax.py in _const(example, val)
7790 val = dtypes.scalar_type_of(example)(val)
7791 return val if dtype == _dtype(val) else np.array(val, dtype)
-> 7792 return np.array(val, dtype)
7793
7794 _zeros: Callable = partial(full_like, fill_value=0)
OverflowError: Python integer 256 out of bounds for int8Instead uint8 or some other format which supports counting up to the vocab size should be used.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels