test_sglang.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. # The idea is that you have a Qwen2-VL-7B model located here:s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"
  2. # You need to load it in both hugging face transformers, and send page 1 of edgar.pdf to it from tests/gnarly_pdfs
  3. # Compare that the temperature 0 sampled result is the same
  4. import asyncio
  5. import base64
  6. import json
  7. import math
  8. import os
  9. import unittest
  10. from io import BytesIO
  11. from pathlib import Path
  12. from unittest.mock import AsyncMock, patch
  13. import numpy as np
  14. import pytest
  15. import torch
  16. import torch.nn.functional as F
  17. from httpx import AsyncClient
  18. from PIL import Image
  19. from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration
  20. from olmocr.pipeline import (
  21. SGLANG_SERVER_PORT,
  22. build_page_query,
  23. get_anchor_text,
  24. render_pdf_to_base64png,
  25. sglang_server_ready,
  26. sglang_server_task,
  27. )
  28. from olmocr.prompts import PageResponse
  29. MODEL_FINETUNED_PATH = (
  30. "s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"
  31. )
  32. @pytest.mark.nonci
  33. class TestSglangServer(unittest.IsolatedAsyncioTestCase):
  34. async def asyncSetUp(self):
  35. # Mock arguments
  36. self.args = AsyncMock()
  37. self.args.workspace = "/tmp/test_workspace"
  38. self.args.model = [MODEL_FINETUNED_PATH]
  39. self.args.model_chat_template = "qwen2-vl"
  40. self.args.target_longest_image_dim = 1024
  41. self.args.target_anchor_text_len = 6000
  42. self.args.model_max_context = 8192
  43. # Create a temporary workspace directory
  44. os.makedirs(self.args.workspace, exist_ok=True)
  45. # Set up a semaphore for server tasks
  46. self.semaphore = asyncio.Semaphore(1)
  47. self.maxDiff = None
  48. # # Start the sglang server
  49. # self.my_server_task = asyncio.create_task(sglang_server_task(self.args, self.semaphore))
  50. # # Wait for the server to become ready
  51. # await sglang_server_ready()
  52. async def test_sglang_server_initialization_and_request(self):
  53. # Mock data paths
  54. self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "ambiguous.pdf"))
  55. # Send a single request to the sglang server for page 1
  56. async with AsyncClient(timeout=600) as session:
  57. query = await build_page_query(
  58. str(self.test_pdf_path),
  59. page=1,
  60. target_longest_image_dim=self.args.target_longest_image_dim,
  61. target_anchor_text_len=self.args.target_anchor_text_len,
  62. )
  63. COMPLETION_URL = f"http://localhost:{30000}/v1/chat/completions"
  64. query["temperature"] = 0.0
  65. query["logprobs"] = True
  66. query["top_logprobs"] = 5
  67. response = await session.post(COMPLETION_URL, json=query)
  68. print(response.text)
  69. # Check the server response
  70. self.assertEqual(response.status_code, 200)
  71. response_data = response.json()
  72. self.assertIn("choices", response_data)
  73. self.assertGreater(len(response_data["choices"]), 0)
  74. model_response_json = json.loads(response_data["choices"][0]["message"]["content"])
  75. page_response = PageResponse(**model_response_json)
  76. print(page_response)
  77. self.assertEqual(page_response.natural_text, EDGAR_TEXT)
  78. async def asyncTearDown(self):
  79. pass
  80. # # Shut down the server
  81. # self.my_server_task.cancel()
  82. # with self.assertRaises(asyncio.CancelledError):
  83. # await self.my_server_task
  84. # # Cleanup temporary workspace
  85. # if os.path.exists(self.args.workspace):
  86. # for root, _, files in os.walk(self.args.workspace):
  87. # for file in files:
  88. # os.unlink(os.path.join(root, file))
  89. # os.rmdir(self.args.workspace)
  90. @pytest.mark.nonci
  91. class TestHuggingFaceModel(unittest.IsolatedAsyncioTestCase):
  92. async def asyncSetUp(self):
  93. # Set up the Hugging Face model and tokenizer
  94. model_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "olmocr", "model")
  95. download_directory([MODEL_FINETUNED_PATH], model_cache_dir)
  96. # Check the rope config and make sure it's got the proper key
  97. with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
  98. config_data = json.load(cfin)
  99. if "rope_type" in config_data["rope_scaling"]:
  100. del config_data["rope_scaling"]["rope_type"]
  101. config_data["rope_scaling"]["type"] = "mrope"
  102. with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
  103. json.dump(config_data, cfout)
  104. self.tokenizer = AutoTokenizer.from_pretrained(model_cache_dir, trust_remote_code=True)
  105. self.image_token_id = self.tokenizer.encode("<|image_pad|>")[0]
  106. self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_cache_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).eval()
  107. self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
  108. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  109. self.model.to(self.device)
  110. # Path to the test PDF
  111. self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "ambiguous.pdf"))
  112. self.maxDiff = None
  113. async def test_hugging_face_generation(self):
  114. query = await build_page_query(
  115. str(self.test_pdf_path),
  116. page=1,
  117. target_longest_image_dim=1024,
  118. target_anchor_text_len=6000,
  119. )
  120. messages = query["messages"]
  121. # Apply chat template to get the text
  122. text = self.processor.apply_chat_template(query["messages"], tokenize=False, add_generation_prompt=True)
  123. image_url = query["messages"][0]["content"][1]["image_url"]["url"]
  124. # Remove the "data:image/png;base64," prefix
  125. base64_image = image_url.split(",")[1]
  126. # Decode the base64 string into bytes
  127. image_data = base64.b64decode(base64_image)
  128. # Create a BytesIO object and load it into a PIL image
  129. main_image = Image.open(BytesIO(image_data))
  130. # Process inputs using processor
  131. inputs = self.processor(
  132. text=[text],
  133. images=[main_image],
  134. padding=True,
  135. return_tensors="pt",
  136. )
  137. image_indices = [idx for idx, token in enumerate(inputs["input_ids"][0]) if token.item() == self.image_token_id]
  138. print("IMAGE INDICES", image_indices)
  139. print(f"image_grid_thw - {inputs['image_grid_thw'].shape} {inputs['image_grid_thw']}")
  140. print(f"pixel_values - {inputs['pixel_values'].shape} {inputs['pixel_values'].detach().cpu().numpy()}")
  141. np.save("/root/pixel_values.npy", inputs["pixel_values"].detach().cpu().numpy())
  142. inputs = {key: value.to(self.device) for (key, value) in inputs.items()}
  143. generated_tokens = []
  144. max_steps = 50
  145. top_logprobs_hf = []
  146. for step in range(max_steps):
  147. # Generate the output with temperature=0
  148. generation_output = self.model.generate(
  149. **inputs,
  150. temperature=0.0,
  151. max_new_tokens=1,
  152. # max_length=8192,
  153. num_return_sequences=1,
  154. do_sample=False,
  155. output_scores=True,
  156. return_dict_in_generate=True,
  157. )
  158. # Extract the generated token's log probabilities
  159. scores = generation_output.scores # Tuple of length 1
  160. logits = scores[0] # Tensor of shape (batch_size, vocab_size)
  161. log_probs = F.log_softmax(logits, dim=-1) # Apply log softmax to get log probabilities
  162. # Get top 5 tokens and their log probabilities
  163. topk_log_probs, topk_indices = torch.topk(log_probs[0], k=5)
  164. topk_tokens = self.tokenizer.convert_ids_to_tokens(topk_indices.tolist())
  165. top_logprobs_hf.append((topk_tokens, topk_log_probs.tolist()))
  166. # Pick the top token
  167. next_token_id = topk_indices[0].unsqueeze(0).unsqueeze(0) # Shape: (1, 1)
  168. next_token_str = self.tokenizer.convert_ids_to_tokens([next_token_id.item()])[0]
  169. generated_tokens.append(next_token_id.item())
  170. # Append the next token to input_ids and update attention_mask
  171. inputs["input_ids"] = torch.cat([inputs["input_ids"], next_token_id], dim=-1)
  172. inputs["attention_mask"] = torch.cat([inputs["attention_mask"], torch.ones((1, 1), dtype=inputs["attention_mask"].dtype).to(self.device)], dim=-1)
  173. print(self.tokenizer.decode(generated_tokens))
  174. # Now take all the input ids and run them through sglang as a comparison
  175. async with AsyncClient(timeout=600) as session:
  176. query["temperature"] = 0.0
  177. query["max_tokens"] = max_steps
  178. query["logprobs"] = True
  179. query["top_logprobs"] = 5
  180. COMPLETION_URL = f"http://localhost:{30000}/v1/chat/completions"
  181. response = await session.post(COMPLETION_URL, json=query)
  182. response_data = response.json()
  183. for step, lptok in enumerate(response_data["choices"][0]["logprobs"]["content"]):
  184. print("\nTop 5 tokens and their log probabilities:")
  185. (topk_tokens, topk_log_probs) = top_logprobs_hf[step]
  186. for token, log_prob, lptokcur in zip(topk_tokens, topk_log_probs, lptok["top_logprobs"]):
  187. print(
  188. f"HF Token: {token} Log Prob: {log_prob:.2f} Prob {math.exp(log_prob)*100:.2f}% SGLANG Token {lptokcur['token']} Logprob {lptokcur['logprob']:.2f} Prob {math.exp(lptokcur['logprob'])*100:.2f}%"
  189. )
  190. async def asyncTearDown(self):
  191. # Clean up the model and tokenizer
  192. del self.model
  193. del self.tokenizer
  194. torch.cuda.empty_cache()
  195. @pytest.mark.nonci
  196. class RawSGLangTest(unittest.IsolatedAsyncioTestCase):
  197. def setUp(self):
  198. # Set up the Hugging Face model and tokenizer
  199. model_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "olmocr", "model")
  200. download_directory([MODEL_FINETUNED_PATH], model_cache_dir)
  201. # Check the rope config and make sure it's got the proper key
  202. with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
  203. config_data = json.load(cfin)
  204. if "rope_type" in config_data["rope_scaling"]:
  205. del config_data["rope_scaling"]["rope_type"]
  206. config_data["rope_scaling"]["type"] = "mrope"
  207. with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
  208. json.dump(config_data, cfout)
  209. self.model_cache_dir = model_cache_dir
  210. self.tokenizer = AutoTokenizer.from_pretrained(model_cache_dir, trust_remote_code=True)
  211. self.image_token_id = self.tokenizer.encode("<|image_pad|>")[0]
  212. self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_cache_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).eval()
  213. self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
  214. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  215. self.model.to(self.device)
  216. # Path to the test PDF
  217. self.test_pdf_path = Path(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "ambiguous.pdf"))
  218. self.maxDiff = None
  219. async def test_vision_encoder(self):
  220. query = await build_page_query(
  221. str(self.test_pdf_path),
  222. page=1,
  223. target_longest_image_dim=1024,
  224. target_anchor_text_len=6000,
  225. )
  226. messages = query["messages"]
  227. # Apply chat template to get the text
  228. text = self.processor.apply_chat_template(query["messages"], tokenize=False, add_generation_prompt=True)
  229. image_url = query["messages"][0]["content"][1]["image_url"]["url"]
  230. # Remove the "data:image/png;base64," prefix
  231. base64_image = image_url.split(",")[1]
  232. # Decode the base64 string into bytes
  233. image_data = base64.b64decode(base64_image)
  234. # Create a BytesIO object and load it into a PIL image
  235. main_image = Image.open(BytesIO(image_data))
  236. # Process inputs using processor
  237. inputs = self.processor(
  238. text=[text],
  239. images=[main_image],
  240. padding=True,
  241. return_tensors="pt",
  242. )
  243. with torch.no_grad():
  244. hf_output = self.model.visual(inputs["pixel_values"].to(self.device), grid_thw=inputs["image_grid_thw"].to(self.device))
  245. print("HF", hf_output, hf_output.shape)
  246. from sglang.srt.configs.model_config import ModelConfig
  247. from sglang.srt.hf_transformers_utils import get_tokenizer
  248. from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
  249. from sglang.srt.model_executor.forward_batch_info import ForwardBatch
  250. from sglang.srt.model_executor.model_runner import ModelRunner
  251. from sglang.srt.sampling.sampling_params import SamplingParams
  252. from sglang.srt.server_args import PortArgs, ServerArgs
  253. model_config = ModelConfig(self.model_cache_dir, model_override_args="{}")
  254. server_args = ServerArgs(model_path=self.model_cache_dir)
  255. # Initialize model runner
  256. model_runner = ModelRunner(
  257. model_config=model_config,
  258. mem_fraction_static=0.8,
  259. gpu_id=0,
  260. tp_rank=0,
  261. tp_size=1,
  262. nccl_port=12435,
  263. server_args=server_args,
  264. )
  265. print(model_runner)
  266. with torch.no_grad():
  267. sglang_output = model_runner.model.visual(inputs["pixel_values"].to(self.device), grid_thw=inputs["image_grid_thw"].to(self.device))
  268. print("SGLANG", sglang_output, sglang_output.shape)
  269. # Convert to float32 for numerical stability if needed
  270. hf = hf_output.float()
  271. sg = sglang_output.float()
  272. # Basic shape and dtype comparison
  273. print("\n=== Basic Properties ===")
  274. print(f"Shapes match: {hf.shape == sg.shape}")
  275. print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}")
  276. print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}")
  277. # Move tensors to CPU for numpy operations
  278. hf_np = hf.cpu().numpy()
  279. sg_np = sg.cpu().numpy()
  280. # Statistical metrics
  281. print("\n=== Statistical Metrics ===")
  282. print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}")
  283. print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}")
  284. print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}")
  285. print(f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}")
  286. # Cosine similarity (across feature dimension)
  287. cos_sim = F.cosine_similarity(hf, sg)
  288. print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}")
  289. print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}")
  290. # Find largest absolute differences
  291. print("\n=== Largest Absolute Differences ===")
  292. diffs = torch.abs(hf - sg)
  293. flat_diffs = diffs.flatten()
  294. # Get indices of top 10 differences
  295. top_k = 10
  296. top_values, top_flat_indices = torch.topk(flat_diffs, top_k)
  297. # Convert flat indices to multidimensional indices
  298. top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape)
  299. print(f"\nTop {top_k} largest absolute differences:")
  300. print("Index".ljust(30) + "Difference".ljust(15) + "HF Value".ljust(15) + "SGLang Value")
  301. print("-" * 75)
  302. for i in range(top_k):
  303. # Get the index tuple for this difference
  304. idx = tuple(dim[i] for dim in top_indices)
  305. diff_val = top_values[i].item()
  306. hf_val = hf[idx].item()
  307. sg_val = sg[idx].item()
  308. # Format the index tuple and values
  309. idx_str = str(idx)
  310. print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}")