qwen_ocr_vllm.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. import base64
  2. from io import BytesIO
  3. from PIL import Image
  4. import json
  5. import os
  6. from pathlib import Path
  7. from typing import List, Dict, Any, Optional
  8. from qwen_vl_utils import process_vision_info
  9. from transformers import AutoProcessor
  10. import time
  11. import torch
  12. from config import MODEL_PATH, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_ICON, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION, PROMPT_EXTRACT_SUPPLIER
  13. # vLLM imports
  14. from vllm import LLM, SamplingParams
  15. from vllm.multimodal.utils import fetch_image
  16. def image_to_base64(pil_image, image_format="JPEG"):
  17. """将PIL Image图像转换为Base64编码"""
  18. buffered = BytesIO()
  19. pil_image.save(buffered, format=image_format)
  20. img_byte_array = buffered.getvalue()
  21. encode_image = base64.b64encode(img_byte_array).decode('utf-8')
  22. return encode_image
  23. class QwenOcrVLLM:
  24. """基于vLLM加速框架的Qwen OCR推理类
  25. vLLM优势:
  26. 1. PagedAttention技术 - 高效的KV cache管理
  27. 2. 连续批处理 - 优化GPU利用率
  28. 3. 快速模型执行 - CUDA/cuDNN kernel优化
  29. 4. 支持量化 - AWQ, GPTQ等量化格式
  30. 5. 张量并行 - 支持多GPU推理
  31. """
  32. def __init__(
  33. self,
  34. icon_dir: str = "./icon",
  35. tensor_parallel_size: int = 1,
  36. gpu_memory_utilization: float = 0.9,
  37. max_model_len: int = 8192,
  38. dtype: str = "bfloat16",
  39. trust_remote_code: bool = True,
  40. ):
  41. """初始化vLLM模型
  42. Args:
  43. icon_dir: icon参考图像目录
  44. tensor_parallel_size: 张量并行大小(多GPU推理)
  45. gpu_memory_utilization: GPU显存利用率(0.0-1.0)
  46. max_model_len: 最大模型序列长度
  47. dtype: 数据类型("auto", "half", "float16", "bfloat16", "float", "float32")
  48. trust_remote_code: 是否信任远程代码
  49. """
  50. print("=" * 60)
  51. print("初始化 vLLM 加速推理引擎...")
  52. print("=" * 60)
  53. # 初始化vLLM引擎
  54. self.llm = LLM(
  55. model=MODEL_PATH,
  56. tensor_parallel_size=tensor_parallel_size,
  57. gpu_memory_utilization=gpu_memory_utilization,
  58. max_model_len=max_model_len,
  59. dtype=dtype,
  60. trust_remote_code=trust_remote_code,
  61. # 视觉模型特定参数
  62. limit_mm_per_prompt={"image": 10}, # 每个prompt最多支持10张图像
  63. )
  64. # 加载processor用于消息模板处理
  65. self.processor = AutoProcessor.from_pretrained(
  66. MODEL_PATH,
  67. trust_remote_code=trust_remote_code
  68. )
  69. # 加载icon参考图像
  70. self.icon_dir = icon_dir
  71. # self.icon_images = self._load_icon_images()
  72. # 默认采样参数
  73. self.default_sampling_params = SamplingParams(
  74. temperature=0.0, # 使用贪婪解码
  75. top_p=1.0,
  76. max_tokens=512, # 最大生成token数
  77. stop_token_ids=None,
  78. skip_special_tokens=True,
  79. )
  80. print("=" * 60)
  81. print("vLLM 引擎初始化完成!")
  82. print(f"- 模型路径: {MODEL_PATH}")
  83. print(f"- 张量并行: {tensor_parallel_size} GPU(s)")
  84. print(f"- 显存利用率: {gpu_memory_utilization * 100:.1f}%")
  85. print(f"- 数据类型: {dtype}")
  86. print(f"- 最大序列长度: {max_model_len}")
  87. print("=" * 60)
  88. # 模型预热
  89. print("模型预热中...")
  90. self._warmup()
  91. print("模型预热完成!")
  92. print("=" * 60)
  93. def _load_icon_images(self) -> Dict[str, Image.Image]:
  94. """加载icon目录下的所有参考图像"""
  95. icon_images = {}
  96. icon_path = Path(self.icon_dir)
  97. if not icon_path.exists():
  98. print(f"警告: icon目录 {self.icon_dir} 不存在")
  99. return icon_images
  100. # 加载所有png图像文件
  101. for icon_file in icon_path.glob("*.png"):
  102. icon_name = icon_file.stem # 获取文件名(不含扩展名), 如 GHS01
  103. try:
  104. icon_image = Image.open(icon_file).convert("RGB")
  105. icon_images[icon_name] = icon_image
  106. print(f"已加载icon参考图像: {icon_name}")
  107. except Exception as e:
  108. print(f"加载icon图像 {icon_file} 失败: {e}")
  109. return icon_images
  110. def _warmup(self):
  111. """预热模型以触发编译和优化"""
  112. dummy_image = Image.new('RGB', (224, 224), color='white')
  113. prompt = PROMPT_EXTRACT_NAME
  114. try:
  115. self.inference(dummy_image, prompt, warmup=True)
  116. except Exception as e:
  117. print(f"预热过程中出现警告(可忽略): {e}")
  118. def _build_messages(self, image: Image.Image, prompt: str) -> List[Dict]:
  119. """构建消息格式
  120. Args:
  121. image: PIL Image对象
  122. prompt: 提示词文本
  123. Returns:
  124. 消息列表
  125. """
  126. messages = [
  127. {
  128. "role": "user",
  129. "content": [
  130. {
  131. "type": "image",
  132. "image": image,
  133. },
  134. {
  135. "type": "text",
  136. "text": prompt
  137. },
  138. ],
  139. }
  140. ]
  141. return messages
  142. def _prepare_inputs(
  143. self,
  144. messages: List[Dict]
  145. ) -> Dict[str, Any]:
  146. """准备vLLM输入格式
  147. Args:
  148. messages: 消息列表
  149. Returns:
  150. 包含prompt和multi_modal_data的字典
  151. """
  152. # 应用chat模板
  153. text = self.processor.apply_chat_template(
  154. messages,
  155. tokenize=False,
  156. add_generation_prompt=True
  157. )
  158. # 处理视觉信息
  159. image_inputs, video_inputs = process_vision_info(messages)
  160. # vLLM 0.6.0+ 新版API格式
  161. # 直接返回包含文本和多模态数据的字典
  162. inputs = {
  163. "prompt": text,
  164. "multi_modal_data": {
  165. "image": image_inputs[0] if image_inputs else None
  166. }
  167. }
  168. return inputs
  169. def inference(
  170. self,
  171. image: Image.Image,
  172. prompt: str,
  173. warmup: bool = False,
  174. sampling_params: Optional[SamplingParams] = None
  175. ) -> List[str]:
  176. """OCR推理
  177. Args:
  178. image: PIL Image对象
  179. prompt: 提示词
  180. warmup: 是否为预热模式
  181. sampling_params: 自定义采样参数
  182. Returns:
  183. 生成的文本列表
  184. """
  185. # 构建消息
  186. messages = self._build_messages(image, prompt)
  187. # 准备输入
  188. inputs = self._prepare_inputs(messages)
  189. # 使用默认或自定义采样参数
  190. params = sampling_params if sampling_params else self.default_sampling_params
  191. # vLLM 0.6.0+ 新版API:直接传递inputs字典
  192. outputs = self.llm.generate(
  193. inputs,
  194. params
  195. )
  196. # 提取生成的文本
  197. generated_texts = [output.outputs[0].text for output in outputs]
  198. if not warmup:
  199. return generated_texts
  200. return generated_texts
  201. def batch_inference(
  202. self,
  203. images: List[Image.Image],
  204. prompts: List[str],
  205. sampling_params: Optional[SamplingParams] = None
  206. ) -> List[str]:
  207. """批量OCR推理(vLLM的核心优势)
  208. Args:
  209. images: PIL Image对象列表
  210. prompts: 提示词列表
  211. sampling_params: 自定义采样参数
  212. Returns:
  213. 生成的文本列表
  214. """
  215. if len(images) != len(prompts):
  216. raise ValueError(f"images数量({len(images)})和prompts数量({len(prompts)})不匹配")
  217. # 准备所有输入
  218. all_inputs = []
  219. for image, prompt in zip(images, prompts):
  220. messages = self._build_messages(image, prompt)
  221. inputs = self._prepare_inputs(messages)
  222. all_inputs.append(inputs)
  223. # 使用默认或自定义采样参数
  224. params = sampling_params if sampling_params else self.default_sampling_params
  225. # vLLM 0.6.0+ 新版API:直接传递inputs列表
  226. outputs = self.llm.generate(
  227. all_inputs,
  228. params
  229. )
  230. # 提取生成的文本
  231. generated_texts = [output.outputs[0].text for output in outputs]
  232. return generated_texts
  233. def extract_icons(self, image: Image.Image) -> List[str]:
  234. """识别图像中的象形图标识
  235. Args:
  236. image: PIL Image对象 - 待识别的化学品标签图像
  237. Returns:
  238. 生成的文本列表
  239. """
  240. # 构建包含所有参考图像的messages
  241. messages = [
  242. {
  243. "role": "user",
  244. "content": []
  245. }
  246. ]
  247. # 添加所有icon参考图像
  248. content_list = messages[0]["content"]
  249. # 按GHS编号顺序添加参考图像
  250. sorted_icons = sorted(self.icon_images.items(), key=lambda x: x[0])
  251. for icon_name, icon_image in sorted_icons:
  252. content_list.append({
  253. "type": "image",
  254. "image": icon_image,
  255. })
  256. content_list.append({
  257. "type": "text",
  258. "text": f"参考图像:{icon_name}"
  259. })
  260. # 添加待识别的图像
  261. content_list.append({
  262. "type": "image",
  263. "image": image,
  264. })
  265. # 添加提示词
  266. content_list.append({
  267. "type": "text",
  268. "text": PROMPT_EXTRACT_ICON
  269. })
  270. # 准备输入
  271. inputs = self._prepare_inputs(messages)
  272. # vLLM 0.6.0+ 新版API:直接传递inputs字典
  273. outputs = self.llm.generate(
  274. inputs,
  275. self.default_sampling_params
  276. )
  277. # 提取生成的文本
  278. generated_texts = [output.outputs[0].text for output in outputs]
  279. return generated_texts
  280. def __del__(self):
  281. """析构函数,清理资源"""
  282. try:
  283. if hasattr(self, 'llm'):
  284. del self.llm
  285. print("vLLM引擎已释放")
  286. except:
  287. pass
  288. if __name__ == '__main__':
  289. # 测试代码
  290. print("初始化 QwenOcrVLLM...")
  291. qwen_ocr = QwenOcrVLLM(
  292. tensor_parallel_size=1, # 单GPU
  293. gpu_memory_utilization=0.9, # 使用90%显存
  294. max_model_len=8192, # 最大序列长度
  295. dtype="bfloat16" # 使用bfloat16精度
  296. )
  297. # 测试单张图像推理
  298. print("\n" + "=" * 60)
  299. print("测试单张图像推理...")
  300. print("=" * 60)
  301. test_image_path = "./test3.jpg"
  302. if os.path.exists(test_image_path):
  303. image = Image.open(test_image_path).convert("RGB")
  304. # 测试提取名称
  305. start_time = time.time()
  306. result = qwen_ocr.inference(image, PROMPT_EXTRACT_PREVENTION)
  307. elapsed = time.time() - start_time
  308. print(f"\n推理耗时: {elapsed:.3f}秒")
  309. print(f"提取结果:\n{result[0]}")
  310. else:
  311. print(f"测试图像 {test_image_path} 不存在,跳过测试")
  312. # 测试批量推理
  313. print("\n" + "=" * 60)
  314. print("测试批量推理...")
  315. print("=" * 60)
  316. if os.path.exists(test_image_path):
  317. # 创建3张相同的测试图像
  318. images = [Image.open(test_image_path).convert("RGB") for _ in range(6)]
  319. prompts = [PROMPT_EXTRACT_ICON, PROMPT_EXTRACT_NAME, PROMPT_EXTRACT_COMPONENTS, PROMPT_EXTRACT_KEYWORD, PROMPT_EXTRACT_PREVENTION, PROMPT_EXTRACT_SUPPLIER]
  320. start_time = time.time()
  321. results = qwen_ocr.batch_inference(images, prompts)
  322. elapsed = time.time() - start_time
  323. print(results[0])
  324. print(results[1])
  325. print(results[2])
  326. print(results[3])
  327. print(results[4])
  328. print(results[5])
  329. print(f"\n批量推理耗时: {elapsed:.3f}秒")
  330. print(f"平均每张: {elapsed/3:.3f}秒")
  331. print(f"批量推理加速比: {(elapsed/3):.3f}秒/张 vs 单张推理")
  332. print("\n" + "=" * 60)
  333. print("测试完成!")
  334. print("=" * 60)