-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCLI.py
More file actions
34 lines (29 loc) · 1.27 KB
/
CLI.py
File metadata and controls
34 lines (29 loc) · 1.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import json
import onnxruntime as ort
import numpy as np
config = json.load(open("config.json"))
vocab = json.load(open("vocab.json"))
chars = vocab["chars"]
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
def decode(ids):
return "".join(itos[i] for i in ids)
model_path = f"{config['model_name']}.onnx"
session = ort.InferenceSession(model_path, providers=config.get("providers", ['CUDAExecutionProvider']))
def generate(prompt: str, max_new_tokens=120, temperature=0.8):
input_ids = [stoi.get(c, 0) for c in prompt]
for _ in range(max_new_tokens):
input_array = np.array([input_ids[-128:]], dtype=np.int64)
logits = session.run(['logits'], {'input_ids': input_array})[0]
next_token_logits = logits[0, -1, :] / temperature
probs = np.exp(next_token_logits) / np.sum(np.exp(next_token_logits))
next_id = np.random.choice(len(probs), p=probs)
input_ids.append(next_id)
return decode(input_ids)
print(f"LLM Chat with {config['model_name']} (type 'exit' to quit)\n")
while True:
user_input = input("> ").strip()
if user_input.lower() in {"exit", "quit"}:
break
output = generate(user_input, max_new_tokens=config.get("max_new_tokens", 120))
print("\n" + output + "\n")