autoscan_dolmadocs.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. import argparse
  2. import json
  3. import os
  4. import random
  5. import tempfile
  6. from concurrent.futures import ThreadPoolExecutor
  7. from enum import Enum
  8. from pathlib import Path
  9. from typing import Any, Dict, List, Optional
  10. import boto3
  11. import pydantic
  12. from openai import OpenAI
  13. from tqdm import tqdm
  14. from olmocr.data.renderpdf import render_pdf_to_base64png
  15. from olmocr.s3_utils import get_s3_bytes, parse_s3_path
  16. LanguageCode = Enum(
  17. "LanguageCode",
  18. {
  19. "en": "English",
  20. "zh": "Chinese",
  21. "hi": "Hindi",
  22. "es": "Spanish",
  23. "fr": "French",
  24. "ar": "Arabic",
  25. "bn": "Bengali",
  26. "ru": "Russian",
  27. "pt": "Portuguese",
  28. "ur": "Urdu",
  29. "id": "Indonesian",
  30. "de": "German",
  31. "ja": "Japanese",
  32. "sw": "Swahili",
  33. "mr": "Marathi",
  34. "te": "Telugu",
  35. "tr": "Turkish",
  36. "vi": "Vietnamese",
  37. "ta": "Tamil",
  38. "ko": "Korean",
  39. "other": "Other",
  40. },
  41. )
  42. class PIIAnnotation(pydantic.BaseModel):
  43. """Structured model for PII annotations returned by ChatGPT"""
  44. document_description: str
  45. language_code: LanguageCode
  46. cannot_read: bool
  47. inappropriate_content: bool
  48. is_public_document: bool
  49. # PII identifiers
  50. contains_names: bool
  51. contains_email_addresses: bool
  52. contains_phone_numbers: bool
  53. # PII that must co-occur with identifiers
  54. contains_addresses: bool
  55. contains_biographical_info: bool # DOB, gender, etc.
  56. contains_location_info: bool
  57. contains_employment_info: bool
  58. contains_education_info: bool
  59. contains_medical_info: bool
  60. # Always sensitive PII
  61. contains_government_ids: bool
  62. contains_financial_info: bool
  63. contains_biometric_data: bool
  64. contains_login_info: bool
  65. other_pii: str
  66. @property
  67. def has_pii(self) -> bool:
  68. """Check if the document contains any PII"""
  69. pii_fields = [
  70. self.contains_names,
  71. self.contains_email_addresses,
  72. self.contains_phone_numbers,
  73. self.contains_addresses,
  74. self.contains_biographical_info,
  75. self.contains_location_info,
  76. self.contains_employment_info,
  77. self.contains_education_info,
  78. self.contains_medical_info,
  79. self.contains_government_ids,
  80. self.contains_financial_info,
  81. self.contains_biometric_data,
  82. self.contains_login_info,
  83. ]
  84. return any(pii_fields) or bool(self.other_pii.strip())
  85. def get_pii_types(self) -> List[str]:
  86. """Get a list of all PII types found in the document"""
  87. pii_types = []
  88. if self.contains_names:
  89. pii_types.append("names")
  90. if self.contains_email_addresses:
  91. pii_types.append("email")
  92. if self.contains_phone_numbers:
  93. pii_types.append("phone")
  94. if self.contains_addresses:
  95. pii_types.append("addresses")
  96. if self.contains_biographical_info:
  97. pii_types.append("biographical")
  98. if self.contains_location_info:
  99. pii_types.append("location")
  100. if self.contains_employment_info:
  101. pii_types.append("employment")
  102. if self.contains_education_info:
  103. pii_types.append("education")
  104. if self.contains_medical_info:
  105. pii_types.append("medical")
  106. if self.contains_government_ids:
  107. pii_types.append("government-id")
  108. if self.contains_financial_info:
  109. pii_types.append("financial")
  110. if self.contains_biometric_data:
  111. pii_types.append("biometric")
  112. if self.contains_login_info:
  113. pii_types.append("login-info")
  114. if self.other_pii.strip():
  115. pii_types.append("other")
  116. return pii_types
  117. def parse_args():
  118. parser = argparse.ArgumentParser(description="Automatically scan OLMO OCR workspace results using ChatGPT")
  119. parser.add_argument("workspace", help="OLMO OCR workspace path (s3://bucket/workspace)")
  120. parser.add_argument("--pages_per_run", type=int, default=30, help="Number of pages per run")
  121. parser.add_argument("--pdf_profile", help="AWS profile for accessing PDFs")
  122. parser.add_argument("--output_dir", default="dolma_samples", help="Directory to save output files")
  123. parser.add_argument("--max_workers", type=int, default=4, help="Maximum number of worker threads")
  124. parser.add_argument("--openai_api_key", help="OpenAI API key (or set OPENAI_API_KEY env var)")
  125. parser.add_argument("--openai_model", default="gpt-4.1", help="OpenAI model to use")
  126. return parser.parse_args()
  127. def list_result_files(s3_client, workspace_path):
  128. """List all JSON result files in the workspace results directory."""
  129. bucket, prefix = parse_s3_path(workspace_path)
  130. results_prefix = os.path.join(prefix, "results").rstrip("/") + "/"
  131. all_files = []
  132. paginator = s3_client.get_paginator("list_objects_v2")
  133. for page in paginator.paginate(Bucket=bucket, Prefix=results_prefix):
  134. if "Contents" in page:
  135. all_files.extend([f"s3://{bucket}/{obj['Key']}" for obj in page["Contents"] if obj["Key"].endswith(".jsonl") or obj["Key"].endswith(".json")])
  136. # if len(all_files) > 1000:
  137. # break
  138. return all_files
  139. def get_random_pages(s3_client, result_files, count=30):
  140. """Get random pages from the result files."""
  141. random_pages = []
  142. # Try to collect the requested number of pages
  143. attempts = 0
  144. max_attempts = count * 3 # Allow extra attempts to handle potential failures
  145. while len(random_pages) < count and attempts < max_attempts:
  146. attempts += 1
  147. # Pick a random result file
  148. if not result_files:
  149. print("No result files found!")
  150. break
  151. result_file = random.choice(result_files)
  152. try:
  153. # Get the content of the file
  154. content = get_s3_bytes(s3_client, result_file)
  155. lines = content.decode("utf-8").strip().split("\n")
  156. if not lines:
  157. continue
  158. # Pick a random line (which contains a complete document)
  159. line = random.choice(lines)
  160. doc = json.loads(line)
  161. # A Dolma document has "text", "metadata", and "attributes" fields
  162. if "text" not in doc or "metadata" not in doc or "attributes" not in doc:
  163. print(f"Document in {result_file} is not a valid Dolma document")
  164. continue
  165. # Get the original PDF path from metadata
  166. pdf_path = doc["metadata"].get("Source-File")
  167. if not pdf_path:
  168. continue
  169. # Get page spans from attributes
  170. page_spans = doc["attributes"].get("pdf_page_numbers", [])
  171. if not page_spans:
  172. continue
  173. # Pick a random page span
  174. page_span = random.choice(page_spans)
  175. if len(page_span) >= 3:
  176. # Page spans are [start_pos, end_pos, page_num]
  177. page_num = page_span[2]
  178. # Extract text for this page
  179. start_pos, end_pos = page_span[0], page_span[1]
  180. page_text = doc["text"][start_pos:end_pos].strip()
  181. # Include the text snippet with the page info
  182. random_pages.append((pdf_path, page_num, page_text, result_file))
  183. if len(random_pages) >= count:
  184. break
  185. except Exception as e:
  186. print(f"Error processing {result_file}: {e}")
  187. continue
  188. print(f"Found {len(random_pages)} random pages from Dolma documents")
  189. return random_pages
  190. def chatgpt_analyze_page(pdf_path: str, page_num: int, pdf_s3_client, openai_api_key: str, openai_model: str) -> Optional[PIIAnnotation]:
  191. """Analyze a page using the ChatGPT vision model with structured outputs."""
  192. try:
  193. # Download PDF to temp file and render to image
  194. bucket, key = parse_s3_path(pdf_path)
  195. with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_file:
  196. pdf_data = pdf_s3_client.get_object(Bucket=bucket, Key=key)["Body"].read()
  197. temp_file.write(pdf_data)
  198. temp_file_path = temp_file.name
  199. # Render PDF to base64 image
  200. base64_image = render_pdf_to_base64png(temp_file_path, page_num, target_longest_image_dim=2048)
  201. # Clean up temp file
  202. os.unlink(temp_file_path)
  203. # Create OpenAI client
  204. client = OpenAI(api_key=openai_api_key)
  205. # Prepare the user message with all instructions
  206. user_message = """
  207. You are a document analyzer that identifies Personally Identifiable Information (PII) in documents.
  208. Your task is to analyze the provided document image and determine:
  209. 1. Whether the document is intended for public release or dissemination (e.g., research paper, public report, etc.)
  210. 2. If the document contains any PII
  211. For PII identification, follow these specific guidelines:
  212. IDENTIFIERS FOR PII:
  213. The following are considered identifiers that can make information PII:
  214. - Names (full names, first names, last names, nicknames)
  215. - Email addresses
  216. - Phone numbers
  217. PII THAT MUST CO-OCCUR WITH AN IDENTIFIER:
  218. The following types of information should ONLY be marked as PII if they occur ALONGSIDE an identifier (commonly, a person's name):
  219. - Addresses (street address, postal code, etc.)
  220. - Biographical Information (date of birth, place of birth, gender, sexual orientation, race, ethnicity, citizenship/immigration status, religion)
  221. - Location Information (geolocations, specific coordinates)
  222. - Employment Information (job titles, workplace names, employment history)
  223. - Education Information (school names, degrees, transcripts)
  224. - Medical Information (health records, diagnoses, genetic or neural data)
  225. PII THAT OCCURS EVEN WITHOUT AN IDENTIFIER:
  226. The following should ALWAYS be marked as PII even if they do not occur alongside an identifier:
  227. - Government IDs (Social Security Numbers, passport numbers, driver's license numbers, tax IDs)
  228. - Financial Information (credit card numbers, bank account/routing numbers)
  229. - Biometric Data (fingerprints, retina scans, facial recognition data, voice signatures)
  230. - Login information (ONLY mark as PII when a username, password, and login location are present together)
  231. If the document is a form, then only consider fields which are filled out with specific values as potential PII.
  232. If this page does not itself contain PII, but references documents (such as curriculum vitae, personal statements) that typically contain PII, then do not mark it as PII.
  233. Only consider actual occurrences of the PII within the document shown.
  234. """
  235. # Use the chat completions API with the custom schema
  236. completion = client.beta.chat.completions.parse(
  237. model=openai_model,
  238. messages=[
  239. {
  240. "role": "user",
  241. "content": [{"type": "text", "text": user_message}, {"type": "image_url", "image_url": {"url": f"data:image/webp;base64,{base64_image}"}}],
  242. }
  243. ],
  244. response_format=PIIAnnotation,
  245. max_tokens=1000,
  246. )
  247. return completion.choices[0].message.parsed
  248. except Exception as e:
  249. print(f"Error analyzing page {pdf_path} (page {page_num}): {e}")
  250. return None
  251. def create_presigned_url(s3_client, pdf_path, expiration=3600 * 24 * 7):
  252. """Create a presigned URL for the given S3 path."""
  253. try:
  254. bucket, key = parse_s3_path(pdf_path)
  255. url = s3_client.generate_presigned_url("get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=expiration)
  256. return url
  257. except Exception as e:
  258. print(f"Error creating presigned URL for {pdf_path}: {e}")
  259. return None
  260. def process_pages(random_pages, pdf_s3_client, openai_api_key, openai_model, max_workers):
  261. """Process multiple pages in parallel using ThreadPoolExecutor."""
  262. results = []
  263. # First generate presigned URLs for all PDFs
  264. print("Generating presigned URLs for PDFs...")
  265. presigned_urls = {}
  266. for pdf_path, page_num, _, _ in random_pages:
  267. if pdf_path not in presigned_urls and pdf_path.startswith("s3://"):
  268. url = create_presigned_url(pdf_s3_client, pdf_path)
  269. if url:
  270. presigned_urls[pdf_path] = url
  271. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  272. futures = {}
  273. # Submit all tasks
  274. for pdf_path, page_num, page_text, result_file in tqdm(random_pages, desc="Submitting pages for analysis"):
  275. future = executor.submit(chatgpt_analyze_page, pdf_path, page_num, pdf_s3_client, openai_api_key, openai_model)
  276. futures[future] = (pdf_path, page_num, page_text, result_file)
  277. # Process results as they complete
  278. for future in tqdm(futures, desc="Processing results"):
  279. pdf_path, page_num, page_text, result_file = futures[future]
  280. try:
  281. annotation = future.result()
  282. if annotation:
  283. # Get presigned URL with page number
  284. presigned_url = None
  285. if pdf_path in presigned_urls:
  286. presigned_url = f"{presigned_urls[pdf_path]}#page={page_num}"
  287. results.append((pdf_path, page_num, page_text, result_file, annotation, presigned_url))
  288. else:
  289. print(f"Failed to get annotation for {pdf_path} (page {page_num})")
  290. except Exception as e:
  291. print(f"Error processing {pdf_path} (page {page_num}): {e}")
  292. return results
  293. def categorize_results(all_results):
  294. """Categorize results for reporting."""
  295. categories = {
  296. "public_document": [],
  297. "private_document": [],
  298. "cannot_read": [],
  299. "report_content": [],
  300. "no_annotation": [],
  301. }
  302. for pdf_path, page_num, page_text, result_file, annotation, presigned_url in all_results:
  303. if annotation.cannot_read or annotation.language_code != LanguageCode.en:
  304. categories["cannot_read"].append({"pdf_path": pdf_path, "pdf_page": page_num, "result_file": result_file, "presigned_url": presigned_url})
  305. elif annotation.inappropriate_content:
  306. categories["report_content"].append({"pdf_path": pdf_path, "pdf_page": page_num, "result_file": result_file, "presigned_url": presigned_url})
  307. elif annotation.is_public_document:
  308. categories["public_document"].append(
  309. {
  310. "pdf_path": pdf_path,
  311. "pdf_page": page_num,
  312. "result_file": result_file,
  313. "pii_types": annotation.get_pii_types(),
  314. "has_pii": annotation.has_pii,
  315. "description": annotation.other_pii,
  316. "presigned_url": presigned_url,
  317. }
  318. )
  319. else:
  320. # Private document
  321. categories["private_document"].append(
  322. {
  323. "pdf_path": pdf_path,
  324. "pdf_page": page_num,
  325. "result_file": result_file,
  326. "pii_types": annotation.get_pii_types(),
  327. "has_pii": annotation.has_pii,
  328. "description": annotation.other_pii,
  329. "presigned_url": presigned_url,
  330. }
  331. )
  332. return categories
  333. def print_annotation_report(annotation_results: Dict[str, List[Dict[str, Any]]]):
  334. """Print a summary report of annotations."""
  335. total_pages = sum(len(items) for items in annotation_results.values())
  336. print("\n" + "=" * 80)
  337. print(f"ANNOTATION REPORT - Total Pages: {total_pages}")
  338. print("=" * 80)
  339. # Count pages with PII in public documents
  340. public_with_pii = [page for page in annotation_results["public_document"] if page.get("has_pii", False)]
  341. public_without_pii = [page for page in annotation_results["public_document"] if not page.get("has_pii", False)]
  342. # Count pages with PII in private documents
  343. private_with_pii = [page for page in annotation_results["private_document"] if page.get("has_pii", False)]
  344. private_without_pii = [page for page in annotation_results["private_document"] if not page.get("has_pii", False)]
  345. # Print summary statistics
  346. print("\nSummary:")
  347. print(
  348. f" Public documents (total): {len(annotation_results['public_document'])} ({len(annotation_results['public_document'])/total_pages*100:.1f}% of all pages)"
  349. )
  350. print(f" - With PII: {len(public_with_pii)} ({len(public_with_pii)/max(1, len(annotation_results['public_document']))*100:.1f}% of public docs)")
  351. print(
  352. f" - Without PII: {len(public_without_pii)} ({len(public_without_pii)/max(1, len(annotation_results['public_document']))*100:.1f}% of public docs)"
  353. )
  354. print(
  355. f" Private documents (total): {len(annotation_results['private_document'])} ({len(annotation_results['private_document'])/total_pages*100:.1f}% of all pages)"
  356. )
  357. print(f" - With PII: {len(private_with_pii)} ({len(private_with_pii)/max(1, len(annotation_results['private_document']))*100:.1f}% of private docs)")
  358. print(
  359. f" - Without PII: {len(private_without_pii)} ({len(private_without_pii)/max(1, len(annotation_results['private_document']))*100:.1f}% of private docs)"
  360. )
  361. print(f" Unreadable pages: {len(annotation_results['cannot_read'])} ({len(annotation_results['cannot_read'])/total_pages*100:.1f}%)")
  362. print(f" Pages with reported content: {len(annotation_results['report_content'])} ({len(annotation_results['report_content'])/total_pages*100:.1f}%)")
  363. print(f" Pages without annotation: {len(annotation_results['no_annotation'])} ({len(annotation_results['no_annotation'])/total_pages*100:.1f}%)")
  364. # Analyze PII types in private documents
  365. if private_with_pii:
  366. # Categorize the PII types for clearer reporting
  367. pii_categories = {
  368. "Identifiers": ["names", "email", "phone"],
  369. "PII requiring identifiers": ["addresses", "biographical", "location", "employment", "education", "medical"],
  370. "Always sensitive PII": ["government-id", "financial", "biometric", "login-info"],
  371. }
  372. # Dictionary to track all PII counts
  373. pii_counts_private = {}
  374. for page in private_with_pii:
  375. for pii_type in page.get("pii_types", []):
  376. pii_counts_private[pii_type] = pii_counts_private.get(pii_type, 0) + 1
  377. # Print categorized PII counts
  378. print("\nPII Types in Private Documents:")
  379. # Print each category
  380. for category, pii_types in pii_categories.items():
  381. print(f"\n {category}:")
  382. for pii_type in pii_types:
  383. count = pii_counts_private.get(pii_type, 0)
  384. if count > 0:
  385. print(f" - {pii_type}: {count} ({count/len(private_with_pii)*100:.1f}%)")
  386. # Print any other PII types not in our categories (like "other")
  387. other_pii = [pii_type for pii_type in pii_counts_private.keys() if not any(pii_type in types for types in pii_categories.values())]
  388. if other_pii:
  389. print("\n Other PII types:")
  390. for pii_type in other_pii:
  391. count = pii_counts_private.get(pii_type, 0)
  392. print(f" - {pii_type}: {count} ({count/len(private_with_pii)*100:.1f}%)")
  393. # Print detailed report for private documents with PII
  394. if private_with_pii:
  395. print("\nDetailed Report - Private Documents with PII:")
  396. print("-" * 80)
  397. for i, item in enumerate(private_with_pii, 1):
  398. pdf_path = item["pdf_path"]
  399. pdf_page = item["pdf_page"]
  400. presigned_url = item.get("presigned_url")
  401. print(f"{i}. PDF: {pdf_path}")
  402. print(f" Page: {pdf_page}")
  403. if presigned_url:
  404. print(f" Presigned URL: {presigned_url}")
  405. print(f" PII Types: {', '.join(item['pii_types'])}")
  406. if item.get("description"):
  407. print(f" Description: {item['description']}")
  408. print("-" * 80)
  409. # Print links to unreadable pages
  410. # if annotation_results["cannot_read"]:
  411. # print("\nUnreadable Pages:")
  412. # print("-" * 80)
  413. # for i, item in enumerate(annotation_results["cannot_read"], 1):
  414. # pdf_path = item["pdf_path"]
  415. # pdf_page = item["pdf_page"]
  416. # presigned_url = item.get("presigned_url")
  417. # print(f"{i}. PDF: {pdf_path}")
  418. # print(f" Page: {pdf_page}")
  419. # if presigned_url:
  420. # print(f" Presigned URL: {presigned_url}")
  421. # print("-" * 80)
  422. # Print links to inappropriate content
  423. if annotation_results["report_content"]:
  424. print("\nReported Content:")
  425. print("-" * 80)
  426. for i, item in enumerate(annotation_results["report_content"], 1):
  427. pdf_path = item["pdf_path"]
  428. pdf_page = item["pdf_page"]
  429. presigned_url = item.get("presigned_url")
  430. print(f"{i}. PDF: {pdf_path}")
  431. print(f" Page: {pdf_page}")
  432. if presigned_url:
  433. print(f" Presigned URL: {presigned_url}")
  434. print("-" * 80)
  435. print("\nReport complete.")
  436. def save_results(results, output_dir):
  437. """Save the results to a JSON file."""
  438. output_path = Path(output_dir) / "autoscan_results.json"
  439. # Convert results to serializable format
  440. serializable_results = []
  441. for pdf_path, page_num, page_text, result_file, annotation, presigned_url in results:
  442. serializable_results.append(
  443. {
  444. "pdf_path": pdf_path,
  445. "page_num": page_num,
  446. "page_text": page_text,
  447. "result_file": result_file,
  448. "annotation": annotation.dict(),
  449. "presigned_url": presigned_url,
  450. }
  451. )
  452. with open(output_path, "w") as f:
  453. json.dump(serializable_results, f, indent=2, default=lambda o: o.value if isinstance(o, Enum) else o)
  454. print(f"Results saved to {output_path}")
  455. def main():
  456. args = parse_args()
  457. # Get OpenAI API key from args or environment
  458. openai_api_key = args.openai_api_key or os.environ.get("OPENAI_API_KEY")
  459. if not openai_api_key:
  460. raise ValueError("OpenAI API key must be provided via --openai_api_key or OPENAI_API_KEY environment variable")
  461. # Set up S3 clients
  462. s3_client = boto3.client("s3")
  463. # Set up PDF S3 client with profile if specified
  464. if args.pdf_profile:
  465. pdf_session = boto3.Session(profile_name=args.pdf_profile)
  466. pdf_s3_client = pdf_session.client("s3")
  467. else:
  468. pdf_s3_client = s3_client
  469. # Create output directory
  470. output_dir = Path(args.output_dir)
  471. output_dir.mkdir(exist_ok=True, parents=True)
  472. # List all result files
  473. print(f"Listing result files in {args.workspace}/results...")
  474. result_files = list_result_files(s3_client, args.workspace)
  475. print(f"Found {len(result_files)} result files")
  476. # Get random pages
  477. random_pages = get_random_pages(s3_client, result_files, args.pages_per_run)
  478. # Process pages with ChatGPT
  479. print(f"Processing {len(random_pages)} pages with ChatGPT...")
  480. all_results = process_pages(random_pages, pdf_s3_client, openai_api_key, args.openai_model, args.max_workers)
  481. # Save results
  482. save_results(all_results, args.output_dir)
  483. # Categorize and report results
  484. categorized_results = categorize_results(all_results)
  485. print_annotation_report(categorized_results)
  486. if __name__ == "__main__":
  487. main()