-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtraining.py
More file actions
64 lines (50 loc) · 2.73 KB
/
training.py
File metadata and controls
64 lines (50 loc) · 2.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import tqdm
from logs import init_wandb, log_wandb, log_model_performance, save_checkpoint
def train_sae(sae, activation_store, model, cfg):
num_batches = cfg["num_tokens"] // cfg["batch_size"]
optimizer = torch.optim.Adam(sae.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"]))
pbar = tqdm.trange(num_batches)
wandb_run = init_wandb(cfg)
for i in pbar:
batch = activation_store.next_batch()
sae_output = sae(batch)
log_wandb(sae_output, i, wandb_run)
if i % cfg["perf_log_freq"] == 0:
log_model_performance(wandb_run, i, model, activation_store, sae)
if i % cfg["checkpoint_freq"] == 0:
save_checkpoint(wandb_run, sae, cfg, i)
loss = sae_output["loss"]
pbar.set_postfix({"Loss": f"{loss.item():.4f}", "L0": f"{sae_output['l0_norm']:.4f}", "L2": f"{sae_output['l2_loss']:.4f}", "L1": f"{sae_output['l1_loss']:.4f}", "L1_norm": f"{sae_output['l1_norm']:.4f}"})
loss.backward()
torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg["max_grad_norm"])
sae.make_decoder_weights_and_grad_unit_norm()
optimizer.step()
optimizer.zero_grad()
save_checkpoint(wandb_run, sae, cfg, i)
def train_sae_group(saes, activation_store, model, cfgs):
num_batches = cfgs[0]["num_tokens"] // cfgs[0]["batch_size"]
optimizers = [torch.optim.Adam(sae.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"])) for sae, cfg in zip(saes, cfgs)]
pbar = tqdm.trange(num_batches)
wandb_run = init_wandb(cfgs[0])
batch_tokens = activation_store.get_batch_tokens()
for i in pbar:
batch = activation_store.next_batch()
counter = 0
for sae, cfg, optimizer in zip(saes, cfgs, optimizers):
sae_output = sae(batch)
loss = sae_output["loss"]
log_wandb(sae_output, i, wandb_run, index=counter)
if i % cfg["perf_log_freq"] == 0:
log_model_performance(wandb_run, i, model, activation_store, sae, index=counter, batch_tokens=batch_tokens)
if i % cfg["checkpoint_freq"] == 0:
save_checkpoint(wandb_run, sae, cfg, i)
pbar.set_postfix({"Loss": f"{loss.item():.4f}", "L0": f"{sae_output['l0_norm']:.4f}", "L2": f"{sae_output['l2_loss']:.4f}", "L1": f"{sae_output['l1_loss']:.4f}", "L1_norm": f"{sae_output['l1_norm']:.4f}"})
loss.backward()
torch.nn.utils.clip_grad_norm_(sae.parameters(), cfg["max_grad_norm"])
sae.make_decoder_weights_and_grad_unit_norm()
optimizer.step()
optimizer.zero_grad()
counter += 1
for sae, cfg, optimizer in zip(saes, cfgs, optimizers):
save_checkpoint(wandb_run, sae, cfg, i)