-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
48 lines (35 loc) · 1.44 KB
/
dataset.py
File metadata and controls
48 lines (35 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from datasets import load_dataset
from tqdm import tqdm
max_examples = 800_000 # 1 example = 1,000 tokens
test_size = 0.01
dataset = load_dataset("openwebtext", split=f'train[0:{max_examples}]', cache_dir='./cache')
split_dataset = dataset.train_test_split(test_size=test_size, seed=42, shuffle=True)
split_dataset['val'] = split_dataset.pop('test')
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file('tokenizer.json')
def Encode(str):
return (tokenizer.encode(str)).ids
eot_id = Encode('<|endoftext|>')
def process(example):
tokens = Encode(example['text']) + eot_id
return {'tokens': tokens, 'len': len(tokens)}
tokenized = split_dataset.map(
process,
remove_columns=['text'],
desc='Tokenizing the splits',
)
import numpy as np
for split, example in tokenized.items():
length = np.sum(example['len'], dtype=np.int64)
filename = f"{split}.bin"
map = np.memmap(filename, dtype=np.uint16, mode='w+', shape=(length,))
total_batches = int(min(max_examples*test_size, 512))
if total_batches < 1:
exit("!!<<Number of batches too small>>!!")
start_idx = 0
for ix in tqdm(range(total_batches), desc=f'Writing {filename}'):
batch = example.shard(num_shards=total_batches, index=ix, contiguous=True).with_format('numpy')
map_batch = np.concatenate(batch['tokens'])
map[start_idx: start_idx + len(map_batch)] = map_batch
start_idx += len(map_batch)
map.flush()