-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsample.py
More file actions
99 lines (83 loc) · 2.86 KB
/
sample.py
File metadata and controls
99 lines (83 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from model import Model
import torch
import os
import regex as re
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
with open('notes.txt', 'r', encoding='utf-8') as f:
text = f.read()
chars = (
[f"t{i*24}" for i in range(0, 86)] +
[f"d{i*24}" for i in range(0, 86)] +
[f"n{i}" for i in range(0, 128)] +
["=", "\n"]
)
vocab_size = len(chars)
stoi = { tok: idx for idx, tok in enumerate(chars) }
itos = { idx: tok for tok, idx in stoi.items() }
# ---- encode/decode ----
def encode(s: str) -> list[int]:
"""
Splits `s` into tokens of the form:
- t<digits>
- d<digits>
- n<digits>
- v<digits>
- =
- newline
and returns their integer IDs.
"""
toks = re.findall(r"t\d+|d\d+|n\d+|v\d+|=|\n", s)
ids = []
for t in toks:
if t not in stoi:
raise ValueError(f"Unknown token: {t!r}")
ids.append(stoi[t])
return ids
def decode(ids, join=True):
# if torch Tensor, turn into a Python list
if isinstance(ids, torch.Tensor):
ids = ids.tolist()
# map back to tokens
tokens = [itos[i] for i in ids]
# optionally join
if join:
return "".join(tokens)
return tokens
seed = None # None to make it random every time!
mode = 'print' # write to write to file, print to print text
write_file_path = 'generation.txt'
if seed is not None:
torch.manual_seed(42)
save_path = 'modelsave.pt'
if os.path.exists(save_path):
checkpoint = torch.load(save_path)
args = checkpoint['args']
model = Model(args)
model.to(device)
model.load_state_dict(checkpoint['model_params'])
print(f"Resuming from iter {checkpoint['iter']:,}\nParameters: {model.count_params():,}")
else:
exit("!!<<Model Checkpoint not found!>>!!")
context = "="
max_new_tokens = 1024
p = 1
num_samples = 1
temperature = 1
view_probabilites = False
x = torch.tensor(encode(context), dtype=torch.long, device=device).unsqueeze(0)
if mode == 'print':
with torch.no_grad():
with torch.amp.autocast(device_type=device, dtype=dtype):
for k in range(num_samples):
print("\n", context, end="")
model.generate(x, chars, max_new_tokens, temperature=temperature, top_p=p, view_probabilities=view_probabilites)
print('\n---------------\n')
elif mode == 'write':
with open(write_file_path, 'w', encoding='utf-8') as f:
with torch.no_grad():
with torch.amp.autocast(device_type=device, dtype=dtype):
for k in range(num_samples):
y = model.generate(x, chars, max_new_tokens, mode='write', temperature=temperature, top_p=p, view_probabilities=view_probabilites)
f.write(decode(y[0].tolist()))
f.write('\n---------------\n')