Skip to content

Commit c00ff92

Browse files
authored
scripts: add script to compare logprobs of llama.cpp against other frameworks (#17947)
* scripts: add script to compare logits of llama.cpp against other frameworks * accept custom prompt file * fix code style * clarify endpoint * fix displaying * use abs for diff * fix vllm case * rm output file * rename to compare-logprobs * add "pattern"
1 parent 4ed2bae commit c00ff92

File tree

1 file changed

+281
-0
lines changed

1 file changed

+281
-0
lines changed

scripts/compare-logprobs.py

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
import argparse
2+
import requests
3+
import json
4+
from pathlib import Path
5+
import logging
6+
7+
logger = logging.getLogger("compare-logprobs")
8+
logging.basicConfig(level=logging.INFO)
9+
10+
11+
DESCRIPTION = """
12+
Compare logits between llama.cpp and another inference engine using OpenAI-compatible server endpoints.
13+
14+
Unlike compare-logits.py, it allows dumping logits from a hosted API endpoint. Useful when it's not possible to run both models locally.
15+
16+
Example usage:
17+
Step 1: Dump logits from two different servers
18+
python scripts/compare-logprobs.py dump logits_llama.log http://localhost:8080/v1/completions
19+
python scripts/compare-logprobs.py dump logits_other.log http://other-engine:8000/v1/completions
20+
21+
(optionally, you can add --api-key <key> if the endpoint requires authentication)
22+
23+
Step 2: Compare the dumped logits
24+
python scripts/compare-logprobs.py compare logits_llama.log logits_other.log report.md
25+
"""
26+
27+
28+
def generate_input_prompt(length: int) -> list[str]:
29+
CORPUS = """
30+
You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls.
31+
32+
### Tool Call Format:
33+
When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text.
34+
35+
You can make multiple calls in one go by placing them one after another.
36+
"""
37+
words = [w.strip() for w in CORPUS.strip().split(" ")]
38+
words = [w for w in words if len(w) > 0] # filter out empty strings
39+
while len(words) < length:
40+
words += words
41+
return words[:length]
42+
43+
44+
def dump_logits(
45+
endpoint: str,
46+
output_path: Path,
47+
input_words: list[str],
48+
pattern: list[tuple[bool, int]],
49+
api_key=None,
50+
):
51+
logger.info(f"Dumping logits to {output_path} from endpoint {endpoint}...")
52+
words = input_words
53+
curr_text = ""
54+
n_total = sum(n for get, n in pattern if get)
55+
n_done = 0
56+
i_cur = 0
57+
i_total = len(words)
58+
with output_path.open("w") as f:
59+
for get, n in pattern:
60+
if not get:
61+
# skip n words
62+
for i in range(n):
63+
curr_text += words.pop(0) + " "
64+
i_cur += 1
65+
continue
66+
# get n words
67+
for i in range(n):
68+
curr_text += words.pop(0) + " "
69+
payload = {
70+
"prompt": curr_text.strip(),
71+
"temperature": 0.0,
72+
"top_k": 1,
73+
"max_tokens": 1,
74+
"logprobs": 1,
75+
"stream": False,
76+
}
77+
response = requests.post(
78+
endpoint,
79+
json=payload,
80+
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
81+
)
82+
response.raise_for_status()
83+
data = response.json()
84+
data["__index"] = i_cur # add index for easier debugging later
85+
data = json.dumps(data)
86+
f.write(f"{data}\n")
87+
n_done += 1
88+
i_cur += 1
89+
logger.info(
90+
f"\n\n{data}\n\n[Step: {n_done}/{n_total} | Word: {i_cur}/{i_total}]"
91+
)
92+
logger.info(f"Logits dumped to {output_path}")
93+
94+
95+
def get_token_logprobs(data: dict):
96+
logprobs = data["choices"][0]["logprobs"]
97+
if "content" in logprobs:
98+
# llama.cpp case
99+
top = logprobs["content"][0]["top_logprobs"][0]
100+
return top["token"], top["logprob"]
101+
else:
102+
# vllm case
103+
tokens = logprobs["tokens"]
104+
token_logprobs = logprobs["token_logprobs"]
105+
return tokens[0], token_logprobs[0]
106+
107+
108+
def clean_text(text: str) -> str:
109+
return (
110+
"'"
111+
+ text.replace("\n", "\\n")
112+
.replace("\t", "\\t")
113+
.replace("\r", "\\r")
114+
.replace("|", "\\|")
115+
+ "'"
116+
)
117+
118+
119+
def compare_logits(input1: Path, input2: Path, output_path: Path):
120+
with input1.open("r") as f1, input2.open("r") as f2, output_path.open("w") as fout:
121+
lines1 = f1.readlines()
122+
lines2 = f2.readlines()
123+
124+
tab_header = [
125+
"idx",
126+
input1.name,
127+
"logprob_1",
128+
input2.name,
129+
"logprob_2",
130+
"diff (abs)",
131+
]
132+
tab_entries = []
133+
tab_max_widths = [len(h) for h in tab_header]
134+
135+
assert len(lines1) == len(
136+
lines2
137+
), "Input files must have the same number of lines."
138+
139+
fout.write("# Logits Comparison Report\n\n")
140+
for i, (line1, line2) in enumerate(zip(lines1, lines2)):
141+
if not line1.strip() or not line2.strip():
142+
continue # skip empty lines
143+
144+
data1 = json.loads(line1)
145+
data2 = json.loads(line2)
146+
147+
idx1 = data1.get("__index", -1)
148+
idx2 = data2.get("__index", -1)
149+
if idx1 != idx2:
150+
logger.warning(
151+
f"Warning: Mismatched indices at line {i}: {idx1} vs {idx2}"
152+
)
153+
154+
token1, logprob1 = get_token_logprobs(data1)
155+
token2, logprob2 = get_token_logprobs(data2)
156+
157+
token1 = clean_text(token1)
158+
token2 = clean_text(token2)
159+
abs_diff = abs(logprob1 - logprob2)
160+
161+
tab_entries.append(
162+
(
163+
str(idx1 + 1),
164+
token1,
165+
f"{logprob1:.4f}",
166+
token2,
167+
f"{logprob2:.4f}",
168+
f"{(abs_diff):.4f}",
169+
)
170+
)
171+
172+
for i in range(len(tab_entries)):
173+
for j in range(len(tab_header)):
174+
tab_max_widths[j] = max(tab_max_widths[j], len(tab_entries[i][j]))
175+
176+
output = ""
177+
for j in range(len(tab_header)):
178+
output += f"| {tab_header[j]:<{tab_max_widths[j]}} "
179+
output += "|\n"
180+
for j in range(len(tab_header)):
181+
output += f"|{'-' * (tab_max_widths[j] + 2)}"
182+
output += "|\n"
183+
for entry in tab_entries:
184+
for j in range(len(tab_header)):
185+
output += f"| {entry[j]:<{tab_max_widths[j]}} "
186+
output += "|\n"
187+
188+
logger.info("\n" + output)
189+
fout.write(output)
190+
logger.info(f"Report written to {output_path}")
191+
192+
193+
def parse_pattern(pattern: str) -> list[tuple[bool, int]]:
194+
parts = pattern.split(",")
195+
result = []
196+
for i, part in enumerate(parts):
197+
n = int(part)
198+
if i % 2 == 0:
199+
result.append((True, n)) # get n words
200+
else:
201+
result.append((False, n)) # skip n words
202+
return result
203+
204+
205+
def parse_args() -> argparse.Namespace:
206+
parser = argparse.ArgumentParser(
207+
description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter
208+
)
209+
subparsers = parser.add_subparsers(
210+
dest="verb", required=True, help="action to perform"
211+
)
212+
213+
# dump subcommand
214+
parser_dump = subparsers.add_parser("dump", help="dump logits from an endpoint")
215+
parser_dump.add_argument(
216+
"output", type=Path, help="output path for dumped logits (.log)"
217+
)
218+
parser_dump.add_argument(
219+
"endpoint", type=str, help="OAI-compat /completions endpoint"
220+
)
221+
parser_dump.add_argument(
222+
"--api-key",
223+
type=str,
224+
default=None,
225+
help="API key for authentication (if required)",
226+
)
227+
parser_dump.add_argument(
228+
"--file",
229+
type=Path,
230+
default=None,
231+
help="File containing prompt to use instead of the default",
232+
)
233+
parser_dump.add_argument(
234+
"--pattern",
235+
type=str,
236+
default="10,1000,10,4000,10",
237+
help="Pattern n_get,n_skip,... where n_get is number of words to get and n_skip is number of words to skip (num of words, NOT num of tokens)",
238+
)
239+
240+
# compare subcommand
241+
parser_compare = subparsers.add_parser(
242+
"compare", help="compare two dumped logits files"
243+
)
244+
parser_compare.add_argument("input1", type=Path, help="first input file (.log)")
245+
parser_compare.add_argument("input2", type=Path, help="second input file (.log)")
246+
parser_compare.add_argument(
247+
"output", type=Path, help="output path for comparison report (.md)"
248+
)
249+
250+
try:
251+
return parser.parse_args()
252+
except Exception as e:
253+
parser.print_help()
254+
raise e
255+
256+
257+
def main():
258+
args = parse_args()
259+
260+
if args.verb == "dump":
261+
pattern = parse_pattern(args.pattern)
262+
input_length = sum(n for _, n in pattern)
263+
input_words = generate_input_prompt(input_length)
264+
if args.file is not None:
265+
with args.file.open("r") as f:
266+
input_words = f.read().strip().split(" ")
267+
if input_length < sum(n for _, n in pattern):
268+
raise ValueError(
269+
f"Input file has only {input_length} words, but pattern requires at least {input_length} words."
270+
)
271+
input_length = len(input_words)
272+
logger.info(f"Using {input_length} words")
273+
dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key)
274+
elif args.verb == "compare":
275+
compare_logits(args.input1, args.input2, args.output)
276+
else:
277+
raise ValueError(f"Unknown verb: {args.verb}")
278+
279+
280+
if __name__ == "__main__":
281+
main()

0 commit comments

Comments
 (0)