Skip to content

S01 run.py Bug Report: int8 used for init when vocab size is 256 gives an overflow error #7

@sshkhr

Description

@sshkhr

S01/run.py L167 has a bug in the model initialization: int8 is used to initialize when the vocabulary size is 256. This throws an error:

---------------------------------------------------------------------------
OverflowError                             Traceback (most recent call last)
<ipython-input-7-263240bbee7e> in <cell line: 0>()
----> 1 main()

8 frames
<ipython-input-6-43b29c8cd544> in main()
      4     model = OurModel()
      5     tx = optax.adam(learning_rate=LEARNING_RATE)
----> 6     params = model.init(rngkey, jax.numpy.ones((BATCH_IN_SEQUENCES, SEQUENCE_LENGTH), dtype = jnp.int8))
      7 
      8     state = train_state.TrainState.create(

    [... skipping hidden 9 frame]

<ipython-input-4-5c82d6e6db30> in __call__(self, input_tokens)
     12     )
     13 
---> 14     x = jnp.asarray(embedding)[input_tokens] # BATCH, SEQUENCE, EMBED
     15 
     16     pos_embedding = self.param(

/usr/local/lib/python3.11/dist-packages/jax/_src/array.py in __getitem__(self, idx)
    420             out.aval, sharding, [out], committed=False, _skip_checks=True)
    421 
--> 422     return indexing.rewriting_take(self, idx)
    423 
    424   def __iter__(self):

/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/indexing.py in rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value, out_sharding)
    626 
    627   treedef, static_idx, dynamic_idx = split_index_for_jit(idx, arr.shape)
--> 628   return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
    629                  unique_indices, mode, fill_value, out_sharding)
    630 

/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/indexing.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value, out_sharding)
    635             unique_indices, mode, fill_value, out_sharding):
    636   idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
--> 637   indexer = index_to_gather(np.shape(arr), idx)  # shared with _scatter_update
    638   y = arr
    639 

/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/indexing.py in index_to_gather(x_shape, idx, normalize_indices)
    790       advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
    791                         for e, i, j in advanced_pairs)
--> 792     advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
    793 
    794   x_axis = 0  # Current axis in x.

/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/indexing.py in <genexpr>(.0)
    788       if lax_numpy.isscalar(e) or isinstance(e, (Sequence, Array, np.ndarray)))
    789     if normalize_indices:
--> 790       advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
    791                         for e, i, j in advanced_pairs)
    792     advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)

/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/indexing.py in _normalize_index(index, axis_size)
    192     return index
    193   if core.is_constant_dim(axis_size):
--> 194     axis_size_val = lax_internal._const(index, axis_size)
    195   else:
    196     axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size),

/usr/local/lib/python3.11/dist-packages/jax/_src/lax/lax.py in _const(example, val)
   7790     val = dtypes.scalar_type_of(example)(val)
   7791     return val if dtype == _dtype(val) else np.array(val, dtype)
-> 7792   return np.array(val, dtype)
   7793 
   7794 _zeros: Callable = partial(full_like, fill_value=0)

OverflowError: Python integer 256 out of bounds for int8

Instead uint8 or some other format which supports counting up to the vocab size should be used.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions