| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152 |
- #!/usr/bin/env python3
- """
- Compare PII Detection Rules and Calculate IoU
- This script processes documents and their attributes from S3 or local storage,
- applies different rules for PII detection, and calculates the
- Intersection over Union (IoU) to measure how well they overlap.
- How it works:
- 1. Documents are stored in one location (--docs-folder)
- 2. Attributes are automatically found in ../attributes/ relative to the documents folder
- 3. The script merges documents with all available attributes by matching filenames and document IDs
- 4. PII detection rules are applied to the merged documents
- 5. IoU and other metrics are calculated to compare the results
- Expected folder structure:
- - s3://bucket/path/documents/ - Contains the main document JSONL files
- - s3://bucket/path/attributes/ - Contains attributes that can be matched with documents by ID
- Document and attribute matching:
- - Files are matched by basename (example.jsonl in documents matches example.jsonl in attributes)
- - Within each file, documents are matched by their "id" field
- - When a match is found, attributes from the attribute file are merged into the document
- Example usage:
- python pii_rule_comparison.py \
- --docs-folder s3://bucket/path/documents \
- --ref-rule "gpt_4_1_contains_pii:any" \
- --hyp-rule "gpt_4_1_contains_email_addresses:any" \
- --output-file iou_results.json \
- --recursive
- Rule expression syntax:
- - Simple rule: "attribute_name:rule_type" where rule_type is "any" or "all"
- - Boolean expressions: "not rule1:any and rule2:all"
- - Parentheses for grouping: "(rule1:any or rule2:any) and not rule3:all"
- """
- import argparse
- import base64
- import gzip
- import html as pyhtml
- import io
- import json
- import logging
- import os
- from collections import defaultdict
- from enum import Enum, auto
- from io import BytesIO
- from pathlib import Path
- import boto3
- import numpy as np
- import zstandard as zstd
- from matplotlib.figure import Figure
- # Initialize logger
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
- logger = logging.getLogger(__name__)
- # Define token types for the rule expression parser
- class TokenType(Enum):
- RULE = auto()
- AND = auto()
- OR = auto()
- NOT = auto()
- LPAREN = auto()
- RPAREN = auto()
- EOF = auto()
- class Token:
- """Token for rule expression parsing"""
- def __init__(self, type, value=None):
- self.type = type
- self.value = value
- def __repr__(self):
- if self.value:
- return f"Token({self.type}, {self.value})"
- return f"Token({self.type})"
- class ExpressionNode:
- """Base class for expression tree nodes"""
- pass
- class RuleNode(ExpressionNode):
- """Leaf node representing a single rule"""
- def __init__(self, attribute_name, rule_type):
- self.attribute_name = attribute_name
- self.rule_type = rule_type
- def __repr__(self):
- return f"Rule({self.attribute_name}:{self.rule_type})"
- class NotNode(ExpressionNode):
- """Unary NOT operation node"""
- def __init__(self, operand):
- self.operand = operand
- def __repr__(self):
- return f"NOT({self.operand})"
- class BinaryNode(ExpressionNode):
- """Binary operation (AND/OR) node"""
- def __init__(self, left, right, operator):
- self.left = left
- self.right = right
- self.operator = operator
- def __repr__(self):
- return f"{self.operator}({self.left}, {self.right})"
- def parse_args():
- parser = argparse.ArgumentParser(description="Compare PII detection rules and calculate IoU")
- parser.add_argument("--docs-folder", required=True, help="Documents folder path containing JSONL files (local or s3://)")
- parser.add_argument("--attr-folder", help="Attributes folder path (if different from standard ../attributes/ location)")
- parser.add_argument(
- "--ref-rule",
- required=True,
- help="""Reference rule expression. Can be a simple rule in format 'attribute_name:rule_type',
- where rule_type is 'any' or 'all'. Or a boolean expression like
- 'not rule1:any and rule2:all' or '(rule1:any or rule2:any) and not rule3:all'""",
- )
- parser.add_argument(
- "--hyp-rule",
- required=True,
- help="""Hypothesis rule expression. Can be a simple rule in format 'attribute_name:rule_type',
- where rule_type is 'any' or 'all'. Or a boolean expression like
- 'not rule1:any and rule2:all' or '(rule1:any or rule2:any) and not rule3:all'""",
- )
- parser.add_argument("--output-dir", default="results", help="Directory to save HTML result files")
- parser.add_argument("--aws-profile", help="AWS profile for S3 access")
- parser.add_argument("--recursive", action="store_true", help="Recursively process folder structure")
- parser.add_argument("--debug", action="store_true", help="Enable debug logging for more detailed output")
- parser.add_argument("--disable-plots", action="store_true", help="Disable CDF plots generation")
- parser.add_argument("--max-plots", type=int, default=200, help="Maximum number of CDF plots to generate (default: 200)")
- return parser.parse_args()
- def parse_s3_path(s3_path):
- """Parse S3 path into bucket and prefix."""
- parts = s3_path.replace("s3://", "").split("/", 1)
- bucket = parts[0]
- prefix = parts[1] if len(parts) > 1 else ""
- return bucket, prefix
- def get_attributes_folder(docs_folder, attr_folder=None):
- """
- Determine the attributes folder path based on the documents folder.
- Args:
- docs_folder: Path to the documents folder
- attr_folder: Manually specified attributes folder (optional)
- Returns:
- Path to the attributes folder
- """
- if attr_folder:
- return attr_folder
- # If no attributes folder specified, derive it from the documents folder
- if docs_folder.startswith("s3://"):
- # For S3 paths
- bucket, prefix = parse_s3_path(docs_folder)
- # Remove trailing slashes for consistent path handling
- prefix = prefix.rstrip("/")
- # Check if the documents folder is in a 'documents' directory
- if prefix.endswith("/documents"):
- # Replace /documents with /attributes
- attr_prefix = prefix[: -len("/documents")] + "/attributes"
- else:
- # Otherwise, add a parent level and include 'attributes'
- path_parts = prefix.split("/")
- # Remove the last part (assumed to be the documents directory name)
- path_parts.pop()
- # Add 'attributes'
- path_parts.append("attributes")
- attr_prefix = "/".join(path_parts)
- return f"s3://{bucket}/{attr_prefix}"
- else:
- # For local paths
- docs_path = Path(docs_folder)
- # Check if the documents folder is in a 'documents' directory
- if docs_path.name == "documents":
- # Replace /documents with /attributes
- attr_path = docs_path.parent / "attributes"
- else:
- # Otherwise, add a parent level and include 'attributes'
- attr_path = docs_path.parent / "attributes"
- return str(attr_path)
- def get_s3_bytes(s3_client, s3_path):
- """Get bytes from S3 object."""
- bucket, key = parse_s3_path(s3_path)
- response = s3_client.get_object(Bucket=bucket, Key=key)
- return response["Body"].read()
- def list_jsonl_files(path, s3_client=None, recursive=False):
- """List all JSONL files in the given path, locally or in S3."""
- jsonl_files = []
- if path.startswith("s3://"):
- bucket, prefix = parse_s3_path(path)
- prefix = prefix.rstrip("/") + "/"
- # List objects in S3 bucket with given prefix
- paginator = s3_client.get_paginator("list_objects_v2")
- for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
- if "Contents" in page:
- for obj in page["Contents"]:
- key = obj["Key"]
- if (
- key.endswith(".jsonl")
- or key.endswith(".json")
- or key.endswith(".jsonl.gz")
- or key.endswith(".jsonl.zst")
- or key.endswith(".jsonl.ztd")
- or key.endswith(".jsonl.zstd")
- ):
- jsonl_files.append(f"s3://{bucket}/{key}")
- else:
- # Local file system
- path_obj = Path(path)
- if recursive:
- for file_path in path_obj.rglob("*"):
- if (
- file_path.name.endswith(".jsonl")
- or file_path.name.endswith(".json")
- or file_path.name.endswith(".jsonl.gz")
- or file_path.name.endswith(".jsonl.zst")
- or file_path.name.endswith(".jsonl.ztd")
- or file_path.name.endswith(".jsonl.zstd")
- ):
- jsonl_files.append(str(file_path))
- else:
- for file_path in path_obj.glob("*"):
- if (
- file_path.name.endswith(".jsonl")
- or file_path.name.endswith(".json")
- or file_path.name.endswith(".jsonl.gz")
- or file_path.name.endswith(".jsonl.zst")
- or file_path.name.endswith(".jsonl.ztd")
- or file_path.name.endswith(".jsonl.zstd")
- ):
- jsonl_files.append(str(file_path))
- return jsonl_files
- def load_jsonl_file(file_path, s3_client=None):
- """Load and decompress a JSONL file, either from local or S3."""
- try:
- # Get file content
- if file_path.startswith("s3://"):
- if s3_client is None:
- raise ValueError("S3 client is required for S3 paths")
- raw_data = get_s3_bytes(s3_client, file_path)
- else:
- with open(file_path, "rb") as f:
- raw_data = f.read()
- # Decompress if needed
- if file_path.endswith(".gz"):
- decompressed = gzip.decompress(raw_data)
- elif file_path.endswith((".zst", ".ztd", ".zstd")):
- try:
- # First try with standard decompression
- dctx = zstd.ZstdDecompressor()
- decompressed = dctx.decompress(raw_data)
- except zstd.ZstdError as e:
- # If that fails, try with stream decompression
- logger.warning(f"Standard zstd decompression failed for {file_path}, trying stream decompression: {e}")
- try:
- # Try with content-size not required
- dctx = zstd.ZstdDecompressor(max_window_size=2147483648) # Use a large window size
- decompressor = dctx.stream_reader(io.BytesIO(raw_data))
- decompressed = decompressor.read()
- except Exception as inner_e:
- # If both methods fail, try with chunking
- logger.warning(f"Stream decompression also failed, trying chunked reading: {inner_e}")
- # Chunked reading approach
- buffer = io.BytesIO()
- dctx = zstd.ZstdDecompressor(max_window_size=2147483648)
- with dctx.stream_reader(io.BytesIO(raw_data)) as reader:
- while True:
- chunk = reader.read(16384) # Read in 16KB chunks
- if not chunk:
- break
- buffer.write(chunk)
- buffer.seek(0)
- decompressed = buffer.read()
- else:
- decompressed = raw_data
- # Parse JSON lines
- lines = decompressed.decode("utf-8").strip().split("\n")
- return [json.loads(line) for line in lines if line.strip()]
- except Exception as e:
- logger.error(f"Error loading file {file_path}: {e}")
- return []
- def load_documents_and_attributes(docs_folder, attr_folder, s3_client=None, recursive=False):
- """
- Load documents and merge them with their attributes from all subdirectories.
- Args:
- docs_folder: Path to the documents folder
- attr_folder: Path to the attributes folder
- s3_client: S3 client for S3 paths
- recursive: Whether to process folders recursively
- Returns:
- List of documents with their attributes merged in
- """
- try:
- # List all document files
- logger.info(f"Finding document files in: {docs_folder}")
- doc_files = list_jsonl_files(docs_folder, s3_client, recursive)
- logger.info(f"Found {len(doc_files)} document files")
- if not doc_files:
- logger.warning(f"No document files found in {docs_folder}. Check the path and permissions.")
- return []
- # Get all attribute subdirectories if it's an S3 path
- attr_subdirs = []
- if attr_folder.startswith("s3://"):
- bucket, attr_prefix = parse_s3_path(attr_folder)
- attr_prefix = attr_prefix.rstrip("/") + "/"
- # List top-level directories in the attributes folder
- logger.info(f"Finding attribute subdirectories in: {attr_folder}")
- # Using delimiter parameter to list "directories" in S3
- paginator = s3_client.get_paginator("list_objects_v2")
- for page in paginator.paginate(Bucket=bucket, Prefix=attr_prefix, Delimiter="/"):
- if "CommonPrefixes" in page:
- for prefix in page["CommonPrefixes"]:
- subdir = f"s3://{bucket}/{prefix['Prefix']}"
- attr_subdirs.append(subdir)
- logger.info(f"Found attribute subdirectory: {subdir}")
- # If no subdirectories, use the main folder
- if not attr_subdirs:
- attr_subdirs = [attr_folder]
- logger.info(f"No subdirectories found, using main attribute folder: {attr_folder}")
- else:
- # For local paths
- attr_path = Path(attr_folder)
- if attr_path.exists() and attr_path.is_dir():
- # Get subdirectories
- subdirs = [str(d) for d in attr_path.iterdir() if d.is_dir()]
- if subdirs:
- attr_subdirs = subdirs
- logger.info(f"Found {len(attr_subdirs)} attribute subdirectories")
- else:
- attr_subdirs = [attr_folder]
- logger.info(f"No subdirectories found, using main attribute folder: {attr_folder}")
- else:
- logger.warning(f"Attributes folder not found or not a directory: {attr_folder}")
- attr_subdirs = []
- # Load and merge documents with all attributes from all subdirectories
- merged_docs = []
- docs_by_id = {}
- total_attr_files = 0
- # First, load all document files and create a document-by-ID mapping
- for doc_path in doc_files:
- try:
- if doc_path.startswith("s3://"):
- _, doc_key = parse_s3_path(doc_path)
- basename = os.path.basename(doc_key)
- else:
- basename = os.path.basename(doc_path)
- # Load documents
- docs = load_jsonl_file(doc_path, s3_client)
- if not docs:
- logger.warning(f"No documents loaded from {basename} (path: {doc_path})")
- continue
- logger.info(f"Loaded {len(docs)} documents from {basename}")
- # Add to the merged documents list and create ID mapping
- for doc in docs:
- if "id" in doc:
- # If the document already exists, use the one with attributes if possible
- doc_id = doc["id"]
- if doc_id in docs_by_id:
- if "attributes" not in doc and "attributes" in docs_by_id[doc_id]:
- # Keep the existing document that has attributes
- continue
- # Initialize attributes if needed
- if "attributes" not in doc:
- doc["attributes"] = {}
- # Add to the mapping
- docs_by_id[doc_id] = doc
- else:
- # No ID, can't match with attributes
- if "attributes" not in doc:
- doc["attributes"] = {}
- merged_docs.append(doc)
- except Exception as e:
- logger.error(f"Error processing document file {doc_path}: {e}")
- continue
- logger.info(f"Loaded {len(docs_by_id)} unique documents with IDs")
- # Now process each attribute subdirectory
- for subdir in attr_subdirs:
- try:
- logger.info(f"Processing attribute directory: {subdir}")
- attr_files = list_jsonl_files(subdir, s3_client, recursive)
- total_attr_files += len(attr_files)
- logger.info(f"Found {len(attr_files)} attribute files in {subdir}")
- # Create a mapping from document basename to attribute file path
- attr_file_map = {}
- for attr_path in attr_files:
- if attr_path.startswith("s3://"):
- _, attr_key = parse_s3_path(attr_path)
- basename = os.path.basename(attr_key)
- else:
- basename = os.path.basename(attr_path)
- attr_file_map[basename] = attr_path
- # Go through the document files again to find matching attributes
- for doc_path in doc_files:
- try:
- if doc_path.startswith("s3://"):
- _, doc_key = parse_s3_path(doc_path)
- basename = os.path.basename(doc_key)
- else:
- basename = os.path.basename(doc_path)
- # Find matching attribute file
- if basename in attr_file_map:
- attr_path = attr_file_map[basename]
- attrs = load_jsonl_file(attr_path, s3_client)
- if not attrs:
- logger.warning(f"No attributes loaded from {os.path.basename(attr_path)} (path: {attr_path})")
- continue
- logger.info(f"Loaded {len(attrs)} attributes from {os.path.basename(attr_path)}")
- # Create a mapping from document ID to attributes
- attr_by_id = {attr["id"]: attr for attr in attrs if "id" in attr}
- # Count documents with matched attributes
- docs_matched_in_file = 0
- # Merge attributes into documents by ID
- for doc_id, doc in docs_by_id.items():
- if doc_id in attr_by_id:
- docs_matched_in_file += 1
- # If attributes document has attributes field, merge them
- if "attributes" in attr_by_id[doc_id]:
- doc["attributes"].update(attr_by_id[doc_id]["attributes"])
- logger.info(f"Matched attributes for {docs_matched_in_file} documents from {basename} in {subdir}")
- except Exception as e:
- logger.error(f"Error processing attribute file {attr_path}: {e}")
- continue
- except Exception as e:
- logger.error(f"Error processing attribute subdirectory {subdir}: {e}")
- continue
- # Convert the dictionary to a list for return
- merged_docs.extend(docs_by_id.values())
- logger.info(f"Total documents processed: {len(merged_docs)}")
- logger.info(f"Total attribute files processed: {total_attr_files}")
- logger.info(f"Total attribute subdirectories processed: {len(attr_subdirs)}")
- return merged_docs
- except Exception as e:
- logger.error(f"Error in load_documents_and_attributes: {e}")
- raise
- def apply_rule(doc, rule):
- """
- Apply a rule to determine if a document meets the PII criteria.
- Args:
- doc: The document JSON object
- rule: Either a tuple (attribute_name, rule_type) for simple rules,
- or an ExpressionNode for complex boolean expressions
- Returns:
- True if the document matches the rule, False otherwise
- """
- # Handle simple rule
- if not is_complex_expression(rule):
- return apply_simple_rule(doc, rule[0], rule[1])
- # Handle complex expression
- return evaluate_expression(doc, rule)
- def calculate_attribute_aggregate(doc, attribute_name, operation_type):
- """
- Calculate an aggregate value for a numeric attribute.
- Args:
- doc: The document JSON object
- attribute_name: The attribute field to aggregate (e.g., "pii_tagging_ratio")
- operation_type: The type of aggregation to perform (e.g., "avg")
- Returns:
- The aggregated value, or None if calculation is not possible
- """
- # Check if document has attributes
- if "attributes" not in doc or not doc["attributes"]:
- logger.debug(f"Document {doc.get('id', 'unknown')} has no attributes")
- return None
- attributes = doc["attributes"]
- # Check if the specific attribute exists
- if attribute_name not in attributes:
- logger.debug(f"Document {doc.get('id', 'unknown')} doesn't have attribute: {attribute_name}")
- return None
- if not attributes[attribute_name]:
- logger.debug(f"Document {doc.get('id', 'unknown')} has empty attribute: {attribute_name}")
- return None
- # Extract the numeric values from the attribute spans
- # Each span is formatted as [start_pos, end_pos, value]
- values = [span[2] for span in attributes[attribute_name] if len(span) >= 3 and span[2] is not None]
- if not values:
- logger.debug(f"Document {doc.get('id', 'unknown')} has no valid values for attribute: {attribute_name}")
- return None
- # Convert all values to float
- try:
- numeric_values = [float(value) for value in values]
- except (ValueError, TypeError):
- logger.debug(f"Document {doc.get('id', 'unknown')} has non-numeric values for attribute: {attribute_name}")
- return None
- # Perform the aggregation
- if operation_type == "avg":
- if not numeric_values:
- return None
- return sum(numeric_values) / len(numeric_values)
- # Add more aggregation types here as needed
- else:
- raise ValueError(f"Unknown operation type: {operation_type}")
- def apply_simple_rule(doc, attribute_name, rule_type):
- """
- Apply a simple rule to determine if a document meets the PII criteria.
- Args:
- doc: The document JSON object
- attribute_name: The attribute field to check (e.g., "gpt_4_1_contains_pii")
- rule_type: 'any' for any true value, 'all' for all true values,
- or a string containing an operation and comparison (e.g., 'avg>0.3')
- Returns:
- True if the document matches the rule, False otherwise
- """
- # Check if document has attributes
- if "attributes" not in doc or not doc["attributes"]:
- logger.debug(f"Document {doc.get('id', 'unknown')} has no attributes")
- return False
- attributes = doc["attributes"]
- # Check if the specific attribute exists
- if attribute_name not in attributes:
- logger.debug(f"Document {doc.get('id', 'unknown')} doesn't have attribute: {attribute_name}")
- return False
- if not attributes[attribute_name]:
- logger.debug(f"Document {doc.get('id', 'unknown')} has empty attribute: {attribute_name}")
- return False
- # Handle numeric comparison rules (e.g., 'avg>0.3')
- if any(op in rule_type for op in [">", "<", ">=", "<=", "=="]):
- # Parse the rule type into operation and comparison
- operation_parts = rule_type.split(">")
- if len(operation_parts) == 2:
- operation_type, threshold = operation_parts
- comparison_op = ">"
- else:
- operation_parts = rule_type.split("<")
- if len(operation_parts) == 2:
- operation_type, threshold = operation_parts
- comparison_op = "<"
- else:
- operation_parts = rule_type.split(">=")
- if len(operation_parts) == 2:
- operation_type, threshold = operation_parts
- comparison_op = ">="
- else:
- operation_parts = rule_type.split("<=")
- if len(operation_parts) == 2:
- operation_type, threshold = operation_parts
- comparison_op = "<="
- else:
- operation_parts = rule_type.split("==")
- if len(operation_parts) == 2:
- operation_type, threshold = operation_parts
- comparison_op = "=="
- else:
- raise ValueError(f"Invalid rule type: {rule_type}")
- # Convert threshold to float
- try:
- threshold = float(threshold)
- except ValueError:
- raise ValueError(f"Invalid threshold value: {threshold}")
- # Calculate the aggregate value
- aggregate_value = calculate_attribute_aggregate(doc, attribute_name, operation_type)
- if aggregate_value is None:
- logger.debug(f"Document {doc.get('id', 'unknown')} has no valid aggregate value for attribute: {attribute_name}")
- return False
- # Apply the comparison
- if comparison_op == ">":
- result = aggregate_value > threshold
- elif comparison_op == "<":
- result = aggregate_value < threshold
- elif comparison_op == ">=":
- result = aggregate_value >= threshold
- elif comparison_op == "<=":
- result = aggregate_value <= threshold
- elif comparison_op == "==":
- result = aggregate_value == threshold
- else:
- raise ValueError(f"Invalid comparison operator: {comparison_op}")
- if result:
- logger.debug(f"Document {doc.get('id', 'unknown')} matched numeric rule '{attribute_name}:{rule_type}' with value {aggregate_value}")
- return result
- # Handle boolean rules (any/all)
- if rule_type in ["any", "all"]:
- # Extract the boolean values from the attribute spans
- # Each span is formatted as [start_pos, end_pos, value]
- values = [span[2] for span in attributes[attribute_name] if len(span) >= 3 and span[2] is not None]
- if not values:
- logger.debug(f"Document {doc.get('id', 'unknown')} has no valid values for attribute: {attribute_name}")
- return False
- # Apply the rule
- if rule_type == "any":
- result = any(values)
- if result:
- logger.debug(f"Document {doc.get('id', 'unknown')} matched rule '{attribute_name}:{rule_type}' (found True in {len(values)} values)")
- return result
- elif rule_type == "all":
- result = all(values)
- if result:
- logger.debug(f"Document {doc.get('id', 'unknown')} matched rule '{attribute_name}:{rule_type}' (all {len(values)} values are True)")
- return result
- raise ValueError(f"Unknown rule type: {rule_type}")
- def evaluate_expression(doc, expr):
- """
- Evaluate a boolean expression on a document.
- Args:
- doc: The document JSON object
- expr: An ExpressionNode representing a boolean expression
- Returns:
- True if the document matches the expression, False otherwise
- """
- if isinstance(expr, RuleNode):
- # Base case: evaluate a leaf rule node
- return apply_simple_rule(doc, expr.attribute_name, expr.rule_type)
- elif isinstance(expr, NotNode):
- # NOT operator
- return not evaluate_expression(doc, expr.operand)
- elif isinstance(expr, BinaryNode):
- # Binary operators (AND/OR)
- if expr.operator == "AND":
- # Short-circuit AND evaluation
- return evaluate_expression(doc, expr.left) and evaluate_expression(doc, expr.right)
- elif expr.operator == "OR":
- # Short-circuit OR evaluation
- return evaluate_expression(doc, expr.left) or evaluate_expression(doc, expr.right)
- # Should not reach here if the expression tree is well-formed
- raise ValueError(f"Invalid expression node type: {type(expr)}")
- def tokenize_expression(expression):
- """
- Tokenize a rule expression string into a list of tokens.
- Args:
- expression: A string containing a boolean rule expression
- (e.g., "not rule1:any and rule2:all")
- Returns:
- A list of Token objects
- """
- tokens = []
- i = 0
- expression = expression.strip()
- while i < len(expression):
- char = expression[i]
- # Skip whitespace
- if char.isspace():
- i += 1
- continue
- # Handle parentheses
- elif char == "(":
- tokens.append(Token(TokenType.LPAREN))
- i += 1
- elif char == ")":
- tokens.append(Token(TokenType.RPAREN))
- i += 1
- # Handle operators
- elif i + 2 < len(expression) and expression[i : i + 3].lower() == "and":
- # Check if it's a standalone 'and' and not part of a word
- if (i == 0 or expression[i - 1].isspace() or expression[i - 1] in "()") and (
- i + 3 >= len(expression) or expression[i + 3].isspace() or expression[i + 3] in "()"
- ):
- tokens.append(Token(TokenType.AND))
- i += 3
- else:
- # It's part of an attribute name
- rule_start = i
- while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
- if i + 1 < len(expression) and expression[i] == ":":
- break
- i += 1
- # Process rule if we found a colon
- if i < len(expression) and expression[i] == ":":
- rule_end = i
- i += 1 # Skip the colon
- # Find the rule type
- type_start = i
- while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
- i += 1
- rule_name = expression[rule_start:rule_end]
- rule_type = expression[type_start:i]
- tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
- else:
- raise ValueError(f"Invalid rule format at position {rule_start}")
- elif i + 1 < len(expression) and expression[i : i + 2].lower() == "or":
- # Check if it's a standalone 'or' and not part of a word
- if (i == 0 or expression[i - 1].isspace() or expression[i - 1] in "()") and (
- i + 2 >= len(expression) or expression[i + 2].isspace() or expression[i + 2] in "()"
- ):
- tokens.append(Token(TokenType.OR))
- i += 2
- else:
- # Part of an attribute name
- rule_start = i
- while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
- if i + 1 < len(expression) and expression[i] == ":":
- break
- i += 1
- # Process rule if we found a colon
- if i < len(expression) and expression[i] == ":":
- rule_end = i
- i += 1 # Skip the colon
- # Find the rule type
- type_start = i
- while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
- i += 1
- rule_name = expression[rule_start:rule_end]
- rule_type = expression[type_start:i]
- tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
- else:
- raise ValueError(f"Invalid rule format at position {rule_start}")
- elif i + 2 < len(expression) and expression[i : i + 3].lower() == "not":
- # Check if it's a standalone 'not' and not part of a word
- if (i == 0 or expression[i - 1].isspace() or expression[i - 1] in "()") and (
- i + 3 >= len(expression) or expression[i + 3].isspace() or expression[i + 3] in "()"
- ):
- tokens.append(Token(TokenType.NOT))
- i += 3
- else:
- # Part of an attribute name
- rule_start = i
- while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
- if i + 1 < len(expression) and expression[i] == ":":
- break
- i += 1
- # Process rule if we found a colon
- if i < len(expression) and expression[i] == ":":
- rule_end = i
- i += 1 # Skip the colon
- # Find the rule type
- type_start = i
- while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
- i += 1
- rule_name = expression[rule_start:rule_end]
- rule_type = expression[type_start:i]
- tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
- else:
- raise ValueError(f"Invalid rule format at position {rule_start}")
- # Handle rule (attribute:type)
- else:
- rule_start = i
- while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
- if i + 1 < len(expression) and expression[i] == ":":
- break
- i += 1
- # Process rule if we found a colon
- if i < len(expression) and expression[i] == ":":
- rule_end = i
- i += 1 # Skip the colon
- # Find the rule type
- type_start = i
- while i < len(expression) and not expression[i].isspace() and expression[i] not in "()":
- i += 1
- rule_name = expression[rule_start:rule_end]
- rule_type = expression[type_start:i]
- tokens.append(Token(TokenType.RULE, (rule_name, rule_type)))
- else:
- raise ValueError(f"Invalid rule format at position {rule_start}")
- tokens.append(Token(TokenType.EOF))
- return tokens
- class Parser:
- """
- Parser for boolean rule expressions.
- Implements a recursive descent parser for expressions with the following grammar:
- expression → or_expr
- or_expr → and_expr ("or" and_expr)*
- and_expr → unary_expr ("and" unary_expr)*
- unary_expr → "not" unary_expr | primary
- primary → rule | "(" expression ")"
- rule → ATTRIBUTE ":" RULE_TYPE
- """
- def __init__(self, tokens):
- self.tokens = tokens
- self.current = 0
- def parse(self):
- """Parse the tokens into an expression tree."""
- return self.expression()
- def expression(self):
- """Parse an expression (top level)."""
- return self.or_expr()
- def or_expr(self):
- """Parse an OR expression."""
- expr = self.and_expr()
- while self.match(TokenType.OR):
- right = self.and_expr()
- expr = BinaryNode(expr, right, "OR")
- return expr
- def and_expr(self):
- """Parse an AND expression."""
- expr = self.unary_expr()
- while self.match(TokenType.AND):
- right = self.unary_expr()
- expr = BinaryNode(expr, right, "AND")
- return expr
- def unary_expr(self):
- """Parse a unary expression (NOT)."""
- if self.match(TokenType.NOT):
- operand = self.unary_expr()
- return NotNode(operand)
- return self.primary()
- def primary(self):
- """Parse a primary expression (rule or parenthesized expression)."""
- if self.match(TokenType.RULE):
- rule_tuple = self.previous().value
- attribute_name, rule_type = rule_tuple
- # Validate rule type
- if rule_type not in ["any", "all"] and not any(op in rule_type for op in [">", "<", ">=", "<=", "=="]):
- raise ValueError(f"Invalid rule type: {rule_type}. Supported types: 'any', 'all', or numeric comparison (e.g., 'avg>0.3')")
- return RuleNode(attribute_name, rule_type)
- if self.match(TokenType.LPAREN):
- expr = self.expression()
- self.consume(TokenType.RPAREN, "Expected ')' after expression.")
- return expr
- raise ValueError(f"Expected rule or '(' at position {self.current}")
- def match(self, *types):
- """Check if the current token matches any of the given types."""
- for type in types:
- if self.check(type):
- self.advance()
- return True
- return False
- def check(self, type):
- """Check if the current token is of the given type without advancing."""
- if self.is_at_end():
- return False
- return self.peek().type == type
- def advance(self):
- """Advance to the next token and return the previous one."""
- if not self.is_at_end():
- self.current += 1
- return self.previous()
- def consume(self, type, message):
- """Consume the current token if it matches the expected type."""
- if self.check(type):
- return self.advance()
- raise ValueError(f"{message} at position {self.current}")
- def is_at_end(self):
- """Check if we've reached the end of the tokens."""
- return self.peek().type == TokenType.EOF
- def peek(self):
- """Return the current token without advancing."""
- return self.tokens[self.current]
- def previous(self):
- """Return the previous token."""
- return self.tokens[self.current - 1]
- def parse_rule(rule_string):
- """
- Parse a rule string into an expression tree or a simple attribute-rule_type tuple.
- Args:
- rule_string: A string containing a rule or boolean expression of rules
- Returns:
- Either a tuple (attribute_name, rule_type) for simple rules,
- or an ExpressionNode for complex boolean expressions
- """
- # Check if this is a simple rule
- if (
- "and" not in rule_string.lower()
- and "or" not in rule_string.lower()
- and "not" not in rule_string.lower()
- and "(" not in rule_string
- and ")" not in rule_string
- ):
- # Simple rule format: attribute_name:rule_type
- parts = rule_string.split(":", 1)
- if len(parts) != 2:
- raise ValueError(f"Invalid rule format: {rule_string}. Expected format: 'attribute_name:rule_type'")
- attribute_name, rule_type = parts
- # Check for numeric comparison rule_type
- if any(op in rule_type for op in [">", "<", ">=", "<=", "=="]):
- # This is a numeric comparison rule - we'll validate it in apply_simple_rule
- return attribute_name, rule_type
- elif rule_type not in ["any", "all"]:
- raise ValueError(f"Invalid rule type: {rule_type}. Supported types: 'any', 'all', or numeric comparison (e.g., 'avg>0.3')")
- return attribute_name, rule_type
- else:
- # Complex rule expression
- try:
- tokens = tokenize_expression(rule_string)
- parser = Parser(tokens)
- return parser.parse()
- except Exception as e:
- raise ValueError(f"Error parsing expression '{rule_string}': {e}")
- def is_complex_expression(rule):
- """Check if the rule is a complex boolean expression."""
- return isinstance(rule, ExpressionNode)
- def calculate_iou(ref_ids, hyp_ids):
- """Calculate Intersection over Union of two sets of document IDs."""
- ref_set = set(ref_ids)
- hyp_set = set(hyp_ids)
- intersection = ref_set.intersection(hyp_set)
- union = ref_set.union(hyp_set)
- if not union:
- return 0.0
- return len(intersection) / len(union)
- def collect_rule_stats(expression, doc):
- """
- Collect statistics for all rules within a complex expression.
- Args:
- expression: A rule expression (either a tuple or ExpressionNode)
- doc: The document to analyze
- Returns:
- A dictionary with rule statistics
- """
- rule_stats = defaultdict(int)
- # Handle simple rule
- if not is_complex_expression(expression):
- attribute_name, rule_type = expression
- # Only process if document has this attribute
- if "attributes" in doc and doc["attributes"] and attribute_name in doc["attributes"] and doc["attributes"][attribute_name]:
- # The rule name will be the key for the statistics
- rule_name = f"{attribute_name}:{rule_type}"
- # Count entries in the attribute
- entries = doc["attributes"][attribute_name]
- rule_stats[f"{rule_name}_total_entries"] += len(entries)
- # Count positive values
- for span in entries:
- if len(span) >= 3 and span[2] is True:
- rule_stats[f"{rule_name}_positive_entries"] += 1
- # Check if document matches the rule
- if apply_simple_rule(doc, attribute_name, rule_type):
- rule_stats[f"{rule_name}_matched_docs"] += 1
- return rule_stats
- # For complex expressions, traverse the expression tree
- if isinstance(expression, RuleNode):
- # Base case: leaf node is a simple rule
- attribute_name, rule_type = expression.attribute_name, expression.rule_type
- if "attributes" in doc and doc["attributes"] and attribute_name in doc["attributes"] and doc["attributes"][attribute_name]:
- # The rule name will be the key for the statistics
- rule_name = f"{attribute_name}:{rule_type}"
- # Count entries in the attribute
- entries = doc["attributes"][attribute_name]
- rule_stats[f"{rule_name}_total_entries"] += len(entries)
- # Count positive values
- for span in entries:
- if len(span) >= 3 and span[2] is True:
- rule_stats[f"{rule_name}_positive_entries"] += 1
- # Check if document matches the rule
- if apply_simple_rule(doc, attribute_name, rule_type):
- rule_stats[f"{rule_name}_matched_docs"] += 1
- elif isinstance(expression, NotNode):
- # Get stats from the operand
- operand_stats = collect_rule_stats(expression.operand, doc)
- # Merge with current stats
- for key, value in operand_stats.items():
- rule_stats[key] += value
- elif isinstance(expression, BinaryNode):
- # Get stats from both sides
- left_stats = collect_rule_stats(expression.left, doc)
- right_stats = collect_rule_stats(expression.right, doc)
- # Merge with current stats
- for key, value in left_stats.items():
- rule_stats[key] += value
- for key, value in right_stats.items():
- rule_stats[key] += value
- return rule_stats
- def get_expression_summary(expression):
- """
- Generate a string representation of a rule expression.
- Args:
- expression: A rule expression (either a tuple or ExpressionNode)
- Returns:
- A string representation of the expression
- """
- if not is_complex_expression(expression):
- return f"{expression[0]}:{expression[1]}"
- if isinstance(expression, RuleNode):
- return f"{expression.attribute_name}:{expression.rule_type}"
- elif isinstance(expression, NotNode):
- return f"not {get_expression_summary(expression.operand)}"
- elif isinstance(expression, BinaryNode):
- left_summary = get_expression_summary(expression.left)
- right_summary = get_expression_summary(expression.right)
- return f"({left_summary} {expression.operator.lower()} {right_summary})"
- return str(expression)
- def compare_documents(ref_docs, hyp_docs, ref_rule, hyp_rule):
- """
- Compare two sets of documents using the specified rules and calculate IoU.
- Args:
- ref_docs: List of reference documents
- hyp_docs: List of hypothesis documents
- ref_rule: Rule expression for reference (tuple or ExpressionNode)
- hyp_rule: Rule expression for hypothesis (tuple or ExpressionNode)
- Returns:
- Dictionary with comparison results
- """
- # Extract document IDs and create ID-to-document maps
- ref_id_to_doc = {doc["id"]: doc for doc in ref_docs if "id" in doc}
- hyp_id_to_doc = {doc["id"]: doc for doc in hyp_docs if "id" in doc}
- # Get common document IDs
- common_ids = set(ref_id_to_doc.keys()).intersection(set(hyp_id_to_doc.keys()))
- # Apply rules to each document
- ref_matches = set()
- hyp_matches = set()
- # Track rule statistics
- ref_rule_stats = defaultdict(int)
- hyp_rule_stats = defaultdict(int)
- for doc_id in common_ids:
- ref_doc = ref_id_to_doc[doc_id]
- hyp_doc = hyp_id_to_doc[doc_id]
- # Collect statistics for all rules in the expressions
- doc_ref_rule_stats = collect_rule_stats(ref_rule, ref_doc)
- doc_hyp_rule_stats = collect_rule_stats(hyp_rule, hyp_doc)
- # Merge with overall stats
- for key, value in doc_ref_rule_stats.items():
- ref_rule_stats[key] += value
- for key, value in doc_hyp_rule_stats.items():
- hyp_rule_stats[key] += value
- # Check if document matches the rule expressions
- if apply_rule(ref_doc, ref_rule):
- ref_matches.add(doc_id)
- ref_rule_stats["expression_matched_docs"] += 1
- if apply_rule(hyp_doc, hyp_rule):
- hyp_matches.add(doc_id)
- hyp_rule_stats["expression_matched_docs"] += 1
- # Calculate IoU
- iou = calculate_iou(ref_matches, hyp_matches)
- # Collect detailed statistics
- tp = len(ref_matches.intersection(hyp_matches))
- fp = len(hyp_matches - ref_matches)
- fn = len(ref_matches - hyp_matches)
- precision = tp / (tp + fp) if (tp + fp) > 0 else 0
- recall = tp / (tp + fn) if (tp + fn) > 0 else 0
- f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
- # Generate string representations of the expressions
- ref_rule_str = get_expression_summary(ref_rule)
- hyp_rule_str = get_expression_summary(hyp_rule)
- return {
- "total_docs": len(common_ids),
- "ref_rule": ref_rule_str,
- "hyp_rule": hyp_rule_str,
- "ref_matches": len(ref_matches),
- "hyp_matches": len(hyp_matches),
- "intersection": tp,
- "union": tp + fp + fn,
- "true_positives": tp,
- "false_positives": fp,
- "false_negatives": fn,
- "precision": precision,
- "recall": recall,
- "f1": f1,
- "iou": iou,
- "ref_rule_stats": dict(ref_rule_stats),
- "hyp_rule_stats": dict(hyp_rule_stats),
- }
- def format_rule_stats(rule_stats):
- """Format rule statistics for display."""
- # Group the statistics by rule name
- grouped_stats = defaultdict(dict)
- # Process regular rule stats (format: "{rule_name}_{stat_type}")
- for key, value in rule_stats.items():
- if key == "expression_matched_docs":
- # Special case for the overall expression match count
- continue
- # Extract rule name and stat type
- if "_total_entries" in key:
- rule_name = key.replace("_total_entries", "")
- grouped_stats[rule_name]["total_entries"] = value
- elif "_positive_entries" in key:
- rule_name = key.replace("_positive_entries", "")
- grouped_stats[rule_name]["positive_entries"] = value
- elif "_matched_docs" in key:
- rule_name = key.replace("_matched_docs", "")
- grouped_stats[rule_name]["matched_docs"] = value
- # Format the grouped statistics as a list of strings
- formatted_stats = []
- for rule_name, stats in grouped_stats.items():
- formatted_stats.append(
- f" {rule_name}:\n"
- f" - Total Entries: {stats.get('total_entries', 0)}\n"
- f" - Positive Entries: {stats.get('positive_entries', 0)}\n"
- f" - Matched Documents: {stats.get('matched_docs', 0)}"
- )
- # Add the expression matched count if available
- if "expression_matched_docs" in rule_stats:
- formatted_stats.append(f" Overall Expression Matched Documents: {rule_stats['expression_matched_docs']}")
- return "\n".join(formatted_stats)
- def collect_numeric_attributes(documents):
- """
- Collect all numeric attribute values from documents.
- Args:
- documents: List of documents with attributes
- Returns:
- Dictionary mapping attribute names to lists of numeric values
- """
- numeric_attributes = defaultdict(list)
- for doc in documents:
- if "attributes" not in doc or not doc["attributes"]:
- continue
- for attr_name, attr_values in doc["attributes"].items():
- if not attr_values:
- continue
- # Try to extract numeric values from the attribute spans
- # Each span is formatted as [start_pos, end_pos, value]
- for span in attr_values:
- if len(span) >= 3 and span[2] is not None:
- try:
- # Convert to float if it's a numeric value
- value = float(span[2])
- numeric_attributes[attr_name].append(value)
- except (ValueError, TypeError):
- # Not a numeric value, skip
- pass
- # Filter out attributes with no or too few numeric values
- return {k: v for k, v in numeric_attributes.items() if len(v) > 5}
- def generate_cdf_plot(values, attribute_name):
- """
- Generate a CDF plot for the given numeric values.
- Args:
- values: List of numeric values
- attribute_name: Name of the attribute (for plot title)
- Returns:
- Base64-encoded PNG image of the plot or None if there's an error
- """
- try:
- # Ensure we have enough data points
- if len(values) < 5:
- logger.warning(f"Not enough data points to generate CDF for {attribute_name}")
- return None
- # Remove any NaN or infinite values
- values = np.array([v for v in values if np.isfinite(v)])
- if len(values) < 5:
- logger.warning(f"Not enough finite data points to generate CDF for {attribute_name}")
- return None
- # Handle extreme values by removing outliers (optional)
- # if len(values) > 30: # Only apply if we have enough data points
- # q1, q3 = np.percentile(values, [25, 75])
- # iqr = q3 - q1
- # lower_bound = q1 - 3 * iqr
- # upper_bound = q3 + 3 * iqr
- # values = values[(values >= lower_bound) & (values <= upper_bound)]
- # Sort values for CDF calculation
- values = np.sort(values)
- # Create a Figure object (no interactive display)
- fig = Figure(figsize=(10, 6))
- ax = fig.add_subplot(1, 1, 1)
- # Calculate CDF (y-values are 0 to 1 for cumulative probability)
- y = np.arange(1, len(values) + 1) / len(values)
- # Plot the CDF
- ax.plot(values, y, "b-", linewidth=2)
- ax.grid(True, linestyle="--", alpha=0.7)
- # Add labels and title
- ax.set_xlabel("Value", fontsize=12)
- ax.set_ylabel("Cumulative Probability", fontsize=12)
- ax.set_title(f"CDF of {attribute_name}", fontsize=14)
- # Ensure the y-axis goes from 0 to 1 for probability
- ax.set_ylim(0, 1.05)
- # Add some statistics to the plot
- if len(values) > 0:
- mean_val = np.mean(values)
- median_val = np.median(values)
- min_val = np.min(values)
- max_val = np.max(values)
- stats_text = f"n={len(values)}\nmin={min_val:.2f}\nmax={max_val:.2f}\nmean={mean_val:.2f}\nmedian={median_val:.2f}"
- ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, verticalalignment="top", bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
- # Make layout tight
- fig.tight_layout()
- # Convert to base64 for embedding in HTML
- buf = BytesIO()
- fig.savefig(buf, format="png", dpi=100)
- buf.seek(0)
- img_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
- return img_base64
- except Exception as e:
- logger.error(f"Error generating CDF plot for {attribute_name}: {e}")
- return None
- def generate_attribute_plots_html(numeric_attributes, max_plots=20):
- """
- Generate HTML section with CDF plots for all numeric attributes.
- Args:
- numeric_attributes: Dictionary mapping attribute names to lists of numeric values
- max_plots: Maximum number of plots to generate
- Returns:
- HTML string with embedded CDF plots
- """
- if not numeric_attributes:
- return ""
- html = """
- <h2>Numeric Attribute Distributions</h2>
- <div class="attribute-plots">
- """
- plot_count = 0
- # Sort attributes by number of values (most values first)
- sorted_attrs = sorted(numeric_attributes.items(), key=lambda x: len(x[1]), reverse=True)
- for attr_name, values in sorted_attrs:
- if len(values) < 10: # Skip attributes with too few values for meaningful plots
- continue
- if plot_count >= max_plots:
- logger.info(f"Limiting CDF plots to {max_plots} attributes to avoid performance issues")
- break
- # Generate the CDF plot
- img_base64 = generate_cdf_plot(values, attr_name)
- # Only add to HTML if plot generation was successful
- if img_base64:
- html += f"""
- <div class="plot-container">
- <h3>{attr_name}</h3>
- <img src="data:image/png;base64,{img_base64}" alt="CDF plot for {attr_name}" class="cdf-plot">
- <p>Number of values: {len(values)}</p>
- </div>
- """
- plot_count += 1
- if plot_count == 0:
- return "" # Don't add the section if no plots were generated
- html += """
- </div>
- """
- return html
- def generate_html_report(docs, title, summary, output_path):
- """
- Generate an HTML report file with document texts
- Args:
- docs: List of documents to include in the report
- title: Title of the report
- summary: Summary statistics to include at the top
- output_path: Path to save the HTML file
- Returns:
- None
- """
- # Create header with CSS styling
- html = f"""<!DOCTYPE html>
- <html>
- <head>
- <meta charset="UTF-8">
- <title>{title}</title>
- <style>
- body {{
- font-family: Arial, sans-serif;
- line-height: 1.6;
- margin: 0;
- padding: 0;
- scroll-behavior: smooth;
- }}
-
- /* Header bar styles */
- .header {{
- background-color: #f8f9fa;
- padding: 8px 20px;
- box-shadow: 0 2px 5px rgba(0,0,0,0.1);
- position: fixed;
- top: 0;
- left: 0;
- right: 0;
- z-index: 100;
- display: flex;
- justify-content: space-between;
- align-items: center;
- height: 40px;
- }}
-
- .title {{
- font-size: 1.2em;
- font-weight: bold;s
- color: #333;
- white-space: nowrap;
- overflow: hidden;
- text-overflow: ellipsis;
- max-width: 60%;
- }}
-
- .controls {{
- display: flex;
- align-items: center;
- }}
-
- .keyboard-controls {{
- font-size: 0.85em;
- margin-right: 15px;
- }}
-
- .toggle-summary {{
- background-color: #e9ecef;
- border: 1px solid #ced4da;
- padding: 4px 10px;
- border-radius: 4px;
- cursor: pointer;
- font-size: 0.85em;
- }}
-
- /* Summary panel styles */
- #summary-panel {{
- position: fixed;
- top: 57px;
- left: 0;
- right: 0;
- background-color: #f8f9fa;
- border-bottom: 1px solid #ddd;
- padding: 15px 20px;
- z-index: 90;
- display: none;
- max-height: 300px;
- overflow-y: auto;
- box-shadow: 0 2px 5px rgba(0,0,0,0.1);
- }}
-
- /* Main content styles */
- .container {{
- max-width: 1200px;
- margin: 0 auto;
- padding: 60px 20px 20px 20px;
- }}
-
- /* Document styles */
- .document {{
- background-color: #fff;
- padding: 15px;
- margin-bottom: 15px;
- border: 1px solid #ddd;
- border-radius: 5px;
- box-shadow: 0 1px 3px rgba(0,0,0,0.1);
- transition: all 0.2s ease-in-out;
- scroll-margin-top: 60px;
- }}
-
- .document:hover {{
- box-shadow: 0 2px 5px rgba(0,0,0,0.2);
- }}
-
- .document.selected {{
- border: 2px solid #007bff;
- box-shadow: 0 0 8px rgba(0, 123, 255, 0.5);
- background-color: #f8f9fa;
- }}
-
- .document-id {{
- color: #007bff;
- font-weight: bold;
- margin-bottom: 5px;
- font-size: 0.9em;
- }}
-
- .document-text {{
- white-space: pre-wrap;
- overflow-wrap: break-word;
- }}
-
- /* Helper styles */
- h2 {{
- margin-top: 0;
- font-size: 1.2em;
- color: #333;
- }}
-
- pre {{
- font-size: 0.9em;
- white-space: pre-wrap;
- }}
-
- .stats {{
- color: #666;
- font-size: 0.8em;
- font-weight: normal;
- }}
-
- .keyboard-shortcut {{
- display: inline-block;
- padding: 1px 4px;
- margin: 0 1px;
- border-radius: 3px;
- background-color: #f1f3f5;
- border: 1px solid #ced4da;
- font-family: monospace;
- font-size: 0.9em;
- }}
- </style>
- </head>
- <body>
- <!-- Fixed header -->
- <div class="header">
- <div class="title">{title} <span class="stats">({len(docs)} documents)</span></div>
- <div class="controls">
- <div class="keyboard-controls">
- <span class="keyboard-shortcut">↑</span>/<span class="keyboard-shortcut">↓</span> to navigate
- <span class="keyboard-shortcut">Home</span>/<span class="keyboard-shortcut">End</span>
- </div>
- <button class="toggle-summary" onclick="toggleSummary()">Show Summary</button>
- </div>
- </div>
-
- <!-- Summary panel (initially hidden) -->
- <div id="summary-panel">
- <h2>Summary</h2>
- <pre>{summary}</pre>
- </div>
-
- <!-- Main content -->
- <div class="container">
- <div id="document-container">
- """
- # Add each document with a unique ID
- for i, doc in enumerate(docs, 1):
- doc_id = doc.get("id", f"unknown_{i}")
- # Get document text, falling back to JSON representation if not available
- doc_text = doc.get("text", json.dumps(doc, indent=2))
- # The first document gets the "selected" class
- selected_class = " selected" if i == 1 else ""
- html += f"""
- <div id="doc-{i}" class="document{selected_class}" tabindex="0">
- <div class="document-id">Document ID: {doc_id}</div>
- <pre class="document-text">{pyhtml.escape(doc_text)}</pre>
- </div>
- """
- # Add JavaScript for keyboard navigation and summary toggle
- html += """
- </div>
- </div>
- <script>
- // Get all documents
- const documents = document.querySelectorAll('.document');
- let selectedIndex = 0; // First document is selected by default
- let summaryVisible = false;
-
- // Function to toggle summary panel
- function toggleSummary() {
- const panel = document.getElementById('summary-panel');
- const button = document.querySelector('.toggle-summary');
-
- if (summaryVisible) {
- panel.style.display = 'none';
- button.textContent = 'Show Summary';
- } else {
- panel.style.display = 'block';
- button.textContent = 'Hide Summary';
- }
-
- summaryVisible = !summaryVisible;
- }
-
- // Function to select a document
- function selectDocument(index) {
- // Validate index
- if (index < 0) index = 0;
- if (index >= documents.length) index = documents.length - 1;
-
- // Store current index for use in setTimeout
- const targetIndex = index;
-
- // Remove selected class from all documents
- documents.forEach(doc => doc.classList.remove('selected'));
-
- // Add selected class to the current document
- documents[targetIndex].classList.add('selected');
-
- // Update selected index
- selectedIndex = targetIndex;
-
- // Use a more direct approach for scrolling
- // Get the element's offset from the top of the document
- const headerHeight = 60; // Fixed header height
- const element = documents[targetIndex];
- const elementPosition = element.offsetTop;
-
- // Scroll the element to the top of the viewport, accounting for header
- window.scrollTo({
- top: elementPosition - headerHeight,
- behavior: 'smooth'
- });
-
- // Focus the selected document for accessibility
- documents[targetIndex].focus();
- }
-
- // Add keyboard event listener to the document
- document.addEventListener('keydown', function(event) {
- // Arrow up
- if (event.key === 'ArrowUp') {
- event.preventDefault();
- selectDocument(selectedIndex - 1);
- }
- // Arrow down
- else if (event.key === 'ArrowDown') {
- event.preventDefault();
- selectDocument(selectedIndex + 1);
- }
- // Home key - go to first document
- else if (event.key === 'Home') {
- event.preventDefault();
- selectDocument(0);
- }
- // End key - go to last document
- else if (event.key === 'End') {
- event.preventDefault();
- selectDocument(documents.length - 1);
- }
- // Escape key - hide summary if visible
- else if (event.key === 'Escape' && summaryVisible) {
- toggleSummary();
- }
- // S key - toggle summary
- else if (event.key === 's' || event.key === 'S') {
- toggleSummary();
- }
- });
-
- // Make documents clickable to select them
- documents.forEach((doc, index) => {
- doc.addEventListener('click', () => {
- selectDocument(index);
- });
- });
-
- // Select the first document when the page loads
- window.addEventListener('load', () => {
- // If there are documents, select the first one
- if (documents.length > 0) {
- selectDocument(0);
- }
- });
- </script>
- </body>
- </html>
- """
- # Create directory if it doesn't exist
- os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
- # Write HTML to file
- with open(output_path, "w", encoding="utf-8") as f:
- f.write(html)
- logger.info(f"Generated HTML report: {output_path}")
- def main():
- global args
- args = parse_args()
- # Set up logging based on arguments
- if args.debug:
- logger.setLevel(logging.DEBUG)
- logger.debug("Debug logging enabled")
- # Set up S3 client if needed
- s3_client = None
- if args.docs_folder.startswith("s3://") or (args.attr_folder and args.attr_folder.startswith("s3://")):
- session = boto3.Session(profile_name=args.aws_profile) if args.aws_profile else boto3.Session()
- s3_client = session.client("s3")
- # Parse the rules
- logger.info(f"Parsing reference rule expression: {args.ref_rule}")
- ref_rule = parse_rule(args.ref_rule)
- logger.info(f"Parsing hypothesis rule expression: {args.hyp_rule}")
- hyp_rule = parse_rule(args.hyp_rule)
- # Generate string representations of the expressions
- ref_rule_str = get_expression_summary(ref_rule)
- hyp_rule_str = get_expression_summary(hyp_rule)
- logger.info(f"Reference rule parsed as: {ref_rule_str}")
- logger.info(f"Hypothesis rule parsed as: {hyp_rule_str}")
- # Determine attributes folder
- attr_folder = get_attributes_folder(args.docs_folder, args.attr_folder)
- logger.info(f"Using attributes folder: {attr_folder}")
- # Load documents and merge with attributes from all subdirectories
- logger.info("Loading documents and merging with all attributes...")
- all_docs = load_documents_and_attributes(args.docs_folder, attr_folder, s3_client, args.recursive)
- # Create output directory if it doesn't exist
- os.makedirs(args.output_dir, exist_ok=True)
- # Use the same documents for both reference and hypothesis evaluation
- # since we've loaded all attributes into each document
- ref_docs = all_docs
- hyp_docs = all_docs
- # Compare the documents
- logger.info("Comparing documents using reference and hypothesis rules...")
- comparison_result = compare_documents(ref_docs, hyp_docs, ref_rule, hyp_rule)
- # Get document IDs for each category
- ref_matches = set()
- hyp_matches = set()
- # Create mappings from document IDs to documents
- doc_map = {doc["id"]: doc for doc in all_docs if "id" in doc}
- # Find documents that match the reference and hypothesis rules
- for doc_id, doc in doc_map.items():
- if apply_rule(doc, ref_rule):
- ref_matches.add(doc_id)
- if apply_rule(doc, hyp_rule):
- hyp_matches.add(doc_id)
- # Calculate document sets for each category
- true_positives_ids = ref_matches.intersection(hyp_matches)
- true_negatives_ids = set(doc_map.keys()) - ref_matches - hyp_matches
- false_positives_ids = hyp_matches - ref_matches
- false_negatives_ids = ref_matches - hyp_matches
- # Create document lists for each category
- true_positives = [doc_map[doc_id] for doc_id in true_positives_ids]
- true_negatives = [doc_map[doc_id] for doc_id in true_negatives_ids]
- false_positives = [doc_map[doc_id] for doc_id in false_positives_ids]
- false_negatives = [doc_map[doc_id] for doc_id in false_negatives_ids]
- # Calculate metrics
- tp = len(true_positives)
- tn = len(true_negatives)
- fp = len(false_positives)
- fn = len(false_negatives)
- precision = tp / (tp + fp) if (tp + fp) > 0 else 0
- recall = tp / (tp + fn) if (tp + fn) > 0 else 0
- f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
- iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
- # Prepare overall statistics
- overall_stats = {
- "total_docs": len(doc_map),
- "ref_matches": len(ref_matches),
- "hyp_matches": len(hyp_matches),
- "true_positives": tp,
- "true_negatives": tn,
- "false_positives": fp,
- "false_negatives": fn,
- "precision": precision,
- "recall": recall,
- "f1": f1,
- "iou": iou,
- "ref_rule_stats": comparison_result["ref_rule_stats"],
- "hyp_rule_stats": comparison_result["hyp_rule_stats"],
- }
- # Prepare summary
- summary = f"""Reference Rule: {args.ref_rule}
- Hypothesis Rule: {args.hyp_rule}
- Total Documents: {overall_stats['total_docs']}
- Reference Matches: {overall_stats['ref_matches']}
- Hypothesis Matches: {overall_stats['hyp_matches']}
- True Positives: {tp}
- True Negatives: {tn}
- False Positives: {fp}
- False Negatives: {fn}
- Precision: {precision:.4f}
- Recall: {recall:.4f}
- F1 Score: {f1:.4f}
- IoU: {iou:.4f}
- """
- # Generate HTML reports for each category
- logger.info("Generating HTML reports...")
- # True Positives
- generate_html_report(
- true_positives[:1000],
- "True Positives - Documents matching both Reference and Hypothesis Rules",
- summary,
- os.path.join(args.output_dir, "true_positives.html"),
- )
- # True Negatives
- generate_html_report(
- true_negatives[:1000], "True Negatives - Documents not matching either Rule", summary, os.path.join(args.output_dir, "true_negatives.html")
- )
- # False Positives
- generate_html_report(
- false_positives[:1000],
- "False Positives - Documents matching Hypothesis but not Reference Rule",
- summary,
- os.path.join(args.output_dir, "false_positives.html"),
- )
- # False Negatives
- generate_html_report(
- false_negatives[:1000],
- "False Negatives - Documents matching Reference but not Hypothesis Rule",
- summary,
- os.path.join(args.output_dir, "false_negatives.html"),
- )
- # Collect numeric attributes and generate CDF plots if not disabled
- attribute_plots_html = ""
- if not args.disable_plots:
- logger.info("Collecting numeric attributes for CDF plots...")
- numeric_attributes = collect_numeric_attributes(all_docs)
- if numeric_attributes:
- logger.info(f"Found {len(numeric_attributes)} numeric attributes suitable for CDF plots")
- # Generate CDF plots HTML with the specified maximum number of plots
- attribute_plots_html = generate_attribute_plots_html(numeric_attributes, args.max_plots)
- else:
- logger.info("No numeric attributes found for CDF plots")
- else:
- logger.info("CDF plot generation disabled by --disable-plots flag")
- # Generate index.html file that links to all reports
- index_html = f"""<!DOCTYPE html>
- <html>
- <head>
- <meta charset="UTF-8">
- <title>PII Rule Comparison Results</title>
- <style>
- body {{
- font-family: Arial, sans-serif;
- line-height: 1.6;
- margin: 0;
- padding: 20px;
- max-width: 1000px;
- margin: 0 auto;
- }}
- .summary {{
- background-color: #f8f9fa;
- padding: 15px;
- border-radius: 5px;
- margin-bottom: 20px;
- border-left: 5px solid #007bff;
- }}
- .category {{
- margin-bottom: 20px;
- padding: 15px;
- border-radius: 5px;
- }}
- .true-positives {{
- background-color: #d4edda;
- border-left: 5px solid #28a745;
- }}
- .true-negatives {{
- background-color: #e2e3e5;
- border-left: 5px solid #6c757d;
- }}
- .false-positives {{
- background-color: #f8d7da;
- border-left: 5px solid #dc3545;
- }}
- .false-negatives {{
- background-color: #fff3cd;
- border-left: 5px solid #ffc107;
- }}
- h1 {{
- border-bottom: 2px solid #007bff;
- padding-bottom: 10px;
- color: #333;
- }}
- a {{
- color: #007bff;
- text-decoration: none;
- font-weight: bold;
- }}
- a:hover {{
- text-decoration: underline;
- }}
- .attribute-plots {{
- margin-top: 30px;
- }}
- .plot-container {{
- margin-bottom: 30px;
- padding: 15px;
- background-color: #fff;
- border-radius: 5px;
- box-shadow: 0 2px 5px rgba(0,0,0,0.1);
- }}
- .cdf-plot {{
- max-width: 100%;
- height: auto;
- }}
- h2 {{
- color: #333;
- border-bottom: 1px solid #eee;
- padding-bottom: 10px;
- margin-top: 30px;
- }}
- h3 {{
- color: #007bff;
- }}
- </style>
- </head>
- <body>
- <h1>PII Rule Comparison Results</h1>
- <div class="summary">
- <h2>Summary</h2>
- <pre>{summary}</pre>
- </div>
- <h2>Result Categories</h2>
- <div class="category true-positives">
- <h3>True Positives: {tp}</h3>
- <p>Documents that match both the reference and hypothesis rules.</p>
- <a href="true_positives.html">View True Positives</a>
- </div>
- <div class="category true-negatives">
- <h3>True Negatives: {tn}</h3>
- <p>Documents that don't match either the reference or hypothesis rules.</p>
- <a href="true_negatives.html">View True Negatives</a>
- </div>
- <div class="category false-positives">
- <h3>False Positives: {fp}</h3>
- <p>Documents that match the hypothesis rule but not the reference rule.</p>
- <a href="false_positives.html">View False Positives</a>
- </div>
- <div class="category false-negatives">
- <h3>False Negatives: {fn}</h3>
- <p>Documents that match the reference rule but not the hypothesis rule.</p>
- <a href="false_negatives.html">View False Negatives</a>
- </div>
-
- {attribute_plots_html}
- </body>
- </html>
- """
- with open(os.path.join(args.output_dir, "index.html"), "w", encoding="utf-8") as f:
- f.write(index_html)
- # Print summary
- logger.info("\n--- COMPARISON SUMMARY ---")
- logger.info(f"Documents Folder: {args.docs_folder}")
- logger.info(f"Attributes Folder: {attr_folder}")
- logger.info(f"Reference Rule Expression: {args.ref_rule}")
- logger.info(f" Parsed as: {ref_rule_str}")
- logger.info(f"Hypothesis Rule Expression: {args.hyp_rule}")
- logger.info(f" Parsed as: {hyp_rule_str}")
- logger.info(f"Total Documents: {overall_stats['total_docs']}")
- # Print rule statistics
- logger.info("\n--- RULE MATCH STATISTICS ---")
- logger.info("\nReference Rules:")
- logger.info(format_rule_stats(overall_stats["ref_rule_stats"]))
- logger.info("\nHypothesis Rules:")
- logger.info(format_rule_stats(overall_stats["hyp_rule_stats"]))
- # Print comparison metrics
- logger.info("\n--- COMPARISON METRICS ---")
- logger.info(f"True Positives: {tp}")
- logger.info(f"True Negatives: {tn}")
- logger.info(f"False Positives: {fp}")
- logger.info(f"False Negatives: {fn}")
- logger.info(f"Precision: {precision:.4f}")
- logger.info(f"Recall: {recall:.4f}")
- logger.info(f"F1 Score: {f1:.4f}")
- logger.info(f"IoU: {iou:.4f}")
- # Output all available attributes that have been loaded
- logger.info("\n--- AVAILABLE ATTRIBUTES ---")
- all_attributes = set()
- for doc in all_docs:
- if "attributes" in doc and doc["attributes"]:
- all_attributes.update(doc["attributes"].keys())
- if all_attributes:
- logger.info(f"Found {len(all_attributes)} unique attributes:")
- for attr in sorted(all_attributes):
- logger.info(f" - {attr}")
- else:
- logger.info("No attributes found in any documents.")
- logger.info(f"\nResults saved to: {args.output_dir}/index.html")
- if __name__ == "__main__":
- main()
- # Example commands with actual S3 paths:
- """
- # Example for AI2 OE data with resume detection:
- python scripts/pii_rule_comparison.py \
- --docs-folder s3://ai2-oe-data/jakep/s2pdf_dedupe_minhash_v1_mini/documents/ \
- --ref-rule "gpt_4_1_contains_pii:any and not gpt_4_1_is_public_document:all" \
- --hyp-rule "google_gemma-3-4b-it_is_resume_cv:any" \
- --output-dir results/resume_detection \
- --recursive \
- --debug
- # Example for PII detection comparison:
- python scripts/pii_rule_comparison.py \
- --docs-folder s3://allenai-dolma/documents/v1.5 \
- --ref-rule "contains_pii:any" \
- --hyp-rule "(contains_email_addresses:any or contains_phone_numbers:any) and not false_positive:any" \
- --output-dir results/pii_detection \
- --recursive \
- --aws-profile dolma
- # Example with custom attributes folder:
- python scripts/pii_rule_comparison.py \
- --docs-folder s3://bucket/path/documents \
- --attr-folder s3://bucket/custom/location/attributes \
- --ref-rule "gpt_4_1_contains_pii:any" \
- --hyp-rule "custom_model_pii_detection:any" \
- --output-dir results/custom_comparison \
- --recursive
- """
|