infinigram_count.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #!/usr/bin/env python3
  2. import argparse
  3. import json
  4. import random
  5. import re
  6. import time
  7. import boto3
  8. import requests
  9. from tqdm import tqdm
  10. from transformers import AutoTokenizer
  11. # Allowed characters: alphanumeric, space, and basic punctuation ".,!?()"
  12. ALLOWED_RE = re.compile(r"^[A-Za-z0-9\.,!?() ]+$")
  13. def get_random_line_from_s3(bucket, key):
  14. """
  15. Reads an S3 object line-by-line and returns a random line using reservoir sampling.
  16. """
  17. s3 = boto3.client("s3")
  18. response = s3.get_object(Bucket=bucket, Key=key)
  19. random_line = None
  20. count = 0
  21. for line in response["Body"].iter_lines():
  22. if not line:
  23. continue
  24. line_str = line.decode("utf-8")
  25. count += 1
  26. if random.randint(1, count) == 1:
  27. random_line = line_str
  28. return random_line
  29. def query_infinigram(ngram, index="v4_rpj_llama_s4", retries=3):
  30. """
  31. Sends a count query to the infini-gram API for the given n-gram.
  32. Retries a few times in case of network issues.
  33. """
  34. url = "https://api.infini-gram.io/"
  35. payload = {
  36. "index": index,
  37. "query_type": "count",
  38. "query": ngram,
  39. }
  40. for i in range(retries):
  41. try:
  42. response = requests.post(url, json=payload, timeout=10)
  43. if response.status_code == 200:
  44. result = response.json()
  45. if "count" in result:
  46. return result["count"]
  47. except Exception: # type: ignore
  48. time.sleep(1)
  49. return 0
  50. def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llama_s4"):
  51. """
  52. Tokenizes the document using the Llama2 tokenizer and samples random n-grams.
  53. Each n-gram is chosen such that:
  54. 1. It starts on a word-split boundary (using the offset mapping and a check on the preceding character).
  55. 2. Its decoded string contains only alphanumeric characters, spaces, and the punctuation marks ".,!?()".
  56. Each valid n-gram is then queried using the infini-gram API.
  57. The function returns the document id, the number of matching n-grams (i.e. API count > 0),
  58. the total number of valid n-grams sampled, and a list of tuples (flag, ngram_string).
  59. """
  60. text = doc.get("text", "")
  61. doc_id = doc.get("id", "Unknown")
  62. # Get tokenized representation with offset mapping to determine word boundaries.
  63. tokenized = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
  64. token_ids = tokenized["input_ids"]
  65. # offsets = tokenized["offset_mapping"]
  66. if len(token_ids) < ngram_size:
  67. return doc_id, 0, 0, []
  68. # Determine valid starting indices based on word-split boundaries.
  69. valid_positions = []
  70. # for i in range(len(token_ids) - ngram_size + 1):
  71. # start_offset = offsets[i][0]
  72. # if start_offset == 0 or (start_offset > 0 and text[start_offset - 1] == " "):
  73. # valid_positions.append(i)
  74. if not valid_positions:
  75. # Fallback: if no valid positions are found, use all possible positions.
  76. valid_positions = list(range(len(token_ids) - ngram_size + 1))
  77. valid_ngram_details = []
  78. attempts = 0
  79. max_attempts = num_samples * 10 # Limit to prevent infinite loops.
  80. while len(valid_ngram_details) < num_samples and attempts < max_attempts:
  81. idx = random.choice(valid_positions)
  82. ngram_token_ids = token_ids[idx : idx + ngram_size]
  83. ngram_str = tokenizer.decode(ngram_token_ids, clean_up_tokenization_spaces=True)
  84. # Only accept n-grams that contain only allowed characters.
  85. if ALLOWED_RE.fullmatch(ngram_str) and len(ngram_str.strip()) > ngram_size * 3:
  86. count = query_infinigram(ngram_str, index=index)
  87. flag = "YES" if count > 0 else "NO"
  88. valid_ngram_details.append((flag, ngram_str))
  89. attempts += 1
  90. match_count = sum(1 for flag, _ in valid_ngram_details if flag == "YES")
  91. sample_count = len(valid_ngram_details)
  92. return doc_id, match_count, sample_count, valid_ngram_details
  93. def main():
  94. parser = argparse.ArgumentParser(description="Infini-gram n-gram matching script with Llama2 tokenization.")
  95. parser.add_argument("N", type=int, help="Number of random .jsonl files to process")
  96. parser.add_argument("s3_path", type=str, help="S3 path to a prefix containing .jsonl files (e.g., s3://my-bucket/my-prefix/)")
  97. parser.add_argument("--index", type=str, default="v4_dolma-v1_7_llama", help="Infini-gram index to use (default: v4_rpj_llama_s4)")
  98. parser.add_argument("--ngram_size", type=int, default=10, help="Size of the n-gram to sample (default: 10)")
  99. parser.add_argument("--num_ngrams", type=int, default=100, help="Number of random n-grams to sample from each document (default: 100)")
  100. args = parser.parse_args()
  101. if not args.s3_path.startswith("s3://"):
  102. print("Error: s3_path must start with 's3://'")
  103. return
  104. path_without_scheme = args.s3_path[5:]
  105. parts = path_without_scheme.split("/", 1)
  106. bucket = parts[0]
  107. prefix = parts[1] if len(parts) > 1 else ""
  108. print("Listing .jsonl files from S3...")
  109. s3 = boto3.client("s3")
  110. response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
  111. files = [obj["Key"] for obj in response.get("Contents", []) if obj["Key"].endswith(".jsonl")]
  112. if not files:
  113. print("No .jsonl files found in the given prefix.")
  114. return
  115. if args.N > len(files):
  116. print(f"Requested {args.N} files, but only found {len(files)}. Processing all available files.")
  117. args.N = len(files)
  118. random_files = random.sample(files, args.N)
  119. print("Loading Llama2 tokenizer...")
  120. tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
  121. total_matches = 0
  122. total_ngrams_sampled = 0
  123. for key in tqdm(random_files, desc="Processing files"):
  124. line = get_random_line_from_s3(bucket, key)
  125. if not line:
  126. print(f"Skipping {key}: No valid lines found.")
  127. continue
  128. try:
  129. doc = json.loads(line)
  130. except Exception as e:
  131. print(f"Error parsing JSON in {key}: {e}")
  132. continue
  133. doc_id, match_count, sample_count, details = process_document(doc, tokenizer, args.ngram_size, args.num_ngrams, index=args.index)
  134. # Print per-document n-gram summary
  135. print(f"\nDocument ID: {doc_id}")
  136. for flag, ngram in details:
  137. # Print the flag in a fixed-width field (4 characters) followed by the n-gram representation.
  138. print(f"{flag:4} {repr(ngram)}")
  139. percentage = (match_count / sample_count * 100) if sample_count else 0
  140. print(f"Matched n-grams: {match_count}/{sample_count} ({percentage:.2f}%)")
  141. total_matches += match_count
  142. total_ngrams_sampled += sample_count
  143. overall_percentage = (total_matches / total_ngrams_sampled * 100) if total_ngrams_sampled else 0
  144. print(f"\nTotal matched n-grams: {total_matches}/{total_ngrams_sampled} ({overall_percentage:.2f}%)")
  145. if __name__ == "__main__":
  146. main()