-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
150 lines (126 loc) · 5.09 KB
/
train.py
File metadata and controls
150 lines (126 loc) · 5.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import torch.nn as nn
from torch.nn import functional as F
from model import Model, ModelArgs
import time
import math
import numpy as np
from contextlib import nullcontext
# Decrease emb_dim, n_heads, n_layers and batch_size if running out of memory / too slow but will decrease performance
# Please follow the Cuda tutorial in README.md to GREATLY speed up training + sampling
load = True # True to load model checkpoint, False to not load. WARNING: If False, it can override checkpoints!
batch_size = 8
gradient_accumulation_steps = 40
grad_clip = 1.0
window_size = 512
emb_dim = 512
n_heads = 8
n_layers = 8
max_iters = 10_000
eval_interval = 100
save_interval = 100
warmup_iters = 300
lr_decay_iters = max_iters
learning_rate = 6e-4
min_lr = 6e-5
dropout = 0.0
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
fused = True if torch.cuda.is_available() else False
eval_iters = 200
modelsave_path = 'modelsave.pt'
tokenizer_path = 'tokenizer.json'
torch.manual_seed(42)
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == torch.float16))
ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=dtype)
train_data = np.memmap('train.bin', dtype=np.uint16, mode='r')
val_data = np.memmap('val.bin', dtype=np.uint16, mode='r')
print(f"Amount of tokens in training dataset: {train_data.shape[0]:,}")
import os
if os.path.exists(tokenizer_path):
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file(tokenizer_path)
vocab_size = tokenizer.get_vocab_size()
else:
exit("!!<<No tokenizer found>>!!")
def get_lr(it):
if it < warmup_iters:
return learning_rate * it / warmup_iters
if it > lr_decay_iters:
return min_lr
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (learning_rate - min_lr)
def get_batch(split):
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - window_size, (batch_size,))
x = torch.stack([torch.from_numpy((data[i:i+window_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+window_size]).astype(np.int64)) for i in ix])
if device == 'cuda':
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
else:
x, y = x.to(device), y.to(device)
return x, y
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
xv, yv = get_batch(split)
logits, loss = model(xv, yv)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
args = ModelArgs(emb_dim=emb_dim, n_heads=n_heads,
window_size=window_size, n_layers=n_layers,
dropout=dropout, batch_size=batch_size,
vocab_size=vocab_size, device=device)
model = Model(args).to(device)
print(f"Number of parameters: {model.count_params():,}")
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-1, betas=(0.9, 0.95), fused=True)
if os.path.exists(modelsave_path) and load == True:
checkpoint = torch.load(modelsave_path)
model.load_state_dict(checkpoint['model_params'])
optimizer.load_state_dict(checkpoint['optimizer_params'])
iter = checkpoint['iter']
print(f"Loaded checkpoint. Starting from iter {checkpoint['iter']:,}")
else:
iter = 0
print("?<No model checkpoint, training from scratch>?")
xb, yb = get_batch('train')
start_time = time.time()
while iter < max_iters:
for param_group in optimizer.param_groups:
param_group['lr'] = get_lr(iter)
if iter % save_interval == 0 and iter != 0:
checkpoint = {
'model_params': model.state_dict(),
'optimizer_params': optimizer.state_dict(),
'args': args,
'iter': iter
}
torch.save(checkpoint, modelsave_path)
if iter % eval_interval == 0:
losses = estimate_loss()
end_time = time.time()
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time: {end_time - start_time:.4f}s")
start_time = time.time()
for g in range(gradient_accumulation_steps):
with torch.amp.autocast(device_type=device, dtype=dtype):
logits, loss = model(xb, yb)
loss = loss / gradient_accumulation_steps
xb, yb = get_batch('train')
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
iter += 1
print("Training has FINISHED!")
context = torch.tensor((tokenizer.encode("<|endoftext|>").ids), dtype=torch.long, device=device).unsqueeze(0)
print(tokenizer.decode(model.generate(context, max_new_tokens=500)[0].tolist()))