tagging_pipeline_v2.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. #!/usr/bin/env python3
  2. """
  3. Tagging pipeline for Dolma JSONL datasets.
  4. For each .jsonl, .jsonl.gz, or .jsonl.ztd file under the dataset/documents folder,
  5. this script issues a model prompt completion
  6. collects the yes/no answers, and writes corresponding Dolma attributes JSONL files under
  7. scratch/attributes/, mirroring the input structure.
  8. """
  9. import argparse
  10. import asyncio
  11. import atexit
  12. import gzip
  13. import json
  14. import logging
  15. import os
  16. import re
  17. import sys
  18. import time
  19. from typing import Optional
  20. from urllib.parse import urlparse
  21. import boto3
  22. import httpx
  23. import zstandard as zstd
  24. from huggingface_hub import snapshot_download
  25. from pydantic import BaseModel, Field, ValidationError
  26. from olmocr.check import (
  27. check_torch_gpu_available,
  28. )
  29. from olmocr.metrics import MetricsKeeper
  30. from olmocr.s3_utils import (
  31. download_directory,
  32. expand_s3_glob,
  33. get_s3_bytes_with_backoff,
  34. parse_s3_path,
  35. )
  36. from olmocr.version import VERSION
  37. from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
  38. # Initialize logger
  39. logger = logging.getLogger(__name__)
  40. logger.setLevel(logging.DEBUG)
  41. logger.propagate = False
  42. server_logger = logging.getLogger("vllm")
  43. server_logger.propagate = False
  44. file_handler = logging.FileHandler("olmocr-pipeline-debug.log", mode="a")
  45. file_handler.setLevel(logging.DEBUG)
  46. file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
  47. console_handler = logging.StreamHandler()
  48. console_handler.setLevel(logging.INFO)
  49. console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
  50. # Add handlers to the logger
  51. logger.addHandler(file_handler)
  52. logger.addHandler(console_handler)
  53. server_logger.addHandler(file_handler)
  54. # Default port; overridden by --port
  55. SERVER_PORT = 30026
  56. # Global variables for token statistics
  57. metrics = MetricsKeeper(window=60 * 5)
  58. class PIIClassification(BaseModel):
  59. primary_language: str = Field(..., description="Primary language as a two-letter code")
  60. document_type: str = Field(..., description="Basic summary of document type classification")
  61. is_resume_cv: Optional[bool] = Field(None, description="True if the document is a page from a resume or cv")
  62. is_academic_paper: Optional[bool] = None
  63. is_textbook: Optional[bool] = None
  64. is_news_article: Optional[bool] = None
  65. is_test_or_quiz: Optional[bool] = None
  66. is_homework_assignment: Optional[bool] = None
  67. is_class_syllabus: Optional[bool] = None
  68. is_meeting_minutes: Optional[bool] = None
  69. is_legal_contract: Optional[bool] = None
  70. is_form: Optional[bool] = None
  71. is_correspondence_or_letter: Optional[bool] = None
  72. is_public_order: Optional[bool] = None
  73. is_court_notice: Optional[bool] = None
  74. is_completion_certificate: Optional[bool] = None
  75. contains_pii: Optional[bool] = Field(None, description="True if document contains PII")
  76. async def _process_single_page(page_text: str) -> PIIClassification:
  77. """Helper function to process a single document or page."""
  78. text = page_text
  79. query = {
  80. "model": "google/gemma-3-4b-it",
  81. "messages": [
  82. {
  83. "role": "user",
  84. "content": [
  85. {
  86. "type": "text",
  87. "text": (
  88. f"{text}\n\n-----------\n"
  89. "Given the text above, determine what type of document it is. Answer in JSON. The format of your json object should be {'primary_language': str, 'document_type': str, 'is_resume_cv': bool, 'is_academic_paper': bool, 'is_textbook': bool, 'is_news_article': bool, 'is_test_or_quiz': bool, 'is_homework_assignment': bool, 'is_class_syllabus': bool, 'is_meeting_minutes': bool, 'is_legal_contract': bool, 'is_form': bool, 'is_correspondence_or_letter': bool, 'is_public_order': bool, 'is_court_notice': bool, 'is_completion_certificate': bool, 'contains_pii': bool}"
  90. ),
  91. }
  92. ],
  93. }
  94. ],
  95. "max_tokens": 400,
  96. "temperature": 0.0,
  97. "response_format": {"type": "json_schema", "json_schema": {"name": "PIIClassification", "schema": PIIClassification.model_json_schema()}},
  98. }
  99. url = f"http://localhost:{SERVER_PORT}/v1/chat/completions"
  100. # ---------- HTTP call ---------------------------------------------------
  101. try:
  102. status, body = await apost(url, json_data=query)
  103. except Exception as e:
  104. logger.warning(f"Server network error: {e!s}")
  105. metrics.add_metrics(server_errors=1)
  106. return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
  107. metrics.add_metrics(server_requests=1)
  108. if status != 200:
  109. logger.warning(f"Server HTTP {status}: {body[:250]!r}")
  110. metrics.add_metrics(server_errors=1)
  111. return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
  112. # ---------- Parse base JSON --------------------------------------------
  113. try:
  114. base = json.loads(body)
  115. except json.JSONDecodeError:
  116. logger.warning(f"Server response is not valid JSON: {body[:250]!r}")
  117. metrics.add_metrics(server_errors=1)
  118. return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
  119. # Token accounting if available
  120. if "usage" in base:
  121. metrics.add_metrics(
  122. server_input_tokens=base["usage"].get("prompt_tokens", 0),
  123. server_output_tokens=base["usage"].get("completion_tokens", 0),
  124. )
  125. # ---------- Extract the model message ----------------------------------
  126. try:
  127. content = base["choices"][0]["message"].get("content")
  128. except (KeyError, IndexError, AttributeError) as e:
  129. logger.warning(f"Missing fields in Server response: {e!s}")
  130. metrics.add_metrics(server_errors=1)
  131. return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
  132. if not isinstance(content, str):
  133. logger.warning("Server `content` is not a string; treating as error.")
  134. metrics.add_metrics(server_errors=1)
  135. return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
  136. try:
  137. pii_classification: PIIClassification = PIIClassification.model_validate_json(content)
  138. return pii_classification
  139. except ValidationError as e:
  140. logger.warning(f"Unable to parse pii classification object: {e!s}")
  141. metrics.add_metrics(server_errors=1)
  142. return PIIClassification(primary_language="en", document_type="unknown", is_resume_cv=None, contains_pii=None)
  143. # Manual simple implementation of HTTP Post
  144. # It feels strange perhaps, but httpx and aiohttp are very complex beasts
  145. # Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
  146. # that at the scale of 100M+ requests, that they deadlock in different strange ways
  147. async def apost(url, json_data):
  148. parsed_url = urlparse(url)
  149. host = parsed_url.hostname
  150. port = parsed_url.port or 80
  151. path = parsed_url.path or "/"
  152. writer = None
  153. try:
  154. reader, writer = await asyncio.open_connection(host, port)
  155. json_payload = json.dumps(json_data)
  156. request = (
  157. f"POST {path} HTTP/1.1\r\n"
  158. f"Host: {host}\r\n"
  159. f"Content-Type: application/json\r\n"
  160. f"Content-Length: {len(json_payload)}\r\n"
  161. f"Connection: close\r\n\r\n"
  162. f"{json_payload}"
  163. )
  164. writer.write(request.encode())
  165. await writer.drain()
  166. # Read status line
  167. status_line = await reader.readline()
  168. if not status_line:
  169. raise ConnectionError("No response from server")
  170. status_parts = status_line.decode().strip().split(" ", 2)
  171. if len(status_parts) < 2:
  172. raise ValueError(f"Malformed status line: {status_line.decode().strip()}")
  173. status_code = int(status_parts[1])
  174. # Read headers
  175. headers = {}
  176. while True:
  177. line = await reader.readline()
  178. if line in (b"\r\n", b"\n", b""):
  179. break
  180. key, _, value = line.decode().partition(":")
  181. headers[key.strip().lower()] = value.strip()
  182. # Read response body
  183. if "content-length" in headers:
  184. body_length = int(headers["content-length"])
  185. response_body = await reader.readexactly(body_length)
  186. else:
  187. raise ConnectionError("Anything other than fixed content length responses are not implemented yet")
  188. return status_code, response_body
  189. except Exception as e:
  190. # Pass through errors
  191. raise e
  192. finally:
  193. # But just make sure to close the socket on your way out
  194. if writer is not None:
  195. try:
  196. writer.close()
  197. await writer.wait_closed()
  198. except:
  199. pass
  200. async def process_dolma_document(args, dolma_doc, sem):
  201. """
  202. Query model to detect PII, enforcing a JSON schema.
  203. Resilient to:
  204. • Transport / HTTP errors
  205. • Missing or malformed fields in the response
  206. • Non-string or None `content`
  207. • Bad JSON in the model's answer
  208. Always returns: (doc_id, contains_pii: bool, text_length: int)
  209. """
  210. doc_id = dolma_doc.get("id")
  211. text = dolma_doc.get("text", "") or ""
  212. # Create keys for all fields in PIIClassification
  213. prefix = args.model.replace("/", "_") + "_v2tag_"
  214. result_attributes = {}
  215. # Initialize attribute lists for all PIIClassification fields
  216. for field_name in PIIClassification.model_fields:
  217. key_name = f"{prefix}_{field_name}"
  218. result_attributes[key_name] = []
  219. # If pdf_page_numbers is present, sample first 5000 characters of the document
  220. if "attributes" in dolma_doc and "pdf_page_numbers" in dolma_doc["attributes"]:
  221. page_numbers = dolma_doc["attributes"]["pdf_page_numbers"]
  222. logger.info(f"Document {doc_id} has {len(page_numbers)} pages, processing first 5000 characters")
  223. # Take first 5000 characters of the document
  224. sample_text = text[:5000]
  225. text_length = len(text)
  226. span_end = min(5000, text_length)
  227. # Process the sample with the semaphore to limit concurrent requests
  228. async with sem:
  229. pii_class = await _process_single_page(sample_text)
  230. # Add all classification attributes to results
  231. for field_name in PIIClassification.model_fields:
  232. key_name = f"{prefix}_{field_name}"
  233. attribute_value = getattr(pii_class, field_name)
  234. # Create a span from 0 to min(5000, len(text)) with the attribute value
  235. result_attributes[key_name].append([0, span_end, attribute_value])
  236. # If the document is longer than 5000 characters, add a null span for the rest
  237. if text_length > 5000:
  238. result_attributes[key_name].append([span_end, text_length, None])
  239. return result_attributes
  240. else:
  241. raise NotImplementedError("Missing code here, expecting this to be dolma docs made by olmocr....")
  242. async def process_file(args, worker_id: int, file_uri: str):
  243. """
  244. Download a JSONL file, query model per record, and collect attributes.
  245. """
  246. # Fetch raw bytes (S3 or local)
  247. if file_uri.startswith("s3://"):
  248. raw = await asyncio.to_thread(get_s3_bytes_with_backoff, dataset_s3, file_uri)
  249. else:
  250. with open(file_uri, "rb") as f:
  251. raw = f.read()
  252. # Decompress if needed
  253. if file_uri.endswith(".gz"):
  254. file_bytes = gzip.decompress(raw)
  255. elif file_uri.endswith(".ztd") or file_uri.endswith(".zst") or file_uri.endswith(".zstd"):
  256. dctx = zstd.ZstdDecompressor()
  257. file_bytes = dctx.decompress(raw, max_output_size=1_000_000_000)
  258. else:
  259. file_bytes = raw
  260. lines = file_bytes.decode("utf-8").splitlines()
  261. page_tasks = {}
  262. # Send all records in parallel, max N queued at a time
  263. sem = asyncio.Semaphore(args.parallel_requests)
  264. async with asyncio.TaskGroup() as tg:
  265. for line in lines:
  266. dolma_doc = json.loads(line)
  267. task = tg.create_task(process_dolma_document(args, dolma_doc, sem))
  268. page_tasks[dolma_doc["id"]] = (task, dolma_doc)
  269. logger.info(f"Finished taskgroup with {len(page_tasks)} items for {file_uri}")
  270. # Collect results and build attributes
  271. attributes = []
  272. for doc_id, (task, dolma_doc) in page_tasks.items():
  273. doc_attributes = task.result()
  274. attributes.append({"id": doc_id, "attributes": doc_attributes})
  275. return attributes
  276. async def worker(args, work_queue: WorkQueue, semaphore: asyncio.Semaphore, worker_id: int):
  277. """
  278. Pop work-items off the queue, run PII tagging, write the attributes file
  279. next to the dataset (keeping the original compression), mark the item done,
  280. and drop an empty sentinel file in <workspace>/results/.
  281. """
  282. while True:
  283. await semaphore.acquire()
  284. work_item = await work_queue.get_work()
  285. if work_item is None:
  286. logger.info(f"Worker {worker_id} exiting – queue empty")
  287. semaphore.release()
  288. break
  289. file_uri = work_item.work_paths[0]
  290. logger.info(f"Worker {worker_id} processing {file_uri}")
  291. try:
  292. # ------------------------------------------------------------------
  293. # Run the per-file pipeline
  294. # ------------------------------------------------------------------
  295. attributes = await process_file(args, worker_id, file_uri)
  296. # 1. Build the relative path that mirrors documents/…
  297. if file_uri.startswith("s3://"):
  298. _, key = parse_s3_path(file_uri)
  299. _, docs_prefix = parse_s3_path(args.dataset)
  300. rel_path = key[len(os.path.join(docs_prefix, "documents/")) :]
  301. else:
  302. docs_root = os.path.join(args.dataset, "documents")
  303. rel_path = os.path.relpath(file_uri, docs_root)
  304. out_rel = os.path.join("attributes", args.attribute_name, rel_path)
  305. out_jsonl = "\n".join(json.dumps(x) for x in attributes) + "\n"
  306. # 2. Preserve compression type
  307. if rel_path.endswith(".gz"):
  308. payload = gzip.compress(out_jsonl.encode("utf-8"))
  309. elif rel_path.endswith((".zst", ".ztd")):
  310. payload = zstd.ZstdCompressor().compress(out_jsonl.encode("utf-8"))
  311. else:
  312. payload = out_jsonl.encode("utf-8")
  313. # 3. Write to args.dataset (local or S3)
  314. if args.dataset.startswith("s3://"):
  315. bucket, prefix = parse_s3_path(args.dataset)
  316. key = os.path.join(prefix, out_rel)
  317. workspace_s3.put_object(Bucket=bucket, Key=key, Body=payload)
  318. else:
  319. out_path = os.path.join(args.dataset, out_rel)
  320. os.makedirs(os.path.dirname(out_path), exist_ok=True)
  321. with open(out_path, "wb") as fh:
  322. fh.write(payload)
  323. # 4. Mark queue item done
  324. await work_queue.mark_done(work_item)
  325. # 5. Drop empty sentinel file in <workspace>/results/
  326. sentinel_rel = os.path.join("results", f"output_{work_item.hash}.jsonl")
  327. if args.scratch.startswith("s3://"):
  328. bkt, pfx = parse_s3_path(args.scratch)
  329. key = os.path.join(pfx, sentinel_rel)
  330. workspace_s3.put_object(Bucket=bkt, Key=key, Body=b"")
  331. else:
  332. sentinel_path = os.path.join(args.scratch, sentinel_rel)
  333. os.makedirs(os.path.dirname(sentinel_path), exist_ok=True)
  334. open(sentinel_path, "w").close()
  335. except Exception as exc:
  336. logger.exception(f"Worker {worker_id} exception: {exc!s}")
  337. finally:
  338. semaphore.release()
  339. async def server_task(model_name_or_path, args, semaphore):
  340. # Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
  341. # mem_fraction_arg = ["--mem-fraction-static", "0.80"]
  342. cmd = [
  343. "vllm",
  344. "serve",
  345. model_name_or_path,
  346. "--port",
  347. str(SERVER_PORT),
  348. "--uvicorn-log-level",
  349. "warning",
  350. "--disable-log-requests",
  351. ]
  352. proc = await asyncio.create_subprocess_exec(
  353. *cmd,
  354. stdout=asyncio.subprocess.PIPE,
  355. stderr=asyncio.subprocess.PIPE,
  356. )
  357. # Ensure the subprocess is terminated on exit
  358. def _kill_proc():
  359. proc.terminate()
  360. atexit.register(_kill_proc)
  361. # Shared variables between tasks
  362. last_running_req, last_queue_req = 0, 0
  363. server_printed_ready_message = False
  364. last_semaphore_release = time.time()
  365. async def process_line(line):
  366. nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
  367. server_logger.info(line)
  368. # if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
  369. # can see any warnings/errors more easily
  370. if not server_printed_ready_message:
  371. logger.info(line)
  372. if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
  373. server_printed_ready_message = True
  374. last_semaphore_release = time.time()
  375. match = re.search(r"Running: (\d+) reqs", line)
  376. if match:
  377. last_running_req = int(match.group(1))
  378. match = re.search(r"Waiting: (\d+) reqs", line)
  379. if match:
  380. last_queue_req = int(match.group(1))
  381. logger.info(f"running req: {last_running_req} queue req: {last_queue_req}")
  382. async def read_stream(stream):
  383. while True:
  384. line = await stream.readline()
  385. if not line:
  386. break
  387. try:
  388. line = line.decode("utf-8").rstrip()
  389. await process_line(line)
  390. except Exception as ex:
  391. logger.warning(f"Got {ex} when reading log line from inference server, skipping")
  392. async def timeout_task():
  393. nonlocal last_running_req, last_queue_req, last_semaphore_release
  394. try:
  395. while True:
  396. await asyncio.sleep(1)
  397. if server_printed_ready_message and last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
  398. semaphore.release()
  399. last_semaphore_release = time.time()
  400. logger.info("Semaphore released, allowing a worker to proceed.")
  401. except asyncio.CancelledError:
  402. pass # Clean up if the task is cancelled
  403. # Start tasks to read stdout, stderr, and handle timeout logic
  404. stdout_task = asyncio.create_task(read_stream(proc.stdout))
  405. stderr_task = asyncio.create_task(read_stream(proc.stderr))
  406. timeout_task = asyncio.create_task(timeout_task())
  407. try:
  408. await proc.wait()
  409. except asyncio.CancelledError:
  410. logger.info("Got cancellation request for server")
  411. proc.terminate()
  412. raise
  413. timeout_task.cancel()
  414. await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
  415. async def server_host(model_name_or_path, args, semaphore):
  416. MAX_RETRIES = 5
  417. retry = 0
  418. while retry < MAX_RETRIES:
  419. await server_task(model_name_or_path, args, semaphore)
  420. logger.warning("Server task ended")
  421. retry += 1
  422. if retry >= MAX_RETRIES:
  423. logger.error(f"Ended up starting the server more than {retry} times, cancelling pipeline")
  424. logger.error("")
  425. logger.error("Please make sure vllm is installed according to the latest instructions for 0.8.4")
  426. sys.exit(1)
  427. async def check_server_ready():
  428. max_attempts = 300
  429. delay_sec = 1
  430. url = f"http://localhost:{SERVER_PORT}/v1/models"
  431. for attempt in range(1, max_attempts + 1):
  432. try:
  433. async with httpx.AsyncClient() as session:
  434. response = await session.get(url)
  435. if response.status_code == 200:
  436. logger.info("server is ready.")
  437. return
  438. else:
  439. logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
  440. except Exception:
  441. logger.warning(f"Attempt {attempt}: Please wait for model server to become ready...")
  442. await asyncio.sleep(delay_sec)
  443. raise Exception("model server did not become ready after waiting.")
  444. async def download_model(model_name_or_path: str):
  445. if model_name_or_path.startswith("s3://") or model_name_or_path.startswith("gs://") or model_name_or_path.startswith("weka://"):
  446. logger.info(f"Downloading model directory from '{model_name_or_path}'")
  447. model_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "olmocr", "model")
  448. download_directory([model_name_or_path], model_cache_dir)
  449. return model_cache_dir
  450. elif os.path.isabs(model_name_or_path) and os.path.isdir(model_name_or_path):
  451. logger.info(f"Using local model path at '{model_name_or_path}'")
  452. return model_name_or_path
  453. else:
  454. logger.info(f"Downloading model with hugging face '{model_name_or_path}'")
  455. snapshot_download(repo_id=model_name_or_path)
  456. return model_name_or_path
  457. async def metrics_reporter(work_queue):
  458. while True:
  459. # Leading newlines preserve table formatting in logs
  460. logger.info(f"Queue remaining: {work_queue.size}")
  461. logger.info("\n" + str(metrics))
  462. await asyncio.sleep(10)
  463. def submit_beaker_job(args):
  464. from beaker import ( # type: ignore
  465. Beaker,
  466. Constraints,
  467. EnvVar,
  468. ExperimentSpec,
  469. ImageSource,
  470. Priority,
  471. ResultSpec,
  472. SecretNotFound,
  473. TaskContext,
  474. TaskResources,
  475. TaskSpec,
  476. )
  477. b = Beaker.from_env(default_workspace=args.beaker_workspace)
  478. account = b.account.whoami()
  479. owner = account.name
  480. beaker_image = f"jakep/olmocr-tagging-{VERSION}"
  481. task_name = f"olmocr-{os.path.basename(args.dataset.rstrip('/'))}"
  482. # Take out --beaker flag so the workers will just run things
  483. args_list = [arg for arg in sys.argv[1:] if arg != "--beaker"]
  484. # Take out the --pdfs [arg] or --pdfs=[arg], since the queue is populated locally
  485. args_list = [arg for i, arg in enumerate(args_list) if not (arg.startswith("--pdfs") or (i > 0 and args_list[i - 1] == "--pdfs"))]
  486. try:
  487. b.secret.get(f"{owner}-WEKA_ACCESS_KEY_ID", args.beaker_workspace)
  488. b.secret.get(f"{owner}-WEKA_SECRET_ACCESS_KEY", args.beaker_workspace)
  489. b.secret.get(f"{owner}-AWS_CREDENTIALS_FILE", args.beaker_workspace)
  490. except SecretNotFound:
  491. print(
  492. f"Expected beaker secrets for accessing Weka and S3 are not found. Are you okay to write those to your beaker workspace {args.beaker_workspace}? [y/n]"
  493. )
  494. if input().strip().lower() != "y":
  495. print("Exiting...")
  496. sys.exit(1)
  497. b.secret.write(f"{owner}-WEKA_ACCESS_KEY_ID", os.environ.get("WEKA_ACCESS_KEY_ID", ""), args.beaker_workspace)
  498. b.secret.write(f"{owner}-WEKA_SECRET_ACCESS_KEY", os.environ.get("WEKA_SECRET_ACCESS_KEY", ""), args.beaker_workspace)
  499. b.secret.write(
  500. f"{owner}-AWS_CREDENTIALS_FILE",
  501. open(os.path.join(os.path.expanduser("~"), ".aws", "credentials")).read(),
  502. args.beaker_workspace,
  503. )
  504. env_var_secrets = [
  505. EnvVar(name="WEKA_ACCESS_KEY_ID", secret=f"{owner}-WEKA_ACCESS_KEY_ID"),
  506. EnvVar(name="WEKA_SECRET_ACCESS_KEY", secret=f"{owner}-WEKA_SECRET_ACCESS_KEY"),
  507. EnvVar(name="AWS_CREDENTIALS_FILE", secret=f"{owner}-AWS_CREDENTIALS_FILE"),
  508. ]
  509. try:
  510. b.secret.get("OLMOCR_PREVIEW_HF_TOKEN", args.beaker_workspace)
  511. env_var_secrets.append(EnvVar(name="HF_TOKEN", secret="OLMOCR_PREVIEW_HF_TOKEN"))
  512. except SecretNotFound:
  513. pass
  514. try:
  515. b.secret.get("OE_DATA_GCS_SA_KEY", args.beaker_workspace)
  516. env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
  517. except SecretNotFound:
  518. print("Input the olmo-gcs SA key if you would like to load weights from gcs (end with a double newline):")
  519. lines = []
  520. prev_empty = False
  521. for line in iter(input, None):
  522. if not line and prev_empty:
  523. break
  524. prev_empty = not line
  525. lines.append(line)
  526. gcs_sa_key = "\n".join(lines[:-1]).strip() # Remove the last empty line
  527. if gcs_sa_key:
  528. b.secret.write("OE_DATA_GCS_SA_KEY", gcs_sa_key, args.beaker_workspace)
  529. env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
  530. # Create the experiment spec
  531. experiment_spec = ExperimentSpec(
  532. budget="ai2/oe-data",
  533. description=task_name,
  534. tasks=[
  535. TaskSpec(
  536. name=task_name,
  537. propagate_failure=False,
  538. propagate_preemption=False,
  539. replicas=args.beaker_gpus,
  540. context=TaskContext(
  541. priority=Priority(args.beaker_priority),
  542. preemptible=True,
  543. ),
  544. image=ImageSource(beaker=beaker_image),
  545. command=["python", "scripts/tagging_pipeline_v2.py"] + args_list,
  546. env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets,
  547. resources=TaskResources(gpu_count=1),
  548. constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
  549. result=ResultSpec(path="/noop-results"),
  550. )
  551. ],
  552. )
  553. experiment_data = b.experiment.create(spec=experiment_spec, workspace=args.beaker_workspace)
  554. print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")
  555. async def main():
  556. parser = argparse.ArgumentParser(description="Tagging pipeline for Dolma JSONL dataset")
  557. parser.add_argument("dataset", help="Dolma dataset root (local or s3://) with documents/ folder")
  558. parser.add_argument("scratch", help="Scratch workspace (local dir or s3://)")
  559. parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers")
  560. parser.add_argument("--parallel_requests", type=int, default=800, help="Max number of parallel requests to send to model")
  561. parser.add_argument("--model", default="google/gemma-3-4b-it", help="Model path or name, hugging face or local path format")
  562. parser.add_argument("--attribute_name", default="model_pii_tagging_v2", help="Path to use for attribute naming")
  563. # Beaker/job running stuff
  564. parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
  565. parser.add_argument("--beaker_workspace", help="Beaker workspace to submit to", default="ai2/olmocr")
  566. parser.add_argument(
  567. "--beaker_cluster",
  568. help="Beaker clusters you want to run on",
  569. default=["ai2/jupiter-cirrascale-2", "ai2/ceres-cirrascale", "ai2/neptune-cirrascale", "ai2/saturn-cirrascale", "ai2/augusta-google-1"],
  570. )
  571. parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
  572. parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
  573. parser.add_argument("--port", type=int, default=30026, help="Port for Model server")
  574. args = parser.parse_args()
  575. global SERVER_PORT, workspace_s3, dataset_s3
  576. SERVER_PORT = args.port
  577. workspace_s3 = boto3.client("s3")
  578. dataset_s3 = boto3.client("s3")
  579. # setup the job to work in beaker environment, load secrets, adjust logging, etc.
  580. if "BEAKER_JOB_ID" in os.environ:
  581. server_logger.addHandler(console_handler)
  582. if "AWS_CREDENTIALS_FILE" in os.environ:
  583. cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
  584. os.makedirs(os.path.dirname(cred_path), exist_ok=True)
  585. with open(cred_path, "w") as f:
  586. f.write(os.environ.get("AWS_CREDENTIALS_FILE"))
  587. if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ:
  588. cred_path = os.path.join(os.path.expanduser("~"), ".gcs", "credentials")
  589. os.makedirs(os.path.dirname(cred_path), exist_ok=True)
  590. with open(cred_path, "w") as f:
  591. f.write(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_FILE"))
  592. os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path
  593. workspace_s3 = boto3.client("s3")
  594. dataset_s3 = boto3.client("s3")
  595. # Wait a little bit so that not all beaker jobs in a task start at the same time and download the model at the same time
  596. replica_count = int(os.environ.get("BEAKER_REPLICA_COUNT", "1"))
  597. interval = 10 if (replica_count - 1) * 10 <= 240 else 240 / max(1, replica_count - 1)
  598. sleep_time = int(int(os.environ.get("BEAKER_REPLICA_RANK", "0")) * interval)
  599. logger.info(f"Beaker job sleeping for {sleep_time} seconds to stagger model downloads")
  600. await asyncio.sleep(sleep_time)
  601. # Initialize work queue
  602. if args.scratch.startswith("s3://"):
  603. work_queue = S3WorkQueue(workspace_s3, args.scratch)
  604. else:
  605. work_queue = LocalWorkQueue(args.scratch)
  606. # Discover input files
  607. files = set()
  608. if args.dataset.startswith("s3://"):
  609. pattern = args.dataset.rstrip("/") + "/documents/*.jsonl*"
  610. matched = expand_s3_glob(dataset_s3, pattern)
  611. files = set(matched.keys())
  612. else:
  613. docs_dir = os.path.join(args.dataset, "documents")
  614. for root, _, fns in os.walk(docs_dir):
  615. for fn in fns:
  616. if fn.endswith((".jsonl", ".jsonl.gz", ".jsonl.ztd")):
  617. files.add(os.path.join(root, fn))
  618. # Populate the work queue if needed
  619. await work_queue.populate_queue(list(files), items_per_group=1)
  620. if args.beaker:
  621. submit_beaker_job(args)
  622. return
  623. # If you get this far, then you are doing inference and need a GPU
  624. check_torch_gpu_available()
  625. logger.info(f"Starting pipeline with PID {os.getpid()}")
  626. # Download the model before you do anything else
  627. model_name_or_path = await download_model(args.model)
  628. # Initialize the work queue
  629. qsize = await work_queue.initialize_queue()
  630. if qsize == 0:
  631. logger.info("No work to do, exiting")
  632. return
  633. # Create a semaphore to control worker access
  634. # We only allow one worker to move forward with requests, until the server has no more requests in its queue
  635. # This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
  636. # As soon as one worker is no longer saturating the gpu, the next one can start sending requests
  637. semaphore = asyncio.Semaphore(1)
  638. model_server = asyncio.create_task(server_host(model_name_or_path, args, semaphore))
  639. await check_server_ready()
  640. metrics_task = asyncio.create_task(metrics_reporter(work_queue))
  641. # Create worker tasks to process the queue concurrently.
  642. worker_tasks = []
  643. for i in range(args.workers):
  644. task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
  645. worker_tasks.append(task)
  646. # Wait for all worker tasks to finish
  647. await asyncio.gather(*worker_tasks)
  648. model_server.cancel()
  649. metrics_task.cancel()
  650. logger.info("Work done")
  651. if __name__ == "__main__":
  652. asyncio.run(main())