benchmark_throughput.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. """Benchmark offline inference throughput."""
  2. import argparse
  3. import base64
  4. import json
  5. import random
  6. import time
  7. from io import BytesIO
  8. from typing import List, Optional, Tuple
  9. import torch
  10. import uvloop
  11. from PIL import Image
  12. from tqdm import tqdm
  13. from transformers import (
  14. AutoModelForCausalLM,
  15. AutoProcessor,
  16. AutoTokenizer,
  17. PreTrainedTokenizerBase,
  18. )
  19. from vllm import TokensPrompt
  20. from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
  21. from vllm.entrypoints.openai.api_server import (
  22. build_async_engine_client_from_engine_args,
  23. )
  24. from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
  25. from vllm.sampling_params import BeamSearchParams
  26. from vllm.utils import FlexibleArgumentParser, merge_async_iterators
  27. def sample_requests(
  28. dataset_path: str,
  29. num_requests: int,
  30. tokenizer: PreTrainedTokenizerBase,
  31. fixed_output_len: Optional[int],
  32. ) -> List[Tuple[str, int, int]]:
  33. if fixed_output_len is not None and fixed_output_len < 4:
  34. raise ValueError("output_len too small")
  35. # Load the dataset.
  36. with open(dataset_path) as f:
  37. dataset = json.load(f)
  38. # Filter out the conversations with less than 2 turns.
  39. dataset = [data for data in dataset if len(data["conversations"]) >= 2]
  40. # Only keep the first two turns of each conversation.
  41. dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset]
  42. # Shuffle the dataset.
  43. random.shuffle(dataset)
  44. # Filter out sequences that are too long or too short
  45. filtered_dataset: List[Tuple[str, int, int]] = []
  46. for i in range(len(dataset)):
  47. if len(filtered_dataset) == num_requests:
  48. break
  49. # Tokenize the prompts and completions.
  50. prompt = dataset[i][0]
  51. prompt_token_ids = tokenizer(prompt).input_ids
  52. completion = dataset[i][1]
  53. completion_token_ids = tokenizer(completion).input_ids
  54. prompt_len = len(prompt_token_ids)
  55. output_len = len(completion_token_ids) if fixed_output_len is None else fixed_output_len
  56. if prompt_len < 4 or output_len < 4:
  57. # Prune too short sequences.
  58. continue
  59. if prompt_len > 1024 or prompt_len + output_len > 2048:
  60. # Prune too long sequences.
  61. continue
  62. filtered_dataset.append((prompt, prompt_len, output_len))
  63. return filtered_dataset
  64. def sample_mm_requests_qwen2vl(
  65. dataset_path: str,
  66. num_requests: int,
  67. tokenizer: PreTrainedTokenizerBase,
  68. fixed_output_len: Optional[int],
  69. ):
  70. processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
  71. with open(dataset_path, "r") as f:
  72. json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
  73. result = []
  74. for data in tqdm(json_data):
  75. text = processor.apply_chat_template(data["chat_messages"], tokenize=False, add_generation_prompt=True)
  76. raw_b64 = data["chat_messages"][0]["content"][1]["image_url"]["url"]
  77. _main_image = Image.open(BytesIO(base64.b64decode(raw_b64[raw_b64.find(",") + 1 :])))
  78. # Process inputs using processor
  79. inputs = processor(
  80. text=[text],
  81. # images=[_main_image], # Don't pad out the image tokens yet, since that happens later inside of birr
  82. padding=True,
  83. return_tensors="np",
  84. )
  85. # print(inputs)
  86. tokens = inputs["input_ids"][0]
  87. prompt_len = len(tokens)
  88. result.append(
  89. (
  90. TokensPrompt(
  91. dict(
  92. prompt_token_ids=tokens,
  93. multi_modal_data=dict(image=dict(image_embeds=torch.randn(1036, 3584), image_grid_thw=torch.tensor([[1, 74, 56]]))),
  94. # multi_modal_data=dict(image=main_image)
  95. )
  96. ),
  97. prompt_len,
  98. fixed_output_len,
  99. )
  100. )
  101. if len(result) >= num_requests:
  102. break
  103. return result
  104. def sample_mm_requests_phi3(
  105. dataset_path: str,
  106. num_requests: int,
  107. tokenizer: PreTrainedTokenizerBase,
  108. fixed_output_len: Optional[int],
  109. ):
  110. processor = AutoProcessor.from_pretrained("microsoft/Phi-3.5-vision-instruct", trust_remote_code=True)
  111. with open(dataset_path, "r") as f:
  112. json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
  113. result = []
  114. for data in tqdm(json_data):
  115. inputs = processor.tokenizer.apply_chat_template(
  116. [{"role": "user", "content": "<|image_1|>\n" + data["chat_messages"][0]["content"][0]["text"]}], tokenize=True, add_generation_prompt=True
  117. )
  118. raw_b64 = data["chat_messages"][0]["content"][1]["image_url"]["url"]
  119. main_image = Image.open(BytesIO(base64.b64decode(raw_b64[raw_b64.find(",") + 1 :])))
  120. # tokens = inputs["input_ids"][0]
  121. tokens = inputs
  122. prompt_len = len(tokens)
  123. result.append(
  124. (
  125. TokensPrompt(
  126. dict(
  127. prompt_token_ids=tokens,
  128. multi_modal_data=dict(image=main_image),
  129. )
  130. ),
  131. prompt_len,
  132. fixed_output_len,
  133. )
  134. )
  135. if len(result) >= num_requests:
  136. break
  137. return result
  138. def sample_mm_requests_molmo(
  139. dataset_path: str,
  140. num_requests: int,
  141. tokenizer: PreTrainedTokenizerBase,
  142. fixed_output_len: Optional[int],
  143. ):
  144. processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-D-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
  145. with open(dataset_path, "r") as f:
  146. json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
  147. result = []
  148. for data in tqdm(json_data):
  149. raw_b64 = data["chat_messages"][0]["content"][1]["image_url"]["url"]
  150. main_image = Image.open(BytesIO(base64.b64decode(raw_b64[raw_b64.find(",") + 1 :])))
  151. inputs = inputs = processor.process(images=[main_image], text=data["chat_messages"][0]["content"][0]["text"])
  152. # print(inputs)
  153. # Molmo has max size of 4096 which is lower than our dataset was generated for
  154. tokens = inputs["input_ids"][:2000]
  155. # tokens = inputs
  156. prompt_len = len(tokens)
  157. result.append(
  158. (
  159. TokensPrompt(
  160. dict(
  161. prompt_token_ids=tokens,
  162. multi_modal_data=dict(image=main_image),
  163. )
  164. ),
  165. prompt_len,
  166. fixed_output_len,
  167. )
  168. )
  169. if len(result) >= num_requests:
  170. break
  171. return result
  172. def run_vllm(
  173. requests: List[Tuple[str, int, int]],
  174. model: str,
  175. tokenizer: str,
  176. quantization: Optional[str],
  177. tensor_parallel_size: int,
  178. seed: int,
  179. n: int,
  180. trust_remote_code: bool,
  181. dtype: str,
  182. max_model_len: Optional[int],
  183. enforce_eager: bool,
  184. kv_cache_dtype: str,
  185. quantization_param_path: Optional[str],
  186. device: str,
  187. enable_prefix_caching: bool,
  188. enable_chunked_prefill: bool,
  189. max_num_batched_tokens: int,
  190. distributed_executor_backend: Optional[str],
  191. gpu_memory_utilization: float = 0.9,
  192. num_scheduler_steps: int = 1,
  193. download_dir: Optional[str] = None,
  194. load_format: str = EngineArgs.load_format,
  195. disable_async_output_proc: bool = False,
  196. ) -> float:
  197. from vllm import LLM, SamplingParams
  198. llm = LLM(
  199. model=model,
  200. tokenizer=tokenizer,
  201. quantization=quantization,
  202. tensor_parallel_size=tensor_parallel_size,
  203. seed=seed,
  204. trust_remote_code=trust_remote_code,
  205. dtype=dtype,
  206. # speculative_model="[ngram]",
  207. # num_speculative_tokens=1,
  208. # ngram_prompt_lookup_max=5,
  209. max_model_len=max_model_len,
  210. gpu_memory_utilization=gpu_memory_utilization,
  211. enforce_eager=enforce_eager,
  212. kv_cache_dtype=kv_cache_dtype,
  213. quantization_param_path=quantization_param_path,
  214. device=device,
  215. enable_prefix_caching=enable_prefix_caching,
  216. download_dir=download_dir,
  217. enable_chunked_prefill=enable_chunked_prefill,
  218. max_num_batched_tokens=max_num_batched_tokens,
  219. distributed_executor_backend=distributed_executor_backend,
  220. load_format=load_format,
  221. num_scheduler_steps=num_scheduler_steps,
  222. disable_async_output_proc=disable_async_output_proc,
  223. )
  224. # Add the requests to the engine.
  225. prompts: List[str] = []
  226. sampling_params: List[SamplingParams] = []
  227. for prompt, _, output_len in requests:
  228. prompts.append(prompt)
  229. sampling_params.append(
  230. SamplingParams(
  231. n=n,
  232. temperature=1.0,
  233. top_p=1.0,
  234. ignore_eos=True,
  235. max_tokens=output_len,
  236. )
  237. )
  238. use_beam_search = False
  239. if not use_beam_search:
  240. start = time.perf_counter()
  241. llm.generate(prompts, sampling_params, use_tqdm=True)
  242. end = time.perf_counter()
  243. else:
  244. prompts = [prompt for prompt, _, _ in requests]
  245. # output_len should be the same for all requests.
  246. output_len = requests[0][2]
  247. for prompt, input_len, _output_len in requests:
  248. assert _output_len == output_len
  249. start = time.perf_counter()
  250. llm.beam_search(
  251. prompts,
  252. BeamSearchParams(
  253. beam_width=n,
  254. max_tokens=output_len,
  255. ignore_eos=True,
  256. ),
  257. )
  258. end = time.perf_counter()
  259. return end - start
  260. async def run_vllm_async(
  261. requests: List[Tuple[str, int, int]],
  262. model: str,
  263. tokenizer: str,
  264. quantization: Optional[str],
  265. tensor_parallel_size: int,
  266. seed: int,
  267. n: int,
  268. trust_remote_code: bool,
  269. dtype: str,
  270. max_model_len: Optional[int],
  271. enforce_eager: bool,
  272. kv_cache_dtype: str,
  273. quantization_param_path: Optional[str],
  274. device: str,
  275. enable_prefix_caching: bool,
  276. enable_chunked_prefill: bool,
  277. max_num_batched_tokens: int,
  278. distributed_executor_backend: Optional[str],
  279. gpu_memory_utilization: float = 0.9,
  280. num_scheduler_steps: int = 1,
  281. download_dir: Optional[str] = None,
  282. load_format: str = EngineArgs.load_format,
  283. disable_async_output_proc: bool = False,
  284. disable_frontend_multiprocessing: bool = False,
  285. ) -> float:
  286. from vllm import SamplingParams
  287. engine_args = AsyncEngineArgs(
  288. model=model,
  289. tokenizer=tokenizer,
  290. quantization=quantization,
  291. tensor_parallel_size=tensor_parallel_size,
  292. seed=seed,
  293. trust_remote_code=trust_remote_code,
  294. dtype=dtype,
  295. max_model_len=max_model_len,
  296. gpu_memory_utilization=gpu_memory_utilization,
  297. enforce_eager=enforce_eager,
  298. kv_cache_dtype=kv_cache_dtype,
  299. quantization_param_path=quantization_param_path,
  300. device=device,
  301. enable_prefix_caching=enable_prefix_caching,
  302. download_dir=download_dir,
  303. enable_chunked_prefill=enable_chunked_prefill,
  304. max_num_batched_tokens=max_num_batched_tokens,
  305. distributed_executor_backend=distributed_executor_backend,
  306. load_format=load_format,
  307. num_scheduler_steps=num_scheduler_steps,
  308. disable_async_output_proc=disable_async_output_proc,
  309. worker_use_ray=False,
  310. disable_log_requests=True,
  311. )
  312. async with build_async_engine_client_from_engine_args(engine_args, disable_frontend_multiprocessing) as llm:
  313. # Add the requests to the engine.
  314. prompts: List[str] = []
  315. sampling_params: List[SamplingParams] = []
  316. for prompt, _, output_len in requests:
  317. prompts.append(prompt)
  318. sampling_params.append(
  319. SamplingParams(
  320. n=n,
  321. temperature=1.0,
  322. top_p=1.0,
  323. ignore_eos=True,
  324. max_tokens=output_len,
  325. )
  326. )
  327. generators = []
  328. start = time.perf_counter()
  329. for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
  330. generator = llm.generate(prompt, sp, request_id=f"test{i}")
  331. generators.append(generator)
  332. all_gens = merge_async_iterators(*generators)
  333. async for i, res in all_gens:
  334. pass
  335. end = time.perf_counter()
  336. return end - start
  337. def run_hf(
  338. requests: List[Tuple[str, int, int]],
  339. model: str,
  340. tokenizer: PreTrainedTokenizerBase,
  341. n: int,
  342. max_batch_size: int,
  343. trust_remote_code: bool,
  344. ) -> float:
  345. llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
  346. if llm.config.model_type == "llama":
  347. # To enable padding in the HF backend.
  348. tokenizer.pad_token = tokenizer.eos_token
  349. llm = llm.cuda()
  350. pbar = tqdm(total=len(requests))
  351. start = time.perf_counter()
  352. batch: List[str] = []
  353. max_prompt_len = 0
  354. max_output_len = 0
  355. for i in range(len(requests)):
  356. prompt, prompt_len, output_len = requests[i]
  357. # Add the prompt to the batch.
  358. batch.append(prompt)
  359. max_prompt_len = max(max_prompt_len, prompt_len)
  360. max_output_len = max(max_output_len, output_len)
  361. if len(batch) < max_batch_size and i != len(requests) - 1:
  362. # Check if we can add more requests to the batch.
  363. _, next_prompt_len, next_output_len = requests[i + 1]
  364. if (max(max_prompt_len, next_prompt_len) + max(max_output_len, next_output_len)) <= 2048:
  365. # We can add more requests to the batch.
  366. continue
  367. # Generate the sequences.
  368. input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
  369. llm_outputs = llm.generate(
  370. input_ids=input_ids.cuda(),
  371. do_sample=True,
  372. num_return_sequences=n,
  373. temperature=1.0,
  374. top_p=1.0,
  375. use_cache=True,
  376. max_new_tokens=max_output_len,
  377. )
  378. # Include the decoding time.
  379. tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
  380. pbar.update(len(batch))
  381. # Clear the batch.
  382. batch = []
  383. max_prompt_len = 0
  384. max_output_len = 0
  385. end = time.perf_counter()
  386. return end - start
  387. def run_mii(
  388. requests: List[Tuple[str, int, int]],
  389. model: str,
  390. tensor_parallel_size: int,
  391. output_len: int,
  392. ) -> float:
  393. from mii import client, serve
  394. llm = serve(model, tensor_parallel=tensor_parallel_size)
  395. prompts = [prompt for prompt, _, _ in requests]
  396. start = time.perf_counter()
  397. llm.generate(prompts, max_new_tokens=output_len)
  398. end = time.perf_counter()
  399. client = client(model)
  400. client.terminate_server()
  401. return end - start
  402. def main(args: argparse.Namespace):
  403. print(args)
  404. random.seed(args.seed)
  405. # Sample the requests.
  406. tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=args.trust_remote_code)
  407. if args.dataset is None:
  408. # Synthesize a prompt with the given input length.
  409. prompt = "hi" * (args.input_len - 1)
  410. requests = [(prompt, args.input_len, args.output_len) for _ in range(args.num_prompts)]
  411. else:
  412. # requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
  413. # args.output_len)
  414. requests = sample_mm_requests_qwen2vl(args.dataset, args.num_prompts, tokenizer, args.output_len)
  415. if args.backend == "vllm":
  416. run_args = [
  417. requests,
  418. args.model,
  419. args.tokenizer,
  420. args.quantization,
  421. args.tensor_parallel_size,
  422. args.seed,
  423. args.n,
  424. args.trust_remote_code,
  425. args.dtype,
  426. args.max_model_len,
  427. args.enforce_eager,
  428. args.kv_cache_dtype,
  429. args.quantization_param_path,
  430. args.device,
  431. args.enable_prefix_caching,
  432. args.enable_chunked_prefill,
  433. args.max_num_batched_tokens,
  434. args.distributed_executor_backend,
  435. args.gpu_memory_utilization,
  436. args.num_scheduler_steps,
  437. args.download_dir,
  438. args.load_format,
  439. args.disable_async_output_proc,
  440. ]
  441. if args.async_engine:
  442. run_args.append(args.disable_frontend_multiprocessing)
  443. elapsed_time = uvloop.run(run_vllm_async(*run_args))
  444. else:
  445. elapsed_time = run_vllm(*run_args)
  446. elif args.backend == "hf":
  447. assert args.tensor_parallel_size == 1
  448. elapsed_time = run_hf(requests, args.model, tokenizer, args.n, args.hf_max_batch_size, args.trust_remote_code)
  449. elif args.backend == "mii":
  450. elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, args.output_len)
  451. else:
  452. raise ValueError(f"Unknown backend: {args.backend}")
  453. total_num_tokens = sum(prompt_len + output_len for _, prompt_len, output_len in requests)
  454. print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} tokens/s")
  455. # Output JSON results if specified
  456. if args.output_json:
  457. results = {
  458. "elapsed_time": elapsed_time,
  459. "num_requests": len(requests),
  460. "total_num_tokens": total_num_tokens,
  461. "requests_per_second": len(requests) / elapsed_time,
  462. "tokens_per_second": total_num_tokens / elapsed_time,
  463. }
  464. with open(args.output_json, "w") as f:
  465. json.dump(results, f, indent=4)
  466. if __name__ == "__main__":
  467. parser = FlexibleArgumentParser(description="Benchmark the throughput.")
  468. parser.add_argument("--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm")
  469. parser.add_argument("--dataset", type=str, default=None, help="Path to the dataset.")
  470. parser.add_argument("--input-len", type=int, default=None, help="Input prompt length for each request")
  471. parser.add_argument("--output-len", type=int, default=None, help="Output length for each request. Overrides the " "output length from the dataset.")
  472. parser.add_argument("--model", type=str, default="facebook/opt-125m")
  473. parser.add_argument("--tokenizer", type=str, default=None)
  474. parser.add_argument("--quantization", "-q", choices=[*QUANTIZATION_METHODS, None], default=None)
  475. parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
  476. parser.add_argument("--n", type=int, default=1, help="Number of generated sequences per prompt.")
  477. parser.add_argument("--num-prompts", type=int, default=1000, help="Number of prompts to process.")
  478. parser.add_argument("--seed", type=int, default=0)
  479. parser.add_argument("--hf-max-batch-size", type=int, default=None, help="Maximum batch size for HF backend.")
  480. parser.add_argument("--trust-remote-code", action="store_true", help="trust remote code from huggingface")
  481. parser.add_argument(
  482. "--max-model-len",
  483. type=int,
  484. default=None,
  485. help="Maximum length of a sequence (including prompt and output). " "If None, will be derived from the model.",
  486. )
  487. parser.add_argument(
  488. "--dtype",
  489. type=str,
  490. default="auto",
  491. choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
  492. help="data type for model weights and activations. "
  493. 'The "auto" option will use FP16 precision '
  494. "for FP32 and FP16 models, and BF16 precision "
  495. "for BF16 models.",
  496. )
  497. parser.add_argument(
  498. "--gpu-memory-utilization",
  499. type=float,
  500. default=0.9,
  501. help="the fraction of GPU memory to be used for "
  502. "the model executor, which can range from 0 to 1."
  503. "If unspecified, will use the default value of 0.9.",
  504. )
  505. parser.add_argument("--enforce-eager", action="store_true", help="enforce eager execution")
  506. parser.add_argument(
  507. "--kv-cache-dtype",
  508. type=str,
  509. choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
  510. default="auto",
  511. help='Data type for kv cache storage. If "auto", will use model '
  512. "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
  513. "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)",
  514. )
  515. parser.add_argument(
  516. "--quantization-param-path",
  517. type=str,
  518. default=None,
  519. help="Path to the JSON file containing the KV cache scaling factors. "
  520. "This should generally be supplied, when KV cache dtype is FP8. "
  521. "Otherwise, KV cache scaling factors default to 1.0, which may cause "
  522. "accuracy issues. FP8_E5M2 (without scaling) is only supported on "
  523. "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is "
  524. "instead supported for common inference criteria.",
  525. )
  526. parser.add_argument("--device", type=str, default="auto", choices=DEVICE_OPTIONS, help="device type for vLLM execution")
  527. parser.add_argument("--num-scheduler-steps", type=int, default=1, help="Maximum number of forward steps per scheduler call.")
  528. parser.add_argument("--enable-prefix-caching", action="store_true", help="Enable automatic prefix caching for vLLM backend.")
  529. parser.add_argument("--enable-chunked-prefill", action="store_true", help="enable chunked prefill for vLLM backend.")
  530. parser.add_argument("--max-num-batched-tokens", type=int, default=None, help="maximum number of batched tokens per " "iteration")
  531. parser.add_argument(
  532. "--download-dir", type=str, default=None, help="directory to download and load the weights, " "default to the default cache dir of huggingface"
  533. )
  534. parser.add_argument("--output-json", type=str, default=None, help="Path to save the throughput results in JSON format.")
  535. parser.add_argument(
  536. "--distributed-executor-backend",
  537. choices=["ray", "mp"],
  538. default=None,
  539. help="Backend to use for distributed serving. When more than 1 GPU "
  540. 'is used, will be automatically set to "ray" if installed '
  541. 'or "mp" (multiprocessing) otherwise.',
  542. )
  543. parser.add_argument(
  544. "--load-format",
  545. type=str,
  546. default=EngineArgs.load_format,
  547. choices=["auto", "pt", "safetensors", "npcache", "dummy", "tensorizer", "bitsandbytes"],
  548. help="The format of the model weights to load.\n\n"
  549. '* "auto" will try to load the weights in the safetensors format '
  550. "and fall back to the pytorch bin format if safetensors format "
  551. "is not available.\n"
  552. '* "pt" will load the weights in the pytorch bin format.\n'
  553. '* "safetensors" will load the weights in the safetensors format.\n'
  554. '* "npcache" will load the weights in pytorch format and store '
  555. "a numpy cache to speed up the loading.\n"
  556. '* "dummy" will initialize the weights with random values, '
  557. "which is mainly for profiling.\n"
  558. '* "tensorizer" will load the weights using tensorizer from '
  559. "CoreWeave. See the Tensorize vLLM Model script in the Examples"
  560. "section for more information.\n"
  561. '* "bitsandbytes" will load the weights using bitsandbytes '
  562. "quantization.\n",
  563. )
  564. parser.add_argument("--disable-async-output-proc", action="store_true", default=False, help="Disable async output processor for vLLM backend.")
  565. parser.add_argument("--async-engine", action="store_true", default=False, help="Use vLLM async engine rather than LLM class.")
  566. parser.add_argument("--disable-frontend-multiprocessing", action="store_true", default=False, help="Disable decoupled async engine frontend.")
  567. args = parser.parse_args()
  568. if args.tokenizer is None:
  569. args.tokenizer = args.model
  570. if args.dataset is None:
  571. assert args.input_len is not None
  572. assert args.output_len is not None
  573. else:
  574. assert args.input_len is None
  575. if args.backend == "vllm":
  576. if args.hf_max_batch_size is not None:
  577. raise ValueError("HF max batch size is only for HF backend.")
  578. elif args.backend == "hf":
  579. if args.hf_max_batch_size is None:
  580. raise ValueError("HF max batch size is required for HF backend.")
  581. if args.quantization is not None:
  582. raise ValueError("Quantization is only for vLLM backend.")
  583. elif args.backend == "mii":
  584. if args.dtype != "auto":
  585. raise ValueError("dtype must be auto for MII backend.")
  586. if args.n != 1:
  587. raise ValueError("n must be 1 for MII backend.")
  588. if args.quantization is not None:
  589. raise ValueError("Quantization is only for vLLM backend.")
  590. if args.hf_max_batch_size is not None:
  591. raise ValueError("HF max batch size is only for HF backend.")
  592. if args.tokenizer != args.model:
  593. raise ValueError("Tokenizer must be the same as the model for MII " "backend.")
  594. main(args)