pii_rule_comparison.py 76 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152
  1. #!/usr/bin/env python3
  2. """
  3. Compare PII Detection Rules and Calculate IoU
  4. This script processes documents and their attributes from S3 or local storage,
  5. applies different rules for PII detection, and calculates the
  6. Intersection over Union (IoU) to measure how well they overlap.
  7. How it works:
  8. 1. Documents are stored in one location (--docs-folder)
  9. 2. Attributes are automatically found in ../attributes/ relative to the documents folder
  10. 3. The script merges documents with all available attributes by matching filenames and document IDs
  11. 4. PII detection rules are applied to the merged documents
  12. 5. IoU and other metrics are calculated to compare the results
  13. Expected folder structure:
  14. - s3://bucket/path/documents/ - Contains the main document JSONL files
  15. - s3://bucket/path/attributes/ - Contains attributes that can be matched with documents by ID
  16. Document and attribute matching:
  17. - Files are matched by basename (example.jsonl in documents matches example.jsonl in attributes)
  18. - Within each file, documents are matched by their "id" field
  19. - When a match is found, attributes from the attribute file are merged into the document
  20. Example usage:
  21. python pii_rule_comparison.py \
  22. --docs-folder s3://bucket/path/documents \
  23. --ref-rule "gpt_4_1_contains_pii:any" \
  24. --hyp-rule "gpt_4_1_contains_email_addresses:any" \
  25. --output-file iou_results.json \
  26. --recursive
  27. Rule expression syntax:
  28. - Simple rule: "attribute_name:rule_type" where rule_type is "any" or "all"
  29. - Boolean expressions: "not rule1:any and rule2:all"
  30. - Parentheses for grouping: "(rule1:any or rule2:any) and not rule3:all"
  31. """
  32. import argparse
  33. import base64
  34. import gzip
  35. import html as pyhtml
  36. import io
  37. import json
  38. import logging
  39. import os
  40. from collections import defaultdict
  41. from enum import Enum, auto
  42. from io import BytesIO
  43. from pathlib import Path
  44. import boto3
  45. import numpy as np
  46. import zstandard as zstd
  47. from matplotlib.figure import Figure
  48. # Initialize logger
  49. logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
  50. logger = logging.getLogger(__name__)
  51. # Define token types for the rule expression parser
  52. class TokenType(Enum):
  53. RULE = auto()
  54. AND = auto()
  55. OR = auto()
  56. NOT = auto()
  57. LPAREN = auto()
  58. RPAREN = auto()
  59. EOF = auto()
  60. class Token:
  61. """Token for rule expression parsing"""
  62. def __init__(self, type, value=None):
  63. self.type = type
  64. self.value = value
  65. def __repr__(self):
  66. if self.value:
  67. return f"Token({self.type}, {self.value})"
  68. return f"Token({self.type})"
  69. class ExpressionNode:
  70. """Base class for expression tree nodes"""
  71. pass
  72. class RuleNode(ExpressionNode):
  73. """Leaf node representing a single rule"""
  74. def __init__(self, attribute_name, rule_type):
  75. self.attribute_name = attribute_name
  76. self.rule_type = rule_type
  77. def __repr__(self):
  78. return f"Rule({self.attribute_name}:{self.rule_type})"
  79. class NotNode(ExpressionNode):
  80. """Unary NOT operation node"""
  81. def __init__(self, operand):
  82. self.operand = operand
  83. def __repr__(self):
  84. return f"NOT({self.operand})"
  85. class BinaryNode(ExpressionNode):
  86. """Binary operation (AND/OR) node"""
  87. def __init__(self, left, right, operator):
  88. self.left = left
  89. self.right = right
  90. self.operator = operator
  91. def __repr__(self):
  92. return f"{self.operator}({self.left}, {self.right})"
  93. def parse_args():
  94. parser = argparse.ArgumentParser(description="Compare PII detection rules and calculate IoU")
  95. parser.add_argument("--docs-folder", required=True, help="Documents folder path containing JSONL files (local or s3://)")
  96. parser.add_argument("--attr-folder", help="Attributes folder path (if different from standard ../attributes/ location)")
  97. parser.add_argument(
  98. "--ref-rule",
  99. required=True,
  100. help="""Reference rule expression. Can be a simple rule in format 'attribute_name:rule_type',
  101. where rule_type is 'any' or 'all'. Or a boolean expression like
  102. 'not rule1:any and rule2:all' or '(rule1:any or rule2:any) and not rule3:all'""",
  103. )
  104. parser.add_argument(
  105. "--hyp-rule",
  106. required=True,
  107. help="""Hypothesis rule expression. Can be a simple rule in format 'attribute_name:rule_type',
  108. where rule_type is 'any' or 'all'. Or a boolean expression like
  109. 'not rule1:any and rule2:all' or '(rule1:any or rule2:any) and not rule3:all'""",
  110. )
  111. parser.add_argument("--output-dir", default="results", help="Directory to save HTML result files")
  112. parser.add_argument("--aws-profile", help="AWS profile for S3 access")
  113. parser.add_argument("--recursive", action="store_true", help="Recursively process folder structure")
  114. parser.add_argument("--debug", action="store_true", help="Enable debug logging for more detailed output")
  115. parser.add_argument("--disable-plots", action="store_true", help="Disable CDF plots generation")
  116. parser.add_argument("--max-plots", type=int, default=200, help="Maximum number of CDF plots to generate (default: 200)")
  117. return parser.parse_args()
  118. def parse_s3_path(s3_path):
  119. """Parse S3 path into bucket and prefix."""
  120. parts = s3_path.replace("s3://", "").split("/", 1)
  121. bucket = parts[0]
  122. prefix = parts[1] if len(parts) > 1 else ""
  123. return bucket, prefix
  124. def get_attributes_folder(docs_folder, attr_folder=None):
  125. """
  126. Determine the attributes folder path based on the documents folder.
  127. Args:
  128. docs_folder: Path to the documents folder
  129. attr_folder: Manually specified attributes folder (optional)
  130. Returns:
  131. Path to the attributes folder
  132. """
  133. if attr_folder:
  134. return attr_folder
  135. # If no attributes folder specified, derive it from the documents folder
  136. if docs_folder.startswith("s3://"):
  137. # For S3 paths
  138. bucket, prefix = parse_s3_path(docs_folder)
  139. # Remove trailing slashes for consistent path handling
  140. prefix = prefix.rstrip("/")
  141. # Check if the documents folder is in a 'documents' directory
  142. if prefix.endswith("/documents"):
  143. # Replace /documents with /attributes
  144. attr_prefix = prefix[: -len("/documents")] + "/attributes"
  145. else:
  146. # Otherwise, add a parent level and include 'attributes'
  147. path_parts = prefix.split("/")
  148. # Remove the last part (assumed to be the documents directory name)
  149. path_parts.pop()
  150. # Add 'attributes'
  151. path_parts.append("attributes")
  152. attr_prefix = "/".join(path_parts)
  153. return f"s3://{bucket}/{attr_prefix}"
  154. else:
  155. # For local paths
  156. docs_path = Path(docs_folder)
  157. # Check if the documents folder is in a 'documents' directory
  158. if docs_path.name == "documents":
  159. # Replace /documents with /attributes
  160. attr_path = docs_path.parent / "attributes"
  161. else:
  162. # Otherwise, add a parent level and include 'attributes'
  163. attr_path = docs_path.parent / "attributes"
  164. return str(attr_path)
  165. def get_s3_bytes(s3_client, s3_path):
  166. """Get bytes from S3 object."""
  167. bucket, key = parse_s3_path(s3_path)
  168. response = s3_client.get_object(Bucket=bucket, Key=key)
  169. return response["Body"].read()
  170. def list_jsonl_files(path, s3_client=None, recursive=False):
  171. """List all JSONL files in the given path, locally or in S3."""
  172. jsonl_files = []
  173. if path.startswith("s3://"):
  174. bucket, prefix = parse_s3_path(path)
  175. prefix = prefix.rstrip("/") + "/"
  176. # List objects in S3 bucket with given prefix
  177. paginator = s3_client.get_paginator("list_objects_v2")
  178. for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
  179. if "Contents" in page:
  180. for obj in page["Contents"]:
  181. key = obj["Key"]
  182. if (
  183. key.endswith(".jsonl")
  184. or key.endswith(".json")
  185. or key.endswith(".jsonl.gz")
  186. or key.endswith(".jsonl.zst")
  187. or key.endswith(".jsonl.ztd")
  188. or key.endswith(".jsonl.zstd")
  189. ):
  190. jsonl_files.append(f"s3://{bucket}/{key}")
  191. else:
  192. # Local file system
  193. path_obj = Path(path)
  194. if recursive:
  195. for file_path in path_obj.rglob("*"):
  196. if (
  197. file_path.name.endswith(".jsonl")
  198. or file_path.name.endswith(".json")
  199. or file_path.name.endswith(".jsonl.gz")
  200. or file_path.name.endswith(".jsonl.zst")
  201. or file_path.name.endswith(".jsonl.ztd")
  202. or file_path.name.endswith(".jsonl.zstd")
  203. ):
  204. jsonl_files.append(str(file_path))
  205. else:
  206. for file_path in path_obj.glob("*"):
  207. if (
  208. file_path.name.endswith(".jsonl")
  209. or file_path.name.endswith(".json")
  210. or file_path.name.endswith(".jsonl.gz")
  211. or file_path.name.endswith(".jsonl.zst")
  212. or file_path.name.endswith(".jsonl.ztd")
  213. or file_path.name.endswith(".jsonl.zstd")
  214. ):
  215. jsonl_files.append(str(file_path))
  216. return jsonl_files
  217. def load_jsonl_file(file_path, s3_client=None):
  218. """Load and decompress a JSONL file, either from local or S3."""
  219. try:
  220. # Get file content
  221. if file_path.startswith("s3://"):
  222. if s3_client is None:
  223. raise ValueError("S3 client is required for S3 paths")
  224. raw_data = get_s3_bytes(s3_client, file_path)
  225. else:
  226. with open(file_path, "rb") as f:
  227. raw_data = f.read()
  228. # Decompress if needed
  229. if file_path.endswith(".gz"):
  230. decompressed = gzip.decompress(raw_data)
  231. elif file_path.endswith((".zst", ".ztd", ".zstd")):
  232. try:
  233. # First try with standard decompression
  234. dctx = zstd.ZstdDecompressor()
  235. decompressed = dctx.decompress(raw_data)
  236. except zstd.ZstdError as e:
  237. # If that fails, try with stream decompression
  238. logger.warning(f"Standard zstd decompression failed for {file_path}, trying stream decompression: {e}")
  239. try:
  240. # Try with content-size not required
  241. dctx = zstd.ZstdDecompressor(max_window_size=2147483648) # Use a large window size
  242. decompressor = dctx.stream_reader(io.BytesIO(raw_data))
  243. decompressed = decompressor.read()
  244. except Exception as inner_e:
  245. # If both methods fail, try with chunking
  246. logger.warning(f"Stream decompression also failed, trying chunked reading: {inner_e}")
  247. # Chunked reading approach
  248. buffer = io.BytesIO()
  249. dctx = zstd.ZstdDecompressor(max_window_size=2147483648)
  250. with dctx.stream_reader(io.BytesIO(raw_data)) as reader:
  251. while True:
  252. chunk = reader.read(16384) # Read in 16KB chunks
  253. if not chunk:
  254. break
  255. buffer.write(chunk)
  256. buffer.seek(0)
  257. decompressed = buffer.read()
  258. else:
  259. decompressed = raw_data
  260. # Parse JSON lines
  261. lines = decompressed.decode("utf-8").strip().split("\n")
  262. return [json.loads(line) for line in lines if line.strip()]
  263. except Exception as e:
  264. logger.error(f"Error loading file {file_path}: {e}")
  265. return []
  266. def load_documents_and_attributes(docs_folder, attr_folder, s3_client=None, recursive=False):
  267. """
  268. Load documents and merge them with their attributes from all subdirectories.
  269. Args:
  270. docs_folder: Path to the documents folder
  271. attr_folder: Path to the attributes folder
  272. s3_client: S3 client for S3 paths
  273. recursive: Whether to process folders recursively
  274. Returns:
  275. List of documents with their attributes merged in
  276. """
  277. try:
  278. # List all document files
  279. logger.info(f"Finding document files in: {docs_folder}")
  280. doc_files = list_jsonl_files(docs_folder, s3_client, recursive)
  281. logger.info(f"Found {len(doc_files)} document files")
  282. if not doc_files:
  283. logger.warning(f"No document files found in {docs_folder}. Check the path and permissions.")
  284. return []
  285. # Get all attribute subdirectories if it's an S3 path
  286. attr_subdirs = []
  287. if attr_folder.startswith("s3://"):
  288. bucket, attr_prefix = parse_s3_path(attr_folder)
  289. attr_prefix = attr_prefix.rstrip("/") + "/"
  290. # List top-level directories in the attributes folder
  291. logger.info(f"Finding attribute subdirectories in: {attr_folder}")
  292. # Using delimiter parameter to list "directories" in S3
  293. paginator = s3_client.get_paginator("list_objects_v2")
  294. for page in paginator.paginate(Bucket=bucket, Prefix=attr_prefix, Delimiter="/"):
  295. if "CommonPrefixes" in page:
  296. for prefix in page["CommonPrefixes"]:
  297. subdir = f"s3://{bucket}/{prefix['Prefix']}"
  298. attr_subdirs.append(subdir)
  299. logger.info(f"Found attribute subdirectory: {subdir}")
  300. # If no subdirectories, use the main folder
  301. if not attr_subdirs:
  302. attr_subdirs = [attr_folder]
  303. logger.info(f"No subdirectories found, using main attribute folder: {attr_folder}")
  304. else:
  305. # For local paths
  306. attr_path = Path(attr_folder)
  307. if attr_path.exists() and attr_path.is_dir():
  308. # Get subdirectories
  309. subdirs = [str(d) for d in attr_path.iterdir() if d.is_dir()]
  310. if subdirs:
  311. attr_subdirs = subdirs
  312. logger.info(f"Found {len(attr_subdirs)} attribute subdirectories")
  313. else:
  314. attr_subdirs = [attr_folder]
  315. logger.info(f"No subdirectories found, using main attribute folder: {attr_folder}")
  316. else:
  317. logger.warning(f"Attributes folder not found or not a directory: {attr_folder}")
  318. attr_subdirs = []
  319. # Load and merge documents with all attributes from all subdirectories
  320. merged_docs = []
  321. docs_by_id = {}
  322. total_attr_files = 0
  323. # First, load all document files and create a document-by-ID mapping
  324. for doc_path in doc_files:
  325. try:
  326. if doc_path.startswith("s3://"):
  327. _, doc_key = parse_s3_path(doc_path)
  328. basename = os.path.basename(doc_key)
  329. else:
  330. basename = os.path.basename(doc_path)
  331. # Load documents
  332. docs = load_jsonl_file(doc_path, s3_client)
  333. if not docs:
  334. logger.warning(f"No documents loaded from {basename} (path: {doc_path})")
  335. continue
  336. logger.info(f"Loaded {len(docs)} documents from {basename}")
  337. # Add to the merged documents list and create ID mapping
  338. for doc in docs:
  339. if "id" in doc:
  340. # If the document already exists, use the one with attributes if possible
  341. doc_id = doc["id"]
  342. if doc_id in docs_by_id:
  343. if "attributes" not in doc and "attributes" in docs_by_id[doc_id]:
  344. # Keep the existing document that has attributes
  345. continue
  346. # Initialize attributes if needed
  347. if "attributes" not in doc:
  348. doc["attributes"] = {}
  349. # Add to the mapping
  350. docs_by_id[doc_id] = doc
  351. else:
  352. # No ID, can't match with attributes
  353. if "attributes" not in doc:
  354. doc["attributes"] = {}
  355. merged_docs.append(doc)
  356. except Exception as e:
  357. logger.error(f"Error processing document file {doc_path}: {e}")
  358. continue
  359. logger.info(f"Loaded {len(docs_by_id)} unique documents with IDs")
  360. # Now process each attribute subdirectory
  361. for subdir in attr_subdirs:
  362. try:
  363. logger.info(f"Processing attribute directory: {subdir}")
  364. attr_files = list_jsonl_files(subdir, s3_client, recursive)
  365. total_attr_files += len(attr_files)
  366. logger.info(f"Found {len(attr_files)} attribute files in {subdir}")
  367. # Create a mapping from document basename to attribute file path
  368. attr_file_map = {}
  369. for attr_path in attr_files:
  370. if attr_path.startswith("s3://"):
  371. _, attr_key = parse_s3_path(attr_path)
  372. basename = os.path.basename(attr_key)
  373. else:
  374. basename = os.path.basename(attr_path)
  375. attr_file_map[basename] = attr_path
  376. # Go through the document files again to find matching attributes
  377. for doc_path in doc_files:
  378. try:
  379. if doc_path.startswith("s3://"):
  380. _, doc_key = parse_s3_path(doc_path)
  381. basename = os.path.basename(doc_key)
  382. else:
  383. basename = os.path.basename(doc_path)
  384. # Find matching attribute file
  385. if basename in attr_file_map:
  386. attr_path = attr_file_map[basename]
  387. attrs = load_jsonl_file(attr_path, s3_client)
  388. if not attrs:
  389. logger.warning(f"No attributes loaded from {os.path.basename(attr_path)} (path: {attr_path})")
  390. continue
  391. logger.info(f"Loaded {len(attrs)} attributes from {os.path.basename(attr_path)}")
  392. # Create a mapping from document ID to attributes
  393. attr_by_id = {attr["id"]: attr for attr in attrs if "id" in attr}
  394. # Count documents with matched attributes
  395. docs_matched_in_file = 0
  396. # Merge attributes into documents by ID
  397. for doc_id, doc in docs_by_id.items():
  398. if doc_id in attr_by_id:
  399. docs_matched_in_file += 1
  400. # If attributes document has attributes field, merge them
  401. if "attributes" in attr_by_id[doc_id]:
  402. doc["attributes"].update(attr_by_id[doc_id]["attributes"])
  403. logger.info(f"Matched attributes for {docs_matched_in_file} documents from {basename} in {subdir}")
  404. except Exception as e:
  405. logger.error(f"Error processing attribute file {attr_path}: {e}")
  406. continue
  407. except Exception as e:
  408. logger.error(f"Error processing attribute subdirectory {subdir}: {e}")
  409. continue
  410. # Convert the dictionary to a list for return
  411. merged_docs.extend(docs_by_id.values())
  412. logger.info(f"Total documents processed: {len(merged_docs)}")
  413. logger.info(f"Total attribute files processed: {total_attr_files}")
  414. logger.info(f"Total attribute subdirectories processed: {len(attr_subdirs)}")
  415. return merged_docs
  416. except Exception as e:
  417. logger.error(f"Error in load_documents_and_attributes: {e}")
  418. raise
  419. def apply_rule(doc, rule):
  420. """
  421. Apply a rule to determine if a document meets the PII criteria.
  422. Args:
  423. doc: The document JSON object
  424. rule: Either a tuple (attribute_name, rule_type) for simple rules,
  425. or an ExpressionNode for complex boolean expressions
  426. Returns:
  427. True if the document matches the rule, False otherwise
  428. """
  429. # Handle simple rule
  430. if not is_complex_expression(rule):
  431. return apply_simple_rule(doc, rule[0], rule[1])
  432. # Handle complex expression
  433. return evaluate_expression(doc, rule)
  434. def calculate_attribute_aggregate(doc, attribute_name, operation_type):
  435. """
  436. Calculate an aggregate value for a numeric attribute.
  437. Args:
  438. doc: The document JSON object
  439. attribute_name: The attribute field to aggregate (e.g., "pii_tagging_ratio")
  440. operation_type: The type of aggregation to perform (e.g., "avg")
  441. Returns:
  442. The aggregated value, or None if calculation is not possible
  443. """
  444. # Check if document has attributes
  445. if "attributes" not in doc or not doc["attributes"]:
  446. logger.debug(f"Document {doc.get('id', 'unknown')} has no attributes")
  447. return None
  448. attributes = doc["attributes"]
  449. # Check if the specific attribute exists
  450. if attribute_name not in attributes:
  451. logger.debug(f"Document {doc.get('id', 'unknown')} doesn't have attribute: {attribute_name}")
  452. return None
  453. if not attributes[attribute_name]:
  454. logger.debug(f"Document {doc.get('id', 'unknown')} has empty attribute: {attribute_name}")
  455. return None
  456. # Extract the numeric values from the attribute spans
  457. # Each span is formatted as [start_pos, end_pos, value]
  458. values = [span[2] for span in attributes[attribute_name] if len(span) >= 3 and span[2] is not None]
  459. if not values:
  460. logger.debug(f"Document {doc.get('id', 'unknown')} has no valid values for attribute: {attribute_name}")
  461. return None
  462. # Convert all values to float
  463. try:
  464. numeric_values = [float(value) for value in values]
  465. except (ValueError, TypeError):
  466. logger.debug(f"Document {doc.get('id', 'unknown')} has non-numeric values for attribute: {attribute_name}")
  467. return None
  468. # Perform the aggregation
  469. if operation_type == "avg":
  470. if not numeric_values:
  471. return None
  472. return sum(numeric_values) / len(numeric_values)
  473. # Add more aggregation types here as needed
  474. else:
  475. raise ValueError(f"Unknown operation type: {operation_type}")
  476. def apply_simple_rule(doc, attribute_name, rule_type):
  477. """
  478. Apply a simple rule to determine if a document meets the PII criteria.
  479. Args:
  480. doc: The document JSON object
  481. attribute_name: The attribute field to check (e.g., "gpt_4_1_contains_pii")
  482. rule_type: 'any' for any true value, 'all' for all true values,
  483. or a string containing an operation and comparison (e.g., 'avg>0.3')
  484. Returns:
  485. True if the document matches the rule, False otherwise
  486. """
  487. # Check if document has attributes
  488. if "attributes" not in doc or not doc["attributes"]:
  489. logger.debug(f"Document {doc.get('id', 'unknown')} has no attributes")
  490. return False
  491. attributes = doc["attributes"]
  492. # Check if the specific attribute exists
  493. if attribute_name not in attributes:
  494. logger.debug(f"Document {doc.get('id', 'unknown')} doesn't have attribute: {attribute_name}")
  495. return False
  496. if not attributes[attribute_name]:
  497. logger.debug(f"Document {doc.get('id', 'unknown')} has empty attribute: {attribute_name}")
  498. return False
  499. # Handle numeric comparison rules (e.g., 'avg>0.3')
  500. if any(op in rule_type for op in [">", "<", ">=", "<=", "=="]):
  501. # Parse the rule type into operation and comparison
  502. operation_parts = rule_type.split(">")
  503. if len(operation_parts) == 2:
  504. operation_type, threshold = operation_parts
  505. comparison_op = ">"
  506. else:
  507. operation_parts = rule_type.split("<")
  508. if len(operation_parts) == 2:
  509. operation_type, threshold = operation_parts
  510. comparison_op = "<"
  511. else:
  512. operation_parts = rule_type.split(">=")
  513. if len(operation_parts) == 2:
  514. operation_type, threshold = operation_parts
  515. comparison_op = ">="
  516. else:
  517. operation_parts = rule_type.split("<=")
  518. if len(operation_parts) == 2:
  519. operation_type, threshold = operation_parts
  520. comparison_op = "<="
  521. else:
  522. operation_parts = rule_type.split("==")
  523. if len(operation_parts) == 2:
  524. operation_type, threshold = operation_parts
  525. comparison_op = "=="
  526. else:
  527. raise ValueError(f"Invalid rule type: {rule_type}")
  528. # Convert threshold to float
  529. try:
  530. threshold = float(threshold)
  531. except ValueError:
  532. raise ValueError(f"Invalid threshold value: {threshold}")
  533. # Calculate the aggregate value
  534. aggregate_value = calculate_attribute_aggregate(doc, attribute_name, operation_type)
  535. if aggregate_value is None:
  536. logger.debug(f"Document {doc.get('id', 'unknown')} has no valid aggregate value for attribute: {attribute_name}")
  537. return False
  538. # Apply the comparison
  539. if comparison_op == ">":
  540. result = aggregate_value > threshold
  541. elif comparison_op == "<":
  542. result = aggregate_value < threshold
  543. elif comparison_op == ">=":
  544. result = aggregate_value >= threshold
  545. elif comparison_op == "<=":
  546. result = aggregate_value <= threshold
  547. elif comparison_op == "==":
  548. result = aggregate_value == threshold
  549. else:
  550. raise ValueError(f"Invalid comparison operator: {comparison_op}")
  551. if result:
  552. logger.debug(f"Document {doc.get('id', 'unknown')} matched numeric rule '{attribute_name}:{rule_type}' with value {aggregate_value}")
  553. return result
  554. # Handle boolean rules (any/all)
  555. if rule_type in ["any", "all"]:
  556. # Extract the boolean values from the attribute spans
  557. # Each span is formatted as [start_pos, end_pos, value]
  558. values = [span[2] for span in attributes[attribute_name] if len(span) >= 3 and span[2] is not None]
  559. if not values:
  560. logger.debug(f"Document {doc.get('id', 'unknown')} has no valid values for attribute: {attribute_name}")
  561. return False
  562. # Apply the rule
  563. if rule_type == "any":
  564. result = any(values)
  565. if result:
  566. logger.debug(f"Document {doc.get('id', 'unknown')} matched rule '{attribute_name}:{rule_type}' (found True in {len(values)} values)")
  567. return result
  568. elif rule_type == "all":
  569. result = all(values)
  570. if result:
  571. logger.debug(f"Document {doc.get('id', 'unknown')} matched rule '{attribute_name}:{rule_type}' (all {len(values)} values are True)")
  572. return result
  573. raise ValueError(f"Unknown rule type: {rule_type}")
  574. def evaluate_expression(doc, expr):
  575. """
  576. Evaluate a boolean expression on a document.
  577. Args:
  578. doc: The document JSON object
  579. expr: An ExpressionNode representing a boolean expression
  580. Returns:
  581. True if the document matches the expression, False otherwise
  582. """
  583. if isinstance(expr, RuleNode):
  584. # Base case: evaluate a leaf rule node
  585. return apply_simple_rule(doc, expr.attribute_name, expr.rule_type)
  586. elif isinstance(expr, NotNode):
  587. # NOT operator
  588. return not evaluate_expression(doc, expr.operand)
  589. elif isinstance(expr, BinaryNode):
  590. # Binary operators (AND/OR)
  591. if expr.operator == "AND":
  592. # Short-circuit AND evaluation
  593. return evaluate_expression(doc, expr.left) and evaluate_expression(doc, expr.right)
  594. elif expr.operator == "OR":
  595. # Short-circuit OR evaluation
  596. return evaluate_expression(doc, expr.left) or evaluate_expression(doc, expr.right)
  597. # Should not reach here if the expression tree is well-formed
  598. raise ValueError(f"Invalid expression node type: {type(expr)}")
  599. def tokenize_expression(expression):
  600. """
  601. Tokenize a rule expression string into a list of tokens.
  602. Args:
  603. expression: A string containing a boolean rule expression
  604. (e.g., "not rule1:any and rule2:all")
  605. Returns:
  606. A list of Token objects
  607. """
  608. tokens = []
  609. i = 0
  610. expression = expression.strip()
  611. while i < len(expression):
  612. char = expression[i]
  613. # Skip whitespace
  614. if char.isspace():
  615. i += 1
  616. continue
  617. # Handle parentheses
  618. elif char == "(":
  619. tokens.append(Token(TokenType.LPAREN))
  620. i += 1
  621. elif char == ")":
  622. tokens.append(Token(TokenType.RPAREN))
  623. i += 1
  624. # Handle operators
  625. elif i + 2 < len(expression) and expression[i : i + 3].lower() == "and":
  626. # Check if it's a standalone 'and' and not part of a word
  627. if (i == 0 or expression[i - 1].isspace() or expression[i - 1] in "()") and (
  628. i + 3 >= len(expression) or expression[i + 3].isspace() or expression[i + 3] in "()"
  629. ):
  630. tokens.append(Token(TokenType.AND))
  631. i += 3
  632. else:
  633. # It's part of an attribute name
  634. rule_start = i
  635. while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
  636. if i + 1 < len(expression) and expression[i] == ":":
  637. break
  638. i += 1
  639. # Process rule if we found a colon
  640. if i < len(expression) and expression[i] == ":":
  641. rule_end = i
  642. i += 1 # Skip the colon
  643. # Find the rule type
  644. type_start = i
  645. while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
  646. i += 1
  647. rule_name = expression[rule_start:rule_end]
  648. rule_type = expression[type_start:i]
  649. tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
  650. else:
  651. raise ValueError(f"Invalid rule format at position {rule_start}")
  652. elif i + 1 < len(expression) and expression[i : i + 2].lower() == "or":
  653. # Check if it's a standalone 'or' and not part of a word
  654. if (i == 0 or expression[i - 1].isspace() or expression[i - 1] in "()") and (
  655. i + 2 >= len(expression) or expression[i + 2].isspace() or expression[i + 2] in "()"
  656. ):
  657. tokens.append(Token(TokenType.OR))
  658. i += 2
  659. else:
  660. # Part of an attribute name
  661. rule_start = i
  662. while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
  663. if i + 1 < len(expression) and expression[i] == ":":
  664. break
  665. i += 1
  666. # Process rule if we found a colon
  667. if i < len(expression) and expression[i] == ":":
  668. rule_end = i
  669. i += 1 # Skip the colon
  670. # Find the rule type
  671. type_start = i
  672. while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
  673. i += 1
  674. rule_name = expression[rule_start:rule_end]
  675. rule_type = expression[type_start:i]
  676. tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
  677. else:
  678. raise ValueError(f"Invalid rule format at position {rule_start}")
  679. elif i + 2 < len(expression) and expression[i : i + 3].lower() == "not":
  680. # Check if it's a standalone 'not' and not part of a word
  681. if (i == 0 or expression[i - 1].isspace() or expression[i - 1] in "()") and (
  682. i + 3 >= len(expression) or expression[i + 3].isspace() or expression[i + 3] in "()"
  683. ):
  684. tokens.append(Token(TokenType.NOT))
  685. i += 3
  686. else:
  687. # Part of an attribute name
  688. rule_start = i
  689. while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
  690. if i + 1 < len(expression) and expression[i] == ":":
  691. break
  692. i += 1
  693. # Process rule if we found a colon
  694. if i < len(expression) and expression[i] == ":":
  695. rule_end = i
  696. i += 1 # Skip the colon
  697. # Find the rule type
  698. type_start = i
  699. while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
  700. i += 1
  701. rule_name = expression[rule_start:rule_end]
  702. rule_type = expression[type_start:i]
  703. tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
  704. else:
  705. raise ValueError(f"Invalid rule format at position {rule_start}")
  706. # Handle rule (attribute:type)
  707. else:
  708. rule_start = i
  709. while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
  710. if i + 1 < len(expression) and expression[i] == ":":
  711. break
  712. i += 1
  713. # Process rule if we found a colon
  714. if i < len(expression) and expression[i] == ":":
  715. rule_end = i
  716. i += 1 # Skip the colon
  717. # Find the rule type
  718. type_start = i
  719. while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
  720. i += 1
  721. rule_name = expression[rule_start:rule_end]
  722. rule_type = expression[type_start:i]
  723. tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
  724. else:
  725. raise ValueError(f"Invalid rule format at position {rule_start}")
  726. tokens.append(Token(TokenType.EOF))
  727. return tokens
  728. class Parser:
  729. """
  730. Parser for boolean rule expressions.
  731. Implements a recursive descent parser for expressions with the following grammar:
  732. expression → or_expr
  733. or_expr → and_expr ("or" and_expr)*
  734. and_expr → unary_expr ("and" unary_expr)*
  735. unary_expr → "not" unary_expr | primary
  736. primary → rule | "(" expression ")"
  737. rule → ATTRIBUTE ":" RULE_TYPE
  738. """
  739. def __init__(self, tokens):
  740. self.tokens = tokens
  741. self.current = 0
  742. def parse(self):
  743. """Parse the tokens into an expression tree."""
  744. return self.expression()
  745. def expression(self):
  746. """Parse an expression (top level)."""
  747. return self.or_expr()
  748. def or_expr(self):
  749. """Parse an OR expression."""
  750. expr = self.and_expr()
  751. while self.match(TokenType.OR):
  752. right = self.and_expr()
  753. expr = BinaryNode(expr, right, "OR")
  754. return expr
  755. def and_expr(self):
  756. """Parse an AND expression."""
  757. expr = self.unary_expr()
  758. while self.match(TokenType.AND):
  759. right = self.unary_expr()
  760. expr = BinaryNode(expr, right, "AND")
  761. return expr
  762. def unary_expr(self):
  763. """Parse a unary expression (NOT)."""
  764. if self.match(TokenType.NOT):
  765. operand = self.unary_expr()
  766. return NotNode(operand)
  767. return self.primary()
  768. def primary(self):
  769. """Parse a primary expression (rule or parenthesized expression)."""
  770. if self.match(TokenType.RULE):
  771. rule_tuple = self.previous().value
  772. attribute_name, rule_type = rule_tuple
  773. # Validate rule type
  774. if rule_type not in ["any", "all"] and not any(op in rule_type for op in [">", "<", ">=", "<=", "=="]):
  775. raise ValueError(f"Invalid rule type: {rule_type}. Supported types: 'any', 'all', or numeric comparison (e.g., 'avg>0.3')")
  776. return RuleNode(attribute_name, rule_type)
  777. if self.match(TokenType.LPAREN):
  778. expr = self.expression()
  779. self.consume(TokenType.RPAREN, "Expected ')' after expression.")
  780. return expr
  781. raise ValueError(f"Expected rule or '(' at position {self.current}")
  782. def match(self, *types):
  783. """Check if the current token matches any of the given types."""
  784. for type in types:
  785. if self.check(type):
  786. self.advance()
  787. return True
  788. return False
  789. def check(self, type):
  790. """Check if the current token is of the given type without advancing."""
  791. if self.is_at_end():
  792. return False
  793. return self.peek().type == type
  794. def advance(self):
  795. """Advance to the next token and return the previous one."""
  796. if not self.is_at_end():
  797. self.current += 1
  798. return self.previous()
  799. def consume(self, type, message):
  800. """Consume the current token if it matches the expected type."""
  801. if self.check(type):
  802. return self.advance()
  803. raise ValueError(f"{message} at position {self.current}")
  804. def is_at_end(self):
  805. """Check if we've reached the end of the tokens."""
  806. return self.peek().type == TokenType.EOF
  807. def peek(self):
  808. """Return the current token without advancing."""
  809. return self.tokens[self.current]
  810. def previous(self):
  811. """Return the previous token."""
  812. return self.tokens[self.current - 1]
  813. def parse_rule(rule_string):
  814. """
  815. Parse a rule string into an expression tree or a simple attribute-rule_type tuple.
  816. Args:
  817. rule_string: A string containing a rule or boolean expression of rules
  818. Returns:
  819. Either a tuple (attribute_name, rule_type) for simple rules,
  820. or an ExpressionNode for complex boolean expressions
  821. """
  822. # Check if this is a simple rule
  823. if (
  824. "and" not in rule_string.lower()
  825. and "or" not in rule_string.lower()
  826. and "not" not in rule_string.lower()
  827. and "(" not in rule_string
  828. and ")" not in rule_string
  829. ):
  830. # Simple rule format: attribute_name:rule_type
  831. parts = rule_string.split(":", 1)
  832. if len(parts) != 2:
  833. raise ValueError(f"Invalid rule format: {rule_string}. Expected format: 'attribute_name:rule_type'")
  834. attribute_name, rule_type = parts
  835. # Check for numeric comparison rule_type
  836. if any(op in rule_type for op in [">", "<", ">=", "<=", "=="]):
  837. # This is a numeric comparison rule - we'll validate it in apply_simple_rule
  838. return attribute_name, rule_type
  839. elif rule_type not in ["any", "all"]:
  840. raise ValueError(f"Invalid rule type: {rule_type}. Supported types: 'any', 'all', or numeric comparison (e.g., 'avg>0.3')")
  841. return attribute_name, rule_type
  842. else:
  843. # Complex rule expression
  844. try:
  845. tokens = tokenize_expression(rule_string)
  846. parser = Parser(tokens)
  847. return parser.parse()
  848. except Exception as e:
  849. raise ValueError(f"Error parsing expression '{rule_string}': {e}")
  850. def is_complex_expression(rule):
  851. """Check if the rule is a complex boolean expression."""
  852. return isinstance(rule, ExpressionNode)
  853. def calculate_iou(ref_ids, hyp_ids):
  854. """Calculate Intersection over Union of two sets of document IDs."""
  855. ref_set = set(ref_ids)
  856. hyp_set = set(hyp_ids)
  857. intersection = ref_set.intersection(hyp_set)
  858. union = ref_set.union(hyp_set)
  859. if not union:
  860. return 0.0
  861. return len(intersection) / len(union)
  862. def collect_rule_stats(expression, doc):
  863. """
  864. Collect statistics for all rules within a complex expression.
  865. Args:
  866. expression: A rule expression (either a tuple or ExpressionNode)
  867. doc: The document to analyze
  868. Returns:
  869. A dictionary with rule statistics
  870. """
  871. rule_stats = defaultdict(int)
  872. # Handle simple rule
  873. if not is_complex_expression(expression):
  874. attribute_name, rule_type = expression
  875. # Only process if document has this attribute
  876. if "attributes" in doc and doc["attributes"] and attribute_name in doc["attributes"] and doc["attributes"][attribute_name]:
  877. # The rule name will be the key for the statistics
  878. rule_name = f"{attribute_name}:{rule_type}"
  879. # Count entries in the attribute
  880. entries = doc["attributes"][attribute_name]
  881. rule_stats[f"{rule_name}_total_entries"] += len(entries)
  882. # Count positive values
  883. for span in entries:
  884. if len(span) >= 3 and span[2] is True:
  885. rule_stats[f"{rule_name}_positive_entries"] += 1
  886. # Check if document matches the rule
  887. if apply_simple_rule(doc, attribute_name, rule_type):
  888. rule_stats[f"{rule_name}_matched_docs"] += 1
  889. return rule_stats
  890. # For complex expressions, traverse the expression tree
  891. if isinstance(expression, RuleNode):
  892. # Base case: leaf node is a simple rule
  893. attribute_name, rule_type = expression.attribute_name, expression.rule_type
  894. if "attributes" in doc and doc["attributes"] and attribute_name in doc["attributes"] and doc["attributes"][attribute_name]:
  895. # The rule name will be the key for the statistics
  896. rule_name = f"{attribute_name}:{rule_type}"
  897. # Count entries in the attribute
  898. entries = doc["attributes"][attribute_name]
  899. rule_stats[f"{rule_name}_total_entries"] += len(entries)
  900. # Count positive values
  901. for span in entries:
  902. if len(span) >= 3 and span[2] is True:
  903. rule_stats[f"{rule_name}_positive_entries"] += 1
  904. # Check if document matches the rule
  905. if apply_simple_rule(doc, attribute_name, rule_type):
  906. rule_stats[f"{rule_name}_matched_docs"] += 1
  907. elif isinstance(expression, NotNode):
  908. # Get stats from the operand
  909. operand_stats = collect_rule_stats(expression.operand, doc)
  910. # Merge with current stats
  911. for key, value in operand_stats.items():
  912. rule_stats[key] += value
  913. elif isinstance(expression, BinaryNode):
  914. # Get stats from both sides
  915. left_stats = collect_rule_stats(expression.left, doc)
  916. right_stats = collect_rule_stats(expression.right, doc)
  917. # Merge with current stats
  918. for key, value in left_stats.items():
  919. rule_stats[key] += value
  920. for key, value in right_stats.items():
  921. rule_stats[key] += value
  922. return rule_stats
  923. def get_expression_summary(expression):
  924. """
  925. Generate a string representation of a rule expression.
  926. Args:
  927. expression: A rule expression (either a tuple or ExpressionNode)
  928. Returns:
  929. A string representation of the expression
  930. """
  931. if not is_complex_expression(expression):
  932. return f"{expression[0]}:{expression[1]}"
  933. if isinstance(expression, RuleNode):
  934. return f"{expression.attribute_name}:{expression.rule_type}"
  935. elif isinstance(expression, NotNode):
  936. return f"not {get_expression_summary(expression.operand)}"
  937. elif isinstance(expression, BinaryNode):
  938. left_summary = get_expression_summary(expression.left)
  939. right_summary = get_expression_summary(expression.right)
  940. return f"({left_summary} {expression.operator.lower()} {right_summary})"
  941. return str(expression)
  942. def compare_documents(ref_docs, hyp_docs, ref_rule, hyp_rule):
  943. """
  944. Compare two sets of documents using the specified rules and calculate IoU.
  945. Args:
  946. ref_docs: List of reference documents
  947. hyp_docs: List of hypothesis documents
  948. ref_rule: Rule expression for reference (tuple or ExpressionNode)
  949. hyp_rule: Rule expression for hypothesis (tuple or ExpressionNode)
  950. Returns:
  951. Dictionary with comparison results
  952. """
  953. # Extract document IDs and create ID-to-document maps
  954. ref_id_to_doc = {doc["id"]: doc for doc in ref_docs if "id" in doc}
  955. hyp_id_to_doc = {doc["id"]: doc for doc in hyp_docs if "id" in doc}
  956. # Get common document IDs
  957. common_ids = set(ref_id_to_doc.keys()).intersection(set(hyp_id_to_doc.keys()))
  958. # Apply rules to each document
  959. ref_matches = set()
  960. hyp_matches = set()
  961. # Track rule statistics
  962. ref_rule_stats = defaultdict(int)
  963. hyp_rule_stats = defaultdict(int)
  964. for doc_id in common_ids:
  965. ref_doc = ref_id_to_doc[doc_id]
  966. hyp_doc = hyp_id_to_doc[doc_id]
  967. # Collect statistics for all rules in the expressions
  968. doc_ref_rule_stats = collect_rule_stats(ref_rule, ref_doc)
  969. doc_hyp_rule_stats = collect_rule_stats(hyp_rule, hyp_doc)
  970. # Merge with overall stats
  971. for key, value in doc_ref_rule_stats.items():
  972. ref_rule_stats[key] += value
  973. for key, value in doc_hyp_rule_stats.items():
  974. hyp_rule_stats[key] += value
  975. # Check if document matches the rule expressions
  976. if apply_rule(ref_doc, ref_rule):
  977. ref_matches.add(doc_id)
  978. ref_rule_stats["expression_matched_docs"] += 1
  979. if apply_rule(hyp_doc, hyp_rule):
  980. hyp_matches.add(doc_id)
  981. hyp_rule_stats["expression_matched_docs"] += 1
  982. # Calculate IoU
  983. iou = calculate_iou(ref_matches, hyp_matches)
  984. # Collect detailed statistics
  985. tp = len(ref_matches.intersection(hyp_matches))
  986. fp = len(hyp_matches - ref_matches)
  987. fn = len(ref_matches - hyp_matches)
  988. precision = tp / (tp + fp) if (tp + fp) > 0 else 0
  989. recall = tp / (tp + fn) if (tp + fn) > 0 else 0
  990. f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
  991. # Generate string representations of the expressions
  992. ref_rule_str = get_expression_summary(ref_rule)
  993. hyp_rule_str = get_expression_summary(hyp_rule)
  994. return {
  995. "total_docs": len(common_ids),
  996. "ref_rule": ref_rule_str,
  997. "hyp_rule": hyp_rule_str,
  998. "ref_matches": len(ref_matches),
  999. "hyp_matches": len(hyp_matches),
  1000. "intersection": tp,
  1001. "union": tp + fp + fn,
  1002. "true_positives": tp,
  1003. "false_positives": fp,
  1004. "false_negatives": fn,
  1005. "precision": precision,
  1006. "recall": recall,
  1007. "f1": f1,
  1008. "iou": iou,
  1009. "ref_rule_stats": dict(ref_rule_stats),
  1010. "hyp_rule_stats": dict(hyp_rule_stats),
  1011. }
  1012. def format_rule_stats(rule_stats):
  1013. """Format rule statistics for display."""
  1014. # Group the statistics by rule name
  1015. grouped_stats = defaultdict(dict)
  1016. # Process regular rule stats (format: "{rule_name}_{stat_type}")
  1017. for key, value in rule_stats.items():
  1018. if key == "expression_matched_docs":
  1019. # Special case for the overall expression match count
  1020. continue
  1021. # Extract rule name and stat type
  1022. if "_total_entries" in key:
  1023. rule_name = key.replace("_total_entries", "")
  1024. grouped_stats[rule_name]["total_entries"] = value
  1025. elif "_positive_entries" in key:
  1026. rule_name = key.replace("_positive_entries", "")
  1027. grouped_stats[rule_name]["positive_entries"] = value
  1028. elif "_matched_docs" in key:
  1029. rule_name = key.replace("_matched_docs", "")
  1030. grouped_stats[rule_name]["matched_docs"] = value
  1031. # Format the grouped statistics as a list of strings
  1032. formatted_stats = []
  1033. for rule_name, stats in grouped_stats.items():
  1034. formatted_stats.append(
  1035. f" {rule_name}:\n"
  1036. f" - Total Entries: {stats.get('total_entries', 0)}\n"
  1037. f" - Positive Entries: {stats.get('positive_entries', 0)}\n"
  1038. f" - Matched Documents: {stats.get('matched_docs', 0)}"
  1039. )
  1040. # Add the expression matched count if available
  1041. if "expression_matched_docs" in rule_stats:
  1042. formatted_stats.append(f" Overall Expression Matched Documents: {rule_stats['expression_matched_docs']}")
  1043. return "\n".join(formatted_stats)
  1044. def collect_numeric_attributes(documents):
  1045. """
  1046. Collect all numeric attribute values from documents.
  1047. Args:
  1048. documents: List of documents with attributes
  1049. Returns:
  1050. Dictionary mapping attribute names to lists of numeric values
  1051. """
  1052. numeric_attributes = defaultdict(list)
  1053. for doc in documents:
  1054. if "attributes" not in doc or not doc["attributes"]:
  1055. continue
  1056. for attr_name, attr_values in doc["attributes"].items():
  1057. if not attr_values:
  1058. continue
  1059. # Try to extract numeric values from the attribute spans
  1060. # Each span is formatted as [start_pos, end_pos, value]
  1061. for span in attr_values:
  1062. if len(span) >= 3 and span[2] is not None:
  1063. try:
  1064. # Convert to float if it's a numeric value
  1065. value = float(span[2])
  1066. numeric_attributes[attr_name].append(value)
  1067. except (ValueError, TypeError):
  1068. # Not a numeric value, skip
  1069. pass
  1070. # Filter out attributes with no or too few numeric values
  1071. return {k: v for k, v in numeric_attributes.items() if len(v) > 5}
  1072. def generate_cdf_plot(values, attribute_name):
  1073. """
  1074. Generate a CDF plot for the given numeric values.
  1075. Args:
  1076. values: List of numeric values
  1077. attribute_name: Name of the attribute (for plot title)
  1078. Returns:
  1079. Base64-encoded PNG image of the plot or None if there's an error
  1080. """
  1081. try:
  1082. # Ensure we have enough data points
  1083. if len(values) < 5:
  1084. logger.warning(f"Not enough data points to generate CDF for {attribute_name}")
  1085. return None
  1086. # Remove any NaN or infinite values
  1087. values = np.array([v for v in values if np.isfinite(v)])
  1088. if len(values) < 5:
  1089. logger.warning(f"Not enough finite data points to generate CDF for {attribute_name}")
  1090. return None
  1091. # Handle extreme values by removing outliers (optional)
  1092. # if len(values) > 30: # Only apply if we have enough data points
  1093. # q1, q3 = np.percentile(values, [25, 75])
  1094. # iqr = q3 - q1
  1095. # lower_bound = q1 - 3 * iqr
  1096. # upper_bound = q3 + 3 * iqr
  1097. # values = values[(values >= lower_bound) & (values <= upper_bound)]
  1098. # Sort values for CDF calculation
  1099. values = np.sort(values)
  1100. # Create a Figure object (no interactive display)
  1101. fig = Figure(figsize=(10, 6))
  1102. ax = fig.add_subplot(1, 1, 1)
  1103. # Calculate CDF (y-values are 0 to 1 for cumulative probability)
  1104. y = np.arange(1, len(values) + 1) / len(values)
  1105. # Plot the CDF
  1106. ax.plot(values, y, "b-", linewidth=2)
  1107. ax.grid(True, linestyle="--", alpha=0.7)
  1108. # Add labels and title
  1109. ax.set_xlabel("Value", fontsize=12)
  1110. ax.set_ylabel("Cumulative Probability", fontsize=12)
  1111. ax.set_title(f"CDF of {attribute_name}", fontsize=14)
  1112. # Ensure the y-axis goes from 0 to 1 for probability
  1113. ax.set_ylim(0, 1.05)
  1114. # Add some statistics to the plot
  1115. if len(values) > 0:
  1116. mean_val = np.mean(values)
  1117. median_val = np.median(values)
  1118. min_val = np.min(values)
  1119. max_val = np.max(values)
  1120. stats_text = f"n={len(values)}\nmin={min_val:.2f}\nmax={max_val:.2f}\nmean={mean_val:.2f}\nmedian={median_val:.2f}"
  1121. ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, verticalalignment="top", bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
  1122. # Make layout tight
  1123. fig.tight_layout()
  1124. # Convert to base64 for embedding in HTML
  1125. buf = BytesIO()
  1126. fig.savefig(buf, format="png", dpi=100)
  1127. buf.seek(0)
  1128. img_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
  1129. return img_base64
  1130. except Exception as e:
  1131. logger.error(f"Error generating CDF plot for {attribute_name}: {e}")
  1132. return None
  1133. def generate_attribute_plots_html(numeric_attributes, max_plots=20):
  1134. """
  1135. Generate HTML section with CDF plots for all numeric attributes.
  1136. Args:
  1137. numeric_attributes: Dictionary mapping attribute names to lists of numeric values
  1138. max_plots: Maximum number of plots to generate
  1139. Returns:
  1140. HTML string with embedded CDF plots
  1141. """
  1142. if not numeric_attributes:
  1143. return ""
  1144. html = """
  1145. <h2>Numeric Attribute Distributions</h2>
  1146. <div class="attribute-plots">
  1147. """
  1148. plot_count = 0
  1149. # Sort attributes by number of values (most values first)
  1150. sorted_attrs = sorted(numeric_attributes.items(), key=lambda x: len(x[1]), reverse=True)
  1151. for attr_name, values in sorted_attrs:
  1152. if len(values) < 10: # Skip attributes with too few values for meaningful plots
  1153. continue
  1154. if plot_count >= max_plots:
  1155. logger.info(f"Limiting CDF plots to {max_plots} attributes to avoid performance issues")
  1156. break
  1157. # Generate the CDF plot
  1158. img_base64 = generate_cdf_plot(values, attr_name)
  1159. # Only add to HTML if plot generation was successful
  1160. if img_base64:
  1161. html += f"""
  1162. <div class="plot-container">
  1163. <h3>{attr_name}</h3>
  1164. <img src="data:image/png;base64,{img_base64}" alt="CDF plot for {attr_name}" class="cdf-plot">
  1165. <p>Number of values: {len(values)}</p>
  1166. </div>
  1167. """
  1168. plot_count += 1
  1169. if plot_count == 0:
  1170. return "" # Don't add the section if no plots were generated
  1171. html += """
  1172. </div>
  1173. """
  1174. return html
  1175. def generate_html_report(docs, title, summary, output_path):
  1176. """
  1177. Generate an HTML report file with document texts
  1178. Args:
  1179. docs: List of documents to include in the report
  1180. title: Title of the report
  1181. summary: Summary statistics to include at the top
  1182. output_path: Path to save the HTML file
  1183. Returns:
  1184. None
  1185. """
  1186. # Create header with CSS styling
  1187. html = f"""<!DOCTYPE html>
  1188. <html>
  1189. <head>
  1190. <meta charset="UTF-8">
  1191. <title>{title}</title>
  1192. <style>
  1193. body {{
  1194. font-family: Arial, sans-serif;
  1195. line-height: 1.6;
  1196. margin: 0;
  1197. padding: 0;
  1198. scroll-behavior: smooth;
  1199. }}
  1200. /* Header bar styles */
  1201. .header {{
  1202. background-color: #f8f9fa;
  1203. padding: 8px 20px;
  1204. box-shadow: 0 2px 5px rgba(0,0,0,0.1);
  1205. position: fixed;
  1206. top: 0;
  1207. left: 0;
  1208. right: 0;
  1209. z-index: 100;
  1210. display: flex;
  1211. justify-content: space-between;
  1212. align-items: center;
  1213. height: 40px;
  1214. }}
  1215. .title {{
  1216. font-size: 1.2em;
  1217. font-weight: bold;s
  1218. color: #333;
  1219. white-space: nowrap;
  1220. overflow: hidden;
  1221. text-overflow: ellipsis;
  1222. max-width: 60%;
  1223. }}
  1224. .controls {{
  1225. display: flex;
  1226. align-items: center;
  1227. }}
  1228. .keyboard-controls {{
  1229. font-size: 0.85em;
  1230. margin-right: 15px;
  1231. }}
  1232. .toggle-summary {{
  1233. background-color: #e9ecef;
  1234. border: 1px solid #ced4da;
  1235. padding: 4px 10px;
  1236. border-radius: 4px;
  1237. cursor: pointer;
  1238. font-size: 0.85em;
  1239. }}
  1240. /* Summary panel styles */
  1241. #summary-panel {{
  1242. position: fixed;
  1243. top: 57px;
  1244. left: 0;
  1245. right: 0;
  1246. background-color: #f8f9fa;
  1247. border-bottom: 1px solid #ddd;
  1248. padding: 15px 20px;
  1249. z-index: 90;
  1250. display: none;
  1251. max-height: 300px;
  1252. overflow-y: auto;
  1253. box-shadow: 0 2px 5px rgba(0,0,0,0.1);
  1254. }}
  1255. /* Main content styles */
  1256. .container {{
  1257. max-width: 1200px;
  1258. margin: 0 auto;
  1259. padding: 60px 20px 20px 20px;
  1260. }}
  1261. /* Document styles */
  1262. .document {{
  1263. background-color: #fff;
  1264. padding: 15px;
  1265. margin-bottom: 15px;
  1266. border: 1px solid #ddd;
  1267. border-radius: 5px;
  1268. box-shadow: 0 1px 3px rgba(0,0,0,0.1);
  1269. transition: all 0.2s ease-in-out;
  1270. scroll-margin-top: 60px;
  1271. }}
  1272. .document:hover {{
  1273. box-shadow: 0 2px 5px rgba(0,0,0,0.2);
  1274. }}
  1275. .document.selected {{
  1276. border: 2px solid #007bff;
  1277. box-shadow: 0 0 8px rgba(0, 123, 255, 0.5);
  1278. background-color: #f8f9fa;
  1279. }}
  1280. .document-id {{
  1281. color: #007bff;
  1282. font-weight: bold;
  1283. margin-bottom: 5px;
  1284. font-size: 0.9em;
  1285. }}
  1286. .document-text {{
  1287. white-space: pre-wrap;
  1288. overflow-wrap: break-word;
  1289. }}
  1290. /* Helper styles */
  1291. h2 {{
  1292. margin-top: 0;
  1293. font-size: 1.2em;
  1294. color: #333;
  1295. }}
  1296. pre {{
  1297. font-size: 0.9em;
  1298. white-space: pre-wrap;
  1299. }}
  1300. .stats {{
  1301. color: #666;
  1302. font-size: 0.8em;
  1303. font-weight: normal;
  1304. }}
  1305. .keyboard-shortcut {{
  1306. display: inline-block;
  1307. padding: 1px 4px;
  1308. margin: 0 1px;
  1309. border-radius: 3px;
  1310. background-color: #f1f3f5;
  1311. border: 1px solid #ced4da;
  1312. font-family: monospace;
  1313. font-size: 0.9em;
  1314. }}
  1315. </style>
  1316. </head>
  1317. <body>
  1318. <!-- Fixed header -->
  1319. <div class="header">
  1320. <div class="title">{title} <span class="stats">({len(docs)} documents)</span></div>
  1321. <div class="controls">
  1322. <div class="keyboard-controls">
  1323. <span class="keyboard-shortcut">↑</span>/<span class="keyboard-shortcut">↓</span> to navigate
  1324. &nbsp;<span class="keyboard-shortcut">Home</span>/<span class="keyboard-shortcut">End</span>
  1325. </div>
  1326. <button class="toggle-summary" onclick="toggleSummary()">Show Summary</button>
  1327. </div>
  1328. </div>
  1329. <!-- Summary panel (initially hidden) -->
  1330. <div id="summary-panel">
  1331. <h2>Summary</h2>
  1332. <pre>{summary}</pre>
  1333. </div>
  1334. <!-- Main content -->
  1335. <div class="container">
  1336. <div id="document-container">
  1337. """
  1338. # Add each document with a unique ID
  1339. for i, doc in enumerate(docs, 1):
  1340. doc_id = doc.get("id", f"unknown_{i}")
  1341. # Get document text, falling back to JSON representation if not available
  1342. doc_text = doc.get("text", json.dumps(doc, indent=2))
  1343. # The first document gets the "selected" class
  1344. selected_class = " selected" if i == 1 else ""
  1345. html += f"""
  1346. <div id="doc-{i}" class="document{selected_class}" tabindex="0">
  1347. <div class="document-id">Document ID: {doc_id}</div>
  1348. <pre class="document-text">{pyhtml.escape(doc_text)}</pre>
  1349. </div>
  1350. """
  1351. # Add JavaScript for keyboard navigation and summary toggle
  1352. html += """
  1353. </div>
  1354. </div>
  1355. <script>
  1356. // Get all documents
  1357. const documents = document.querySelectorAll('.document');
  1358. let selectedIndex = 0; // First document is selected by default
  1359. let summaryVisible = false;
  1360. // Function to toggle summary panel
  1361. function toggleSummary() {
  1362. const panel = document.getElementById('summary-panel');
  1363. const button = document.querySelector('.toggle-summary');
  1364. if (summaryVisible) {
  1365. panel.style.display = 'none';
  1366. button.textContent = 'Show Summary';
  1367. } else {
  1368. panel.style.display = 'block';
  1369. button.textContent = 'Hide Summary';
  1370. }
  1371. summaryVisible = !summaryVisible;
  1372. }
  1373. // Function to select a document
  1374. function selectDocument(index) {
  1375. // Validate index
  1376. if (index < 0) index = 0;
  1377. if (index >= documents.length) index = documents.length - 1;
  1378. // Store current index for use in setTimeout
  1379. const targetIndex = index;
  1380. // Remove selected class from all documents
  1381. documents.forEach(doc => doc.classList.remove('selected'));
  1382. // Add selected class to the current document
  1383. documents[targetIndex].classList.add('selected');
  1384. // Update selected index
  1385. selectedIndex = targetIndex;
  1386. // Use a more direct approach for scrolling
  1387. // Get the element's offset from the top of the document
  1388. const headerHeight = 60; // Fixed header height
  1389. const element = documents[targetIndex];
  1390. const elementPosition = element.offsetTop;
  1391. // Scroll the element to the top of the viewport, accounting for header
  1392. window.scrollTo({
  1393. top: elementPosition - headerHeight,
  1394. behavior: 'smooth'
  1395. });
  1396. // Focus the selected document for accessibility
  1397. documents[targetIndex].focus();
  1398. }
  1399. // Add keyboard event listener to the document
  1400. document.addEventListener('keydown', function(event) {
  1401. // Arrow up
  1402. if (event.key === 'ArrowUp') {
  1403. event.preventDefault();
  1404. selectDocument(selectedIndex - 1);
  1405. }
  1406. // Arrow down
  1407. else if (event.key === 'ArrowDown') {
  1408. event.preventDefault();
  1409. selectDocument(selectedIndex + 1);
  1410. }
  1411. // Home key - go to first document
  1412. else if (event.key === 'Home') {
  1413. event.preventDefault();
  1414. selectDocument(0);
  1415. }
  1416. // End key - go to last document
  1417. else if (event.key === 'End') {
  1418. event.preventDefault();
  1419. selectDocument(documents.length - 1);
  1420. }
  1421. // Escape key - hide summary if visible
  1422. else if (event.key === 'Escape' && summaryVisible) {
  1423. toggleSummary();
  1424. }
  1425. // S key - toggle summary
  1426. else if (event.key === 's' || event.key === 'S') {
  1427. toggleSummary();
  1428. }
  1429. });
  1430. // Make documents clickable to select them
  1431. documents.forEach((doc, index) => {
  1432. doc.addEventListener('click', () => {
  1433. selectDocument(index);
  1434. });
  1435. });
  1436. // Select the first document when the page loads
  1437. window.addEventListener('load', () => {
  1438. // If there are documents, select the first one
  1439. if (documents.length > 0) {
  1440. selectDocument(0);
  1441. }
  1442. });
  1443. </script>
  1444. </body>
  1445. </html>
  1446. """
  1447. # Create directory if it doesn't exist
  1448. os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
  1449. # Write HTML to file
  1450. with open(output_path, "w", encoding="utf-8") as f:
  1451. f.write(html)
  1452. logger.info(f"Generated HTML report: {output_path}")
  1453. def main():
  1454. global args
  1455. args = parse_args()
  1456. # Set up logging based on arguments
  1457. if args.debug:
  1458. logger.setLevel(logging.DEBUG)
  1459. logger.debug("Debug logging enabled")
  1460. # Set up S3 client if needed
  1461. s3_client = None
  1462. if args.docs_folder.startswith("s3://") or (args.attr_folder and args.attr_folder.startswith("s3://")):
  1463. session = boto3.Session(profile_name=args.aws_profile) if args.aws_profile else boto3.Session()
  1464. s3_client = session.client("s3")
  1465. # Parse the rules
  1466. logger.info(f"Parsing reference rule expression: {args.ref_rule}")
  1467. ref_rule = parse_rule(args.ref_rule)
  1468. logger.info(f"Parsing hypothesis rule expression: {args.hyp_rule}")
  1469. hyp_rule = parse_rule(args.hyp_rule)
  1470. # Generate string representations of the expressions
  1471. ref_rule_str = get_expression_summary(ref_rule)
  1472. hyp_rule_str = get_expression_summary(hyp_rule)
  1473. logger.info(f"Reference rule parsed as: {ref_rule_str}")
  1474. logger.info(f"Hypothesis rule parsed as: {hyp_rule_str}")
  1475. # Determine attributes folder
  1476. attr_folder = get_attributes_folder(args.docs_folder, args.attr_folder)
  1477. logger.info(f"Using attributes folder: {attr_folder}")
  1478. # Load documents and merge with attributes from all subdirectories
  1479. logger.info("Loading documents and merging with all attributes...")
  1480. all_docs = load_documents_and_attributes(args.docs_folder, attr_folder, s3_client, args.recursive)
  1481. # Create output directory if it doesn't exist
  1482. os.makedirs(args.output_dir, exist_ok=True)
  1483. # Use the same documents for both reference and hypothesis evaluation
  1484. # since we've loaded all attributes into each document
  1485. ref_docs = all_docs
  1486. hyp_docs = all_docs
  1487. # Compare the documents
  1488. logger.info("Comparing documents using reference and hypothesis rules...")
  1489. comparison_result = compare_documents(ref_docs, hyp_docs, ref_rule, hyp_rule)
  1490. # Get document IDs for each category
  1491. ref_matches = set()
  1492. hyp_matches = set()
  1493. # Create mappings from document IDs to documents
  1494. doc_map = {doc["id"]: doc for doc in all_docs if "id" in doc}
  1495. # Find documents that match the reference and hypothesis rules
  1496. for doc_id, doc in doc_map.items():
  1497. if apply_rule(doc, ref_rule):
  1498. ref_matches.add(doc_id)
  1499. if apply_rule(doc, hyp_rule):
  1500. hyp_matches.add(doc_id)
  1501. # Calculate document sets for each category
  1502. true_positives_ids = ref_matches.intersection(hyp_matches)
  1503. true_negatives_ids = set(doc_map.keys()) - ref_matches - hyp_matches
  1504. false_positives_ids = hyp_matches - ref_matches
  1505. false_negatives_ids = ref_matches - hyp_matches
  1506. # Create document lists for each category
  1507. true_positives = [doc_map[doc_id] for doc_id in true_positives_ids]
  1508. true_negatives = [doc_map[doc_id] for doc_id in true_negatives_ids]
  1509. false_positives = [doc_map[doc_id] for doc_id in false_positives_ids]
  1510. false_negatives = [doc_map[doc_id] for doc_id in false_negatives_ids]
  1511. # Calculate metrics
  1512. tp = len(true_positives)
  1513. tn = len(true_negatives)
  1514. fp = len(false_positives)
  1515. fn = len(false_negatives)
  1516. precision = tp / (tp + fp) if (tp + fp) > 0 else 0
  1517. recall = tp / (tp + fn) if (tp + fn) > 0 else 0
  1518. f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
  1519. iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
  1520. # Prepare overall statistics
  1521. overall_stats = {
  1522. "total_docs": len(doc_map),
  1523. "ref_matches": len(ref_matches),
  1524. "hyp_matches": len(hyp_matches),
  1525. "true_positives": tp,
  1526. "true_negatives": tn,
  1527. "false_positives": fp,
  1528. "false_negatives": fn,
  1529. "precision": precision,
  1530. "recall": recall,
  1531. "f1": f1,
  1532. "iou": iou,
  1533. "ref_rule_stats": comparison_result["ref_rule_stats"],
  1534. "hyp_rule_stats": comparison_result["hyp_rule_stats"],
  1535. }
  1536. # Prepare summary
  1537. summary = f"""Reference Rule: {args.ref_rule}
  1538. Hypothesis Rule: {args.hyp_rule}
  1539. Total Documents: {overall_stats['total_docs']}
  1540. Reference Matches: {overall_stats['ref_matches']}
  1541. Hypothesis Matches: {overall_stats['hyp_matches']}
  1542. True Positives: {tp}
  1543. True Negatives: {tn}
  1544. False Positives: {fp}
  1545. False Negatives: {fn}
  1546. Precision: {precision:.4f}
  1547. Recall: {recall:.4f}
  1548. F1 Score: {f1:.4f}
  1549. IoU: {iou:.4f}
  1550. """
  1551. # Generate HTML reports for each category
  1552. logger.info("Generating HTML reports...")
  1553. # True Positives
  1554. generate_html_report(
  1555. true_positives[:1000],
  1556. "True Positives - Documents matching both Reference and Hypothesis Rules",
  1557. summary,
  1558. os.path.join(args.output_dir, "true_positives.html"),
  1559. )
  1560. # True Negatives
  1561. generate_html_report(
  1562. true_negatives[:1000], "True Negatives - Documents not matching either Rule", summary, os.path.join(args.output_dir, "true_negatives.html")
  1563. )
  1564. # False Positives
  1565. generate_html_report(
  1566. false_positives[:1000],
  1567. "False Positives - Documents matching Hypothesis but not Reference Rule",
  1568. summary,
  1569. os.path.join(args.output_dir, "false_positives.html"),
  1570. )
  1571. # False Negatives
  1572. generate_html_report(
  1573. false_negatives[:1000],
  1574. "False Negatives - Documents matching Reference but not Hypothesis Rule",
  1575. summary,
  1576. os.path.join(args.output_dir, "false_negatives.html"),
  1577. )
  1578. # Collect numeric attributes and generate CDF plots if not disabled
  1579. attribute_plots_html = ""
  1580. if not args.disable_plots:
  1581. logger.info("Collecting numeric attributes for CDF plots...")
  1582. numeric_attributes = collect_numeric_attributes(all_docs)
  1583. if numeric_attributes:
  1584. logger.info(f"Found {len(numeric_attributes)} numeric attributes suitable for CDF plots")
  1585. # Generate CDF plots HTML with the specified maximum number of plots
  1586. attribute_plots_html = generate_attribute_plots_html(numeric_attributes, args.max_plots)
  1587. else:
  1588. logger.info("No numeric attributes found for CDF plots")
  1589. else:
  1590. logger.info("CDF plot generation disabled by --disable-plots flag")
  1591. # Generate index.html file that links to all reports
  1592. index_html = f"""<!DOCTYPE html>
  1593. <html>
  1594. <head>
  1595. <meta charset="UTF-8">
  1596. <title>PII Rule Comparison Results</title>
  1597. <style>
  1598. body {{
  1599. font-family: Arial, sans-serif;
  1600. line-height: 1.6;
  1601. margin: 0;
  1602. padding: 20px;
  1603. max-width: 1000px;
  1604. margin: 0 auto;
  1605. }}
  1606. .summary {{
  1607. background-color: #f8f9fa;
  1608. padding: 15px;
  1609. border-radius: 5px;
  1610. margin-bottom: 20px;
  1611. border-left: 5px solid #007bff;
  1612. }}
  1613. .category {{
  1614. margin-bottom: 20px;
  1615. padding: 15px;
  1616. border-radius: 5px;
  1617. }}
  1618. .true-positives {{
  1619. background-color: #d4edda;
  1620. border-left: 5px solid #28a745;
  1621. }}
  1622. .true-negatives {{
  1623. background-color: #e2e3e5;
  1624. border-left: 5px solid #6c757d;
  1625. }}
  1626. .false-positives {{
  1627. background-color: #f8d7da;
  1628. border-left: 5px solid #dc3545;
  1629. }}
  1630. .false-negatives {{
  1631. background-color: #fff3cd;
  1632. border-left: 5px solid #ffc107;
  1633. }}
  1634. h1 {{
  1635. border-bottom: 2px solid #007bff;
  1636. padding-bottom: 10px;
  1637. color: #333;
  1638. }}
  1639. a {{
  1640. color: #007bff;
  1641. text-decoration: none;
  1642. font-weight: bold;
  1643. }}
  1644. a:hover {{
  1645. text-decoration: underline;
  1646. }}
  1647. .attribute-plots {{
  1648. margin-top: 30px;
  1649. }}
  1650. .plot-container {{
  1651. margin-bottom: 30px;
  1652. padding: 15px;
  1653. background-color: #fff;
  1654. border-radius: 5px;
  1655. box-shadow: 0 2px 5px rgba(0,0,0,0.1);
  1656. }}
  1657. .cdf-plot {{
  1658. max-width: 100%;
  1659. height: auto;
  1660. }}
  1661. h2 {{
  1662. color: #333;
  1663. border-bottom: 1px solid #eee;
  1664. padding-bottom: 10px;
  1665. margin-top: 30px;
  1666. }}
  1667. h3 {{
  1668. color: #007bff;
  1669. }}
  1670. </style>
  1671. </head>
  1672. <body>
  1673. <h1>PII Rule Comparison Results</h1>
  1674. <div class="summary">
  1675. <h2>Summary</h2>
  1676. <pre>{summary}</pre>
  1677. </div>
  1678. <h2>Result Categories</h2>
  1679. <div class="category true-positives">
  1680. <h3>True Positives: {tp}</h3>
  1681. <p>Documents that match both the reference and hypothesis rules.</p>
  1682. <a href="true_positives.html">View True Positives</a>
  1683. </div>
  1684. <div class="category true-negatives">
  1685. <h3>True Negatives: {tn}</h3>
  1686. <p>Documents that don't match either the reference or hypothesis rules.</p>
  1687. <a href="true_negatives.html">View True Negatives</a>
  1688. </div>
  1689. <div class="category false-positives">
  1690. <h3>False Positives: {fp}</h3>
  1691. <p>Documents that match the hypothesis rule but not the reference rule.</p>
  1692. <a href="false_positives.html">View False Positives</a>
  1693. </div>
  1694. <div class="category false-negatives">
  1695. <h3>False Negatives: {fn}</h3>
  1696. <p>Documents that match the reference rule but not the hypothesis rule.</p>
  1697. <a href="false_negatives.html">View False Negatives</a>
  1698. </div>
  1699. {attribute_plots_html}
  1700. </body>
  1701. </html>
  1702. """
  1703. with open(os.path.join(args.output_dir, "index.html"), "w", encoding="utf-8") as f:
  1704. f.write(index_html)
  1705. # Print summary
  1706. logger.info("\n--- COMPARISON SUMMARY ---")
  1707. logger.info(f"Documents Folder: {args.docs_folder}")
  1708. logger.info(f"Attributes Folder: {attr_folder}")
  1709. logger.info(f"Reference Rule Expression: {args.ref_rule}")
  1710. logger.info(f" Parsed as: {ref_rule_str}")
  1711. logger.info(f"Hypothesis Rule Expression: {args.hyp_rule}")
  1712. logger.info(f" Parsed as: {hyp_rule_str}")
  1713. logger.info(f"Total Documents: {overall_stats['total_docs']}")
  1714. # Print rule statistics
  1715. logger.info("\n--- RULE MATCH STATISTICS ---")
  1716. logger.info("\nReference Rules:")
  1717. logger.info(format_rule_stats(overall_stats["ref_rule_stats"]))
  1718. logger.info("\nHypothesis Rules:")
  1719. logger.info(format_rule_stats(overall_stats["hyp_rule_stats"]))
  1720. # Print comparison metrics
  1721. logger.info("\n--- COMPARISON METRICS ---")
  1722. logger.info(f"True Positives: {tp}")
  1723. logger.info(f"True Negatives: {tn}")
  1724. logger.info(f"False Positives: {fp}")
  1725. logger.info(f"False Negatives: {fn}")
  1726. logger.info(f"Precision: {precision:.4f}")
  1727. logger.info(f"Recall: {recall:.4f}")
  1728. logger.info(f"F1 Score: {f1:.4f}")
  1729. logger.info(f"IoU: {iou:.4f}")
  1730. # Output all available attributes that have been loaded
  1731. logger.info("\n--- AVAILABLE ATTRIBUTES ---")
  1732. all_attributes = set()
  1733. for doc in all_docs:
  1734. if "attributes" in doc and doc["attributes"]:
  1735. all_attributes.update(doc["attributes"].keys())
  1736. if all_attributes:
  1737. logger.info(f"Found {len(all_attributes)} unique attributes:")
  1738. for attr in sorted(all_attributes):
  1739. logger.info(f" - {attr}")
  1740. else:
  1741. logger.info("No attributes found in any documents.")
  1742. logger.info(f"\nResults saved to: {args.output_dir}/index.html")
  1743. if __name__ == "__main__":
  1744. main()
  1745. # Example commands with actual S3 paths:
  1746. """
  1747. # Example for AI2 OE data with resume detection:
  1748. python scripts/pii_rule_comparison.py \
  1749. --docs-folder s3://ai2-oe-data/jakep/s2pdf_dedupe_minhash_v1_mini/documents/ \
  1750. --ref-rule "gpt_4_1_contains_pii:any and not gpt_4_1_is_public_document:all" \
  1751. --hyp-rule "google_gemma-3-4b-it_is_resume_cv:any" \
  1752. --output-dir results/resume_detection \
  1753. --recursive \
  1754. --debug
  1755. # Example for PII detection comparison:
  1756. python scripts/pii_rule_comparison.py \
  1757. --docs-folder s3://allenai-dolma/documents/v1.5 \
  1758. --ref-rule "contains_pii:any" \
  1759. --hyp-rule "(contains_email_addresses:any or contains_phone_numbers:any) and not false_positive:any" \
  1760. --output-dir results/pii_detection \
  1761. --recursive \
  1762. --aws-profile dolma
  1763. # Example with custom attributes folder:
  1764. python scripts/pii_rule_comparison.py \
  1765. --docs-folder s3://bucket/path/documents \
  1766. --attr-folder s3://bucket/custom/location/attributes \
  1767. --ref-rule "gpt_4_1_contains_pii:any" \
  1768. --hyp-rule "custom_model_pii_detection:any" \
  1769. --output-dir results/custom_comparison \
  1770. --recursive
  1771. """