diff --git a/config.yaml b/config.yaml new file mode 100644 index 000000000..527a2bef2 --- /dev/null +++ b/config.yaml @@ -0,0 +1,8 @@ +model: "gpt-4o-2024-11-20" +toc_check_page_num: 20 +max_page_num_each_node: 10 +max_token_num_each_node: 20000 +if_add_node_id: true +if_add_node_summary: true +if_add_doc_description: false +if_add_node_text: false diff --git a/pageindex/config.py b/pageindex/config.py new file mode 100644 index 000000000..99db3348c --- /dev/null +++ b/pageindex/config.py @@ -0,0 +1,91 @@ +import os +import yaml +from pathlib import Path +from typing import Any, Dict, Optional, Union +from pydantic import BaseModel, Field, ValidationError + +class PageIndexConfig(BaseModel): + """ + Configuration schema for PageIndex. + """ + model: str = Field(default="gpt-4o", description="LLM model to use") + + # PDF Processing + toc_check_page_num: int = Field(default=3, description="Number of pages to check for TOC") + max_page_num_each_node: int = Field(default=5, description="Maximum pages per leaf node") + max_token_num_each_node: int = Field(default=4000, description="Max tokens per node") # Approx + + # Enrichment + if_add_node_id: bool = Field(default=True, description="Add unique ID to nodes") + if_add_node_summary: bool = Field(default=True, description="Generate summary for nodes") + if_add_doc_description: bool = Field(default=True, description="Generate doc-level description") + if_add_node_text: bool = Field(default=True, description="Keep raw text in nodes") + + # Tree Optimization + if_thinning: bool = Field(default=True, description="Merge small adjacent nodes") + thinning_threshold: int = Field(default=500, description="Token threshold for merging") + summary_token_threshold: int = Field(default=200, description="Min tokens required to trigger summary generation") + + # Additional + api_key: Optional[str] = Field(default=None, description="OpenAI API Key (optional, prefers env var)") + + class Config: + arbitrary_types_allowed = True + extra = "forbid" + + +class ConfigLoader: + def __init__(self, default_path: Optional[Union[str, Path]] = None): + if default_path is None: + env_path = os.getenv("PAGEINDEX_CONFIG") + if env_path: + default_path = Path(env_path) + else: + cwd_path = Path.cwd() / "config.yaml" + repo_path = Path(__file__).resolve().parents[1] / "config.yaml" + default_path = cwd_path if cwd_path.exists() else repo_path + + self.default_path = default_path + self._default_dict = self._load_yaml(default_path) if default_path else {} + + @staticmethod + def _load_yaml(path: Optional[Path]) -> Dict[str, Any]: + if not path or not path.exists(): + return {} + try: + with open(path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) or {} + except Exception as e: + print(f"Warning: Failed to load config from {path}: {e}") + return {} + + def load(self, user_opt: Optional[Union[Dict[str, Any], Any]] = None) -> PageIndexConfig: + """ + Load configuration, merging defaults with user overrides and validating via Pydantic. + + Args: + user_opt: Dictionary or object with overrides. + + Returns: + PageIndexConfig: Validated configuration object. + """ + user_dict: Dict[str, Any] = {} + if user_opt is None: + pass + elif hasattr(user_opt, '__dict__'): + # Handle SimpleNamespace or other objects + user_dict = {k: v for k, v in vars(user_opt).items() if v is not None} + elif isinstance(user_opt, dict): + user_dict = {k: v for k, v in user_opt.items() if v is not None} + else: + raise TypeError(f"user_opt must be dict or object, got {type(user_opt)}") + + # Merge defaults and user overrides + # Pydantic accepts kwargs, efficiently merging + merged_data = {**self._default_dict, **user_dict} + + try: + return PageIndexConfig(**merged_data) + except ValidationError as e: + # Re-raise nicely or log + raise ValueError(f"Configuration validation failed: {e}") diff --git a/pageindex/core/__init__.py b/pageindex/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pageindex/core/llm.py b/pageindex/core/llm.py new file mode 100644 index 000000000..264788c76 --- /dev/null +++ b/pageindex/core/llm.py @@ -0,0 +1,245 @@ +import tiktoken +import openai +import logging +import os +import time +import json +import asyncio +from typing import Optional, List, Dict, Any, Union, Tuple +from dotenv import load_dotenv + +load_dotenv() + +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("CHATGPT_API_KEY") + +def count_tokens(text: Optional[str], model: str = "gpt-4o") -> int: + """ + Count the number of tokens in a text string using the specified model's encoding. + + Args: + text (Optional[str]): The text to encode. If None, returns 0. + model (str): The model name to use for encoding. Defaults to "gpt-4o". + + Returns: + int: The number of tokens. + """ + if not text: + return 0 + try: + enc = tiktoken.encoding_for_model(model) + except KeyError: + # Fallback for newer or unknown models + enc = tiktoken.get_encoding("cl100k_base") + tokens = enc.encode(text) + return len(tokens) + +def ChatGPT_API_with_finish_reason( + model: str, + prompt: str, + api_key: Optional[str] = OPENAI_API_KEY, + chat_history: Optional[List[Dict[str, str]]] = None +) -> Tuple[str, str]: + """ + Call OpenAI Chat Completion API and return content along with finish reason. + + Args: + model (str): The model name (e.g., "gpt-4o"). + prompt (str): The user prompt. + api_key (Optional[str]): OpenAI API key. Defaults to env var. + chat_history (Optional[List[Dict[str, str]]]): Previous messages for context. + + Returns: + Tuple[str, str]: A tuple containing (content, finish_reason). + Returns ("Error", "error") if max retries reached. + """ + max_retries = 10 + if not api_key: + logging.error("No API key provided.") + return "Error", "missing_api_key" + + client = openai.OpenAI(api_key=api_key) + for i in range(max_retries): + try: + if chat_history: + messages = chat_history.copy() # Avoid modifying original list if passed by ref (shallow copy enough for append) + messages.append({"role": "user", "content": prompt}) + else: + messages = [{"role": "user", "content": prompt}] + + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + + content = response.choices[0].message.content or "" + finish_reason = response.choices[0].finish_reason + + if finish_reason == "length": + return content, "max_output_reached" + else: + return content, "finished" + + except Exception as e: + print('************* Retrying *************') + logging.error(f"Error: {e}") + if i < max_retries - 1: + time.sleep(1) + else: + logging.error('Max retries reached for prompt: ' + prompt[:50] + '...') + return "Error", "error" + return "Error", "max_retries" + +def ChatGPT_API( + model: str, + prompt: str, + api_key: Optional[str] = OPENAI_API_KEY, + chat_history: Optional[List[Dict[str, str]]] = None +) -> str: + """ + Call OpenAI Chat Completion API and return the content string. + + Args: + model (str): The model name. + prompt (str): The user prompt. + api_key (Optional[str]): OpenAI API key. + chat_history (Optional[List[Dict[str, str]]]): Previous messages. + + Returns: + str: The response content, or "Error" if failed. + """ + max_retries = 10 + if not api_key: + logging.error("No API key provided.") + return "Error" + + client = openai.OpenAI(api_key=api_key) + for i in range(max_retries): + try: + if chat_history: + messages = chat_history.copy() + messages.append({"role": "user", "content": prompt}) + else: + messages = [{"role": "user", "content": prompt}] + + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + + return response.choices[0].message.content or "" + except Exception as e: + print('************* Retrying *************') + logging.error(f"Error: {e}") + if i < max_retries - 1: + time.sleep(1) + else: + logging.error('Max retries reached for prompt: ' + prompt[:50] + '...') + return "Error" + return "Error" + +async def ChatGPT_API_async( + model: str, + prompt: str, + api_key: Optional[str] = OPENAI_API_KEY +) -> str: + """ + Asynchronously call OpenAI Chat Completion API. + + Args: + model (str): The model name. + prompt (str): The user prompt. + api_key (Optional[str]): OpenAI API key. + + Returns: + str: The response content, or "Error" if failed. + """ + max_retries = 10 + if not api_key: + logging.error("No API key provided.") + return "Error" + + messages = [{"role": "user", "content": prompt}] + for i in range(max_retries): + try: + async with openai.AsyncOpenAI(api_key=api_key) as client: + response = await client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + return response.choices[0].message.content or "" + except Exception as e: + print('************* Retrying *************') + logging.error(f"Error: {e}") + if i < max_retries - 1: + await asyncio.sleep(1) + else: + logging.error('Max retries reached for prompt: ' + prompt[:50] + '...') + return "Error" + return "Error" + +def get_json_content(response: str) -> str: + """ + Extract content inside markdown JSON code blocks. + + Args: + response (str): The full raw response string. + + Returns: + str: The extracted JSON string stripped of markers. + """ + start_idx = response.find("```json") + if start_idx != -1: + start_idx += 7 + response = response[start_idx:] + + end_idx = response.rfind("```") + if end_idx != -1: + response = response[:end_idx] + + json_content = response.strip() + return json_content + +def extract_json(content: str) -> Union[Dict[str, Any], List[Any]]: + """ + Robustly extract and parse JSON from a string, handling common LLM formatting issues. + + Args: + content (str): The text containing JSON. + + Returns: + Union[Dict, List]: The parsed JSON object or empty dict/list on failure. + """ + try: + # First, try to extract JSON enclosed within ```json and ``` + start_idx = content.find("```json") + if start_idx != -1: + start_idx += 7 # Adjust index to start after the delimiter + end_idx = content.rfind("```") + json_content = content[start_idx:end_idx].strip() + else: + # If no delimiters, assume entire content could be JSON + json_content = content.strip() + + # Clean up common issues that might cause parsing errors + json_content = json_content.replace('None', 'null') # Replace Python None with JSON null + json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines + json_content = ' '.join(json_content.split()) # Normalize whitespace + + # Attempt to parse and return the JSON object + return json.loads(json_content) + except json.JSONDecodeError as e: + logging.error(f"Failed to extract JSON: {e}") + # Try to clean up the content further if initial parsing fails + try: + # Remove any trailing commas before closing brackets/braces + json_content = json_content.replace(',]', ']').replace(',}', '}') + return json.loads(json_content) + except: + logging.error("Failed to parse JSON even after cleanup") + return {} + except Exception as e: + logging.error(f"Unexpected error while extracting JSON: {e}") + return {} diff --git a/pageindex/core/logging.py b/pageindex/core/logging.py new file mode 100644 index 000000000..9e7cd0be3 --- /dev/null +++ b/pageindex/core/logging.py @@ -0,0 +1,65 @@ +import os +import json +from datetime import datetime +from typing import Any, Dict, Optional, Union, List +from .pdf import get_pdf_name + +class JsonLogger: + """ + A simple JSON-based logger that writes distinct log files for each run session. + """ + def __init__(self, file_path: Union[str, Any]): + """ + Initialize the logger. + + Args: + file_path (Union[str, Any]): The source file path (usually PDF) to derive the log filename from. + """ + # Extract PDF name for logger name + pdf_name = get_pdf_name(file_path) + + current_time = datetime.now().strftime("%Y%m%d_%H%M%S") + self.filename = f"{pdf_name}_{current_time}.json" + os.makedirs("./logs", exist_ok=True) + # Initialize empty list to store all messages + self.log_data: List[Dict[str, Any]] = [] + + def log(self, level: str, message: Union[str, Dict[str, Any]], **kwargs: Any) -> None: + """ + Log a message. + + Args: + level (str): Log level (INFO, ERROR, etc.) + message (Union[str, Dict]): The message content. + """ + entry: Dict[str, Any] = {} + if isinstance(message, dict): + entry = message + else: + entry = {'message': message} + + entry['level'] = level + entry['timestamp'] = datetime.now().isoformat() + entry.update(kwargs) + + self.log_data.append(entry) + + # Write entire log data to file (inefficient for large logs, but simple for now) + with open(self._filepath(), "w", encoding='utf-8') as f: + json.dump(self.log_data, f, indent=2, ensure_ascii=False) + + def info(self, message: Union[str, Dict[str, Any]], **kwargs: Any) -> None: + self.log("INFO", message, **kwargs) + + def error(self, message: Union[str, Dict[str, Any]], **kwargs: Any) -> None: + self.log("ERROR", message, **kwargs) + + def debug(self, message: Union[str, Dict[str, Any]], **kwargs: Any) -> None: + self.log("DEBUG", message, **kwargs) + + def exception(self, message: Union[str, Dict[str, Any]], **kwargs: Any) -> None: + kwargs["exception"] = True + self.log("ERROR", message, **kwargs) + + def _filepath(self) -> str: + return os.path.join("logs", self.filename) diff --git a/pageindex/core/pdf.py b/pageindex/core/pdf.py new file mode 100644 index 000000000..855c08566 --- /dev/null +++ b/pageindex/core/pdf.py @@ -0,0 +1,207 @@ +import PyPDF2 +import pymupdf +import re +import os +import tiktoken +from io import BytesIO +from typing import List, Tuple, Union, Optional +from .llm import count_tokens + +def extract_text_from_pdf(pdf_path: str) -> str: + """ + Extract all text from a PDF file using PyPDF2. + + Args: + pdf_path (str): Path to the PDF file. + + Returns: + str: Concatenated text from all pages. + """ + pdf_reader = PyPDF2.PdfReader(pdf_path) + text = "" + for page_num in range(len(pdf_reader.pages)): + page = pdf_reader.pages[page_num] + text += page.extract_text() + return text + +def get_pdf_title(pdf_path: Union[str, BytesIO]) -> str: + """ + Extract the title from PDF metadata. + + Args: + pdf_path (Union[str, BytesIO]): Path to PDF or BytesIO object. + + Returns: + str: Title of the PDF or 'Untitled'. + """ + pdf_reader = PyPDF2.PdfReader(pdf_path) + meta = pdf_reader.metadata + title = meta.title if meta and meta.title else 'Untitled' + return title + +def get_text_of_pages(pdf_path: str, start_page: int, end_page: int, tag: bool = True) -> str: + """ + Get text from a specific range of pages in a PDF. + + Args: + pdf_path (str): Path to the PDF file. + start_page (int): Start page number (1-based). + end_page (int): End page number (1-based). + tag (bool): If True, wraps page text in ... tags. + + Returns: + str: Extracted text. + """ + pdf_reader = PyPDF2.PdfReader(pdf_path) + text = "" + for page_num in range(start_page-1, end_page): + if page_num < len(pdf_reader.pages): + page = pdf_reader.pages[page_num] + page_text = page.extract_text() + if tag: + text += f"\n{page_text}\n\n" + else: + text += page_text + return text + +def get_first_start_page_from_text(text: str) -> int: + """ + Extract the first page index tag found in text. + + Args: + text (str): Text containing tags. + + Returns: + int: Page number or -1 if not found. + """ + start_page = -1 + start_page_match = re.search(r'', text) + if start_page_match: + start_page = int(start_page_match.group(1)) + return start_page + +def get_last_start_page_from_text(text: str) -> int: + """ + Extract the last page index tag found in text. + + Args: + text (str): Text containing tags. + + Returns: + int: Page number or -1 if not found. + """ + start_page = -1 + start_page_matches = re.finditer(r'', text) + matches_list = list(start_page_matches) + if matches_list: + start_page = int(matches_list[-1].group(1)) + return start_page + + +def sanitize_filename(filename: str, replacement: str = '-') -> str: + """Replace illegal characters in filename.""" + return filename.replace('/', replacement) + +def get_pdf_name(pdf_path: Union[str, BytesIO]) -> str: + """ + Get a sanitized name for the PDF file. + + Args: + pdf_path (Union[str, BytesIO]): Path or file object. + + Returns: + str: Filename or logical title. + """ + pdf_name = "Untitled.pdf" + if isinstance(pdf_path, str): + pdf_name = os.path.basename(pdf_path) + elif isinstance(pdf_path, BytesIO): + pdf_reader = PyPDF2.PdfReader(pdf_path) + meta = pdf_reader.metadata + if meta and meta.title: + pdf_name = meta.title + pdf_name = sanitize_filename(pdf_name) + return pdf_name + + +def get_page_tokens( + pdf_path: Union[str, BytesIO], + model: str = "gpt-4o-2024-11-20", + pdf_parser: str = "PyPDF2" +) -> List[Tuple[str, int]]: + """ + Extract text and token counts for each page. + + Args: + pdf_path (Union[str, BytesIO]): Path to PDF. + model (str): Model name for token counting. + pdf_parser (str): "PyPDF2" or "PyMuPDF". + + Returns: + List[Tuple[str, int]]: List of (page_text, token_count). + """ + enc = tiktoken.encoding_for_model(model) + if pdf_parser == "PyPDF2": + pdf_reader = PyPDF2.PdfReader(pdf_path) + page_list = [] + for page_num in range(len(pdf_reader.pages)): + page = pdf_reader.pages[page_num] + page_text = page.extract_text() + token_length = len(enc.encode(page_text)) + page_list.append((page_text, token_length)) + return page_list + elif pdf_parser == "PyMuPDF": + if isinstance(pdf_path, BytesIO): + pdf_stream = pdf_path + doc = pymupdf.open(stream=pdf_stream, filetype="pdf") + elif isinstance(pdf_path, str) and os.path.isfile(pdf_path) and pdf_path.lower().endswith(".pdf"): + doc = pymupdf.open(pdf_path) + else: + raise ValueError(f"Invalid pdf path for PyMuPDF: {pdf_path}") + + page_list = [] + for page in doc: + page_text = page.get_text() + token_length = len(enc.encode(page_text)) + page_list.append((page_text, token_length)) + return page_list + else: + raise ValueError(f"Unsupported PDF parser: {pdf_parser}") + + + +def get_text_of_pdf_pages(pdf_pages: List[Tuple[str, int]], start_page: int, end_page: int) -> str: + """ + Combine text from a list of page tuples [1-based range]. + + Args: + pdf_pages (List[Tuple[str, int]]): Output from get_page_tokens. + start_page (int): Start page (1-based). + end_page (int): End page (1-based, inclusive). + + Returns: + str: Combined text. + """ + text = "" + # Safe indexing + total_pages = len(pdf_pages) + for page_num in range(start_page-1, end_page): + if 0 <= page_num < total_pages: + text += pdf_pages[page_num][0] + return text + +def get_text_of_pdf_pages_with_labels(pdf_pages: List[Tuple[str, int]], start_page: int, end_page: int) -> str: + """ + Combine text from pages with tags. + """ + text = "" + total_pages = len(pdf_pages) + for page_num in range(start_page-1, end_page): + if 0 <= page_num < total_pages: + text += f"\n{pdf_pages[page_num][0]}\n\n" + return text + +def get_number_of_pages(pdf_path: Union[str, BytesIO]) -> int: + """Get total page count of a PDF.""" + pdf_reader = PyPDF2.PdfReader(pdf_path) + return len(pdf_reader.pages) diff --git a/pageindex/core/tree.py b/pageindex/core/tree.py new file mode 100644 index 000000000..762319900 --- /dev/null +++ b/pageindex/core/tree.py @@ -0,0 +1,545 @@ +import copy +import json +import asyncio +from typing import List, Dict, Any, Optional, Union +from .llm import count_tokens, ChatGPT_API, ChatGPT_API_async + +# Type aliases for tree structures +Node = Dict[str, Any] +Tree = List[Node] +Structure = Union[Node, List[Any]] # Recursive definition limitation in MyPy, using Any for nested + +def write_node_id(data: Structure, node_id: int = 0) -> int: + """ + Recursively assign sequential node_ids to a tree structure. + + Args: + data (Structure): The tree or node to process. + node_id (int): The starting ID. + + Returns: + int: The next available node_id. + """ + if isinstance(data, dict): + data['node_id'] = str(node_id).zfill(4) + node_id += 1 + for key in list(data.keys()): + if 'nodes' in key: + node_id = write_node_id(data[key], node_id) + elif isinstance(data, list): + for index in range(len(data)): + node_id = write_node_id(data[index], node_id) + return node_id + +def get_nodes(structure: Structure) -> List[Node]: + """ + Flatten the tree into a list of nodes, excluding their children 'nodes' list from the copy. + + Args: + structure (Structure): The tree structure. + + Returns: + List[Node]: A flat list of node dictionaries (without 'nodes' key). + """ + if isinstance(structure, dict): + structure_node = copy.deepcopy(structure) + structure_node.pop('nodes', None) + nodes = [structure_node] + for key in list(structure.keys()): + if 'nodes' in key: + nodes.extend(get_nodes(structure[key])) + return nodes + elif isinstance(structure, list): + nodes = [] + for item in structure: + nodes.extend(get_nodes(item)) + return nodes + return [] + +def structure_to_list(structure: Structure) -> List[Node]: + """ + Flatten the tree into a list of references to all nodes (including containers). + + Args: + structure (Structure): The tree structure. + + Returns: + List[Node]: Flat list of all nodes. + """ + if isinstance(structure, dict): + nodes = [] + nodes.append(structure) + if 'nodes' in structure: + nodes.extend(structure_to_list(structure['nodes'])) + return nodes + elif isinstance(structure, list): + nodes = [] + for item in structure: + nodes.extend(structure_to_list(item)) + return nodes + return [] + + +def get_leaf_nodes(structure: Structure) -> List[Node]: + """ + Get all leaf nodes (nodes with no children). + + Args: + structure (Structure): The tree structure. + + Returns: + List[Node]: List of leaf node copies (without 'nodes' key). + """ + if isinstance(structure, dict): + if not structure.get('nodes'): + structure_node = copy.deepcopy(structure) + structure_node.pop('nodes', None) + return [structure_node] + else: + leaf_nodes = [] + for key in list(structure.keys()): + if 'nodes' in key: + leaf_nodes.extend(get_leaf_nodes(structure[key])) + return leaf_nodes + elif isinstance(structure, list): + leaf_nodes = [] + for item in structure: + leaf_nodes.extend(get_leaf_nodes(item)) + return leaf_nodes + return [] + +def is_leaf_node(data: Structure, node_id: str) -> bool: + """ + Check if a node with specific ID is a leaf node. + + Args: + data (Structure): The tree structure. + node_id (str): The ID to check. + + Returns: + bool: True if node exists and has no children. + """ + # Helper function to find the node by its node_id + def find_node(data: Structure, node_id: str) -> Optional[Node]: + if isinstance(data, dict): + if data.get('node_id') == node_id: + return data + for key in data.keys(): + if 'nodes' in key: + result = find_node(data[key], node_id) + if result: + return result + elif isinstance(data, list): + for item in data: + result = find_node(item, node_id) + if result: + return result + return None + + # Find the node with the given node_id + node = find_node(data, node_id) + + # Check if the node is a leaf node + if node and not node.get('nodes'): + return True + return False + +def get_last_node(structure: List[Any]) -> Any: + """Get the last element of a list structure.""" + return structure[-1] + +def list_to_tree(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Convert a flat list of nodes with dot-notation 'structure' keys (e.g., '1.1') + into a nested tree. + + Args: + data (List[Dict[str, Any]]): List of node dictionaries. + + Returns: + List[Dict[str, Any]]: The nested tree structure. + """ + def get_parent_structure(structure: Optional[str]) -> Optional[str]: + """Helper function to get the parent structure code""" + if not structure: + return None + parts = str(structure).split('.') + return '.'.join(parts[:-1]) if len(parts) > 1 else None + + # First pass: Create nodes and track parent-child relationships + nodes: Dict[str, Dict[str, Any]] = {} + root_nodes: List[Dict[str, Any]] = [] + + for item in data: + structure = str(item.get('structure', '')) + node = { + 'title': item.get('title'), + 'start_index': item.get('start_index'), + 'end_index': item.get('end_index'), + 'nodes': [] + } + + nodes[structure] = node + + # Find parent + parent_structure = get_parent_structure(structure) + + if parent_structure: + # Add as child to parent if parent exists + if parent_structure in nodes: + nodes[parent_structure]['nodes'].append(node) + else: + root_nodes.append(node) + else: + # No parent, this is a root node + root_nodes.append(node) + + # Helper function to clean empty children arrays + def clean_node(node: Dict[str, Any]) -> Dict[str, Any]: + if not node['nodes']: + del node['nodes'] + else: + for child in node['nodes']: + clean_node(child) + return node + + # Clean and return the tree + return [clean_node(node) for node in root_nodes] + +def add_preface_if_needed(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Inject a Preface node if the first node starts after page 1. + """ + if not isinstance(data, list) or not data: + return data + + if data[0].get('physical_index') is not None and data[0]['physical_index'] > 1: + preface_node = { + "structure": "0", + "title": "Preface", + "physical_index": 1, + } + data.insert(0, preface_node) + return data + + +def post_processing(structure: List[Dict[str, Any]], end_physical_index: int) -> Union[List[Dict[str, Any]], List[Any]]: + """ + Calculate start/end indices based on 'physical_index' and convert to tree if possible. + + Args: + structure: List of flat nodes. + end_physical_index: Total pages or end index. + + Returns: + Tree or List. + """ + # First convert page_number to start_index in flat list + for i, item in enumerate(structure): + item['start_index'] = item.get('physical_index') + if i < len(structure) - 1: + if structure[i + 1].get('appear_start') == 'yes': + item['end_index'] = structure[i + 1]['physical_index']-1 + else: + item['end_index'] = structure[i + 1]['physical_index'] + else: + item['end_index'] = end_physical_index + tree = list_to_tree(structure) + if len(tree)!=0: + return tree + else: + ### remove appear_start + for node in structure: + node.pop('appear_start', None) + node.pop('physical_index', None) + return structure + +def clean_structure_post(data: Structure) -> Structure: + """Recursively clean internal processing fields from structure.""" + if isinstance(data, dict): + data.pop('page_number', None) + data.pop('start_index', None) + data.pop('end_index', None) + if 'nodes' in data: + clean_structure_post(data['nodes']) + elif isinstance(data, list): + for section in data: + clean_structure_post(section) + return data + +def remove_fields(data: Structure, fields: List[str] = ['text']) -> Structure: + """Recursively remove specified fields from the structure.""" + if isinstance(data, dict): + return {k: remove_fields(v, fields) + for k, v in data.items() if k not in fields} + elif isinstance(data, list): + return [remove_fields(item, fields) for item in data] + return data + +def print_toc(tree: List[Dict[str, Any]], indent: int = 0) -> None: + """Print Table of Contents to stdout.""" + for node in tree: + print(' ' * indent + str(node.get('title', ''))) + if node.get('nodes'): + print_toc(node['nodes'], indent + 1) + +def print_json(data: Any, max_len: int = 40, indent: int = 2) -> None: + """Pretty print JSON with truncated strings.""" + def simplify_data(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: simplify_data(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [simplify_data(item) for item in obj] + elif isinstance(obj, str) and len(obj) > max_len: + return obj[:max_len] + '...' + else: + return obj + + simplified = simplify_data(data) + print(json.dumps(simplified, indent=indent, ensure_ascii=False)) + + +def print_wrapped(text: Any, width: int = 100) -> None: + """Print text wrapped to specified width.""" + import textwrap + + if text is None: + return + for line in str(text).splitlines(): + if not line.strip(): + print() + continue + for wrapped in textwrap.wrap(line, width=width): + print(wrapped) + + +def print_tree(tree: List[Dict[str, Any]], exclude_fields: Optional[List[str]] = None, indent: int = 0, max_summary_len: int = 120) -> None: + """Print tree structure with node IDs and summaries.""" + if exclude_fields: + # Cast to Any to satisfy mypy since remove_fields returns Structure + tree = remove_fields(tree, fields=exclude_fields) # type: ignore + + for node in tree: + node_id = node.get('node_id', '') + title = node.get('title', '') + start = node.get('start_index') + end = node.get('end_index') + summary = node.get('summary') or node.get('prefix_summary') + page_range = None + if start is not None and end is not None: + page_range = start if start == end else f"{start}-{end}" + line = f"{node_id}\t{page_range}\t{title}" if page_range else f"{node_id}\t{title}" + if summary: + short_summary = summary if len(summary) <= max_summary_len else summary[:max_summary_len] + '...' + line = f"{line} — {short_summary}" + print(' ' * indent + line) + if node.get('nodes'): + print_tree(node['nodes'], exclude_fields=exclude_fields, indent=indent + 1, max_summary_len=max_summary_len) + + +def create_node_mapping(tree: List[Dict[str, Any]], include_page_ranges: bool = False, max_page: Optional[int] = None) -> Dict[str, Any]: + """Create a dictionary mapping node_ids to nodes.""" + mapping = {} + + def clamp_page(value: Optional[int]) -> Optional[int]: + if value is None or max_page is None: + return value + return max(1, min(value, max_page)) + + def visit(node: Dict[str, Any]) -> None: + node_id = node.get('node_id') + if node_id: + if include_page_ranges: + start = clamp_page(node.get('start_index')) + end = clamp_page(node.get('end_index')) + mapping[node_id] = { + 'node': node, + 'start_index': start, + 'end_index': end, + } + else: + mapping[node_id] = node + for child in node.get('nodes') or []: + visit(child) + + for root in tree: + visit(root) + + return mapping + + +def remove_structure_text(data: Structure) -> Structure: + """Recursively remove 'text' field.""" + if isinstance(data, dict): + data.pop('text', None) + if 'nodes' in data: + remove_structure_text(data['nodes']) + elif isinstance(data, list): + for item in data: + remove_structure_text(item) + return data + + +def check_token_limit(structure: Structure, limit: int = 110000) -> None: + """Check if any node exceeds the token limit.""" + flat_list = structure_to_list(structure) + for node in flat_list: + text = node.get('text', '') + num_tokens = count_tokens(text, model='gpt-4o') + if num_tokens > limit: + print(f"Node ID: {node.get('node_id')} has {num_tokens} tokens") + print("Start Index:", node.get('start_index')) + print("End Index:", node.get('end_index')) + print("Title:", node.get('title')) + print("\n") + + +def convert_physical_index_to_int(data: Any) -> Any: + """Convert physical_index strings (e.g., '') to integers inplace.""" + if isinstance(data, list): + for i in range(len(data)): + # Check if item is a dictionary and has 'physical_index' key + if isinstance(data[i], dict) and 'physical_index' in data[i]: + if isinstance(data[i]['physical_index'], str): + if data[i]['physical_index'].startswith('').strip()) + elif data[i]['physical_index'].startswith('physical_index_'): + data[i]['physical_index'] = int(data[i]['physical_index'].split('_')[-1].strip()) + elif isinstance(data, str): + if data.startswith('').strip()) + elif data.startswith('physical_index_'): + data = int(data.split('_')[-1].strip()) + # Check data is int + if isinstance(data, int): + return data + else: + return None + return data + + +def convert_page_to_int(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert 'page' field to int if possible.""" + for item in data: + if 'page' in item and isinstance(item['page'], str): + try: + item['page'] = int(item['page']) + except ValueError: + # Keep original value if conversion fails + pass + return data + +from .pdf import get_text_of_pdf_pages, get_text_of_pdf_pages_with_labels + +def add_node_text(node: Structure, pdf_pages: List[Any]) -> None: + """Recursively add text to nodes from pdf_pages list based on page range.""" + if isinstance(node, dict): + start_page = node.get('start_index') + end_page = node.get('end_index') + if start_page is not None and end_page is not None: + node['text'] = get_text_of_pdf_pages(pdf_pages, start_page, end_page) + if 'nodes' in node: + add_node_text(node['nodes'], pdf_pages) + elif isinstance(node, list): + for index in range(len(node)): + add_node_text(node[index], pdf_pages) + return + + +def add_node_text_with_labels(node: Structure, pdf_pages: List[Any]) -> None: + """Recursively add text with physical index labels.""" + if isinstance(node, dict): + start_page = node.get('start_index') + end_page = node.get('end_index') + if start_page is not None and end_page is not None: + node['text'] = get_text_of_pdf_pages_with_labels(pdf_pages, start_page, end_page) + if 'nodes' in node: + add_node_text_with_labels(node['nodes'], pdf_pages) + elif isinstance(node, list): + for index in range(len(node)): + add_node_text_with_labels(node[index], pdf_pages) + return + + +async def generate_node_summary(node: Dict[str, Any], model: Optional[str] = None) -> str: + """Generate summary for a node using LLM.""" + # Ensure text exists + text = node.get('text', '') + prompt = f"""You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document. + + Partial Document Text: {text} + + Directly return the description, do not include any other text. + """ + # Note: model name should ideally be passed, default handled in API + response = await ChatGPT_API_async(model or "gpt-4o", prompt) + return response + + +async def generate_summaries_for_structure(structure: Structure, model: Optional[str] = None) -> Structure: + """Generate summaries for all nodes in the structure.""" + nodes = structure_to_list(structure) + tasks = [generate_node_summary(node, model=model) for node in nodes] + summaries = await asyncio.gather(*tasks) + + for node, summary in zip(nodes, summaries): + node['summary'] = summary + return structure + + +def create_clean_structure_for_description(structure: Structure) -> Structure: + """ + Create a clean structure for document description generation, + excluding unnecessary fields like 'text'. + """ + if isinstance(structure, dict): + clean_node: Dict[str, Any] = {} + # Only include essential fields for description + for key in ['title', 'node_id', 'summary', 'prefix_summary']: + if key in structure: + clean_node[key] = structure[key] + + # Recursively process child nodes + if 'nodes' in structure and structure['nodes']: + clean_node['nodes'] = create_clean_structure_for_description(structure['nodes']) + + return clean_node + elif isinstance(structure, list): + return [create_clean_structure_for_description(item) for item in structure] # type: ignore + else: + return structure + + +def generate_doc_description(structure: Structure, model: str = "gpt-4o") -> str: + """Generate a one-sentence description for the entire document structure.""" + prompt = f"""Your are an expert in generating descriptions for a document. + You are given a structure of a document. Your task is to generate a one-sentence description for the document, which makes it easy to distinguish the document from other documents. + + Document Structure: {structure} + + Directly return the description, do not include any other text. + """ + response = ChatGPT_API(model, prompt) + return response + + +def reorder_dict(data: Dict[str, Any], key_order: List[str]) -> Dict[str, Any]: + """Reorder dictionary keys.""" + if not key_order: + return data + return {key: data[key] for key in key_order if key in data} + + +def format_structure(structure: Structure, order: Optional[List[str]] = None) -> Structure: + """Recursively format and reorder keys in the structure.""" + if not order: + return structure + if isinstance(structure, dict): + if 'nodes' in structure: + structure['nodes'] = format_structure(structure['nodes'], order) + if not structure.get('nodes'): + structure.pop('nodes', None) + structure = reorder_dict(structure, order) + elif isinstance(structure, list): + structure = [format_structure(item, order) for item in structure] # type: ignore + return structure diff --git a/pageindex/page_index.py b/pageindex/page_index.py index 719255463..a79e6b33c 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -4,7 +4,11 @@ import math import random import re -from .utils import * +from .core.llm import ChatGPT_API, ChatGPT_API_with_finish_reason, ChatGPT_API_async, extract_json, count_tokens, get_json_content +from .core.tree import convert_page_to_int, convert_physical_index_to_int, add_node_text, add_node_text_with_labels +from .core.pdf import get_number_of_pages, get_pdf_title, get_page_tokens, get_text_of_pages, get_first_start_page_from_text, get_last_start_page_from_text +from .core.logging import JsonLogger +from pageindex.config import ConfigLoader import os from concurrent.futures import ThreadPoolExecutor, as_completed @@ -36,7 +40,7 @@ async def check_title_appearance(item, page_list, start_index=1, model=None): }} Directly return the final JSON structure. Do not output anything else.""" - response = await llm_acompletion(model=model, prompt=prompt) + response = await ChatGPT_API_async(model=model, prompt=prompt) response = extract_json(response) if 'answer' in response: answer = response['answer'] @@ -64,7 +68,7 @@ async def check_title_appearance_in_start(title, page_text, model=None, logger=N }} Directly return the final JSON structure. Do not output anything else.""" - response = await llm_acompletion(model=model, prompt=prompt) + response = await ChatGPT_API_async(model=model, prompt=prompt) response = extract_json(response) if logger: logger.info(f"Response: {response}") @@ -116,7 +120,7 @@ def toc_detector_single_page(content, model=None): Directly return the final JSON structure. Do not output anything else. Please note: abstract,summary, notation list, figure list, table list, etc. are not table of contents.""" - response = llm_completion(model=model, prompt=prompt) + response = ChatGPT_API(model=model, prompt=prompt) # print('response', response) json_content = extract_json(response) return json_content['toc_detected'] @@ -135,7 +139,7 @@ def check_if_toc_extraction_is_complete(content, toc, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\n Document:\n' + content + '\n Table of contents:\n' + toc - response = llm_completion(model=model, prompt=prompt) + response = ChatGPT_API(model=model, prompt=prompt) json_content = extract_json(response) return json_content['completed'] @@ -153,7 +157,7 @@ def check_if_toc_transformation_is_complete(content, toc, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc - response = llm_completion(model=model, prompt=prompt) + response = ChatGPT_API(model=model, prompt=prompt) json_content = extract_json(response) return json_content['completed'] @@ -165,7 +169,7 @@ def extract_toc_content(content, model=None): Directly return the full table of contents content. Do not output anything else.""" - response, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) if_complete = check_if_toc_transformation_is_complete(content, response, model) if if_complete == "yes" and finish_reason == "finished": @@ -176,24 +180,23 @@ def extract_toc_content(content, model=None): {"role": "assistant", "content": response}, ] prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" - new_response, finish_reason = llm_completion(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True) + new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history) response = response + new_response if_complete = check_if_toc_transformation_is_complete(content, response, model) - attempt = 0 - max_attempts = 5 - + attempts = 0 + max_attempts = 10 while not (if_complete == "yes" and finish_reason == "finished"): - attempt += 1 - if attempt > max_attempts: + attempts += 1 + if attempts > max_attempts: raise Exception('Failed to complete table of contents after maximum retries') - + chat_history = [ - {"role": "user", "content": prompt}, - {"role": "assistant", "content": response}, + {"role": "user", "content": prompt}, + {"role": "assistant", "content": response}, ] prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" - new_response, finish_reason = llm_completion(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True) + new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history) response = response + new_response if_complete = check_if_toc_transformation_is_complete(content, response, model) @@ -215,7 +218,7 @@ def detect_page_index(toc_content, model=None): }} Directly return the final JSON structure. Do not output anything else.""" - response = llm_completion(model=model, prompt=prompt) + response = ChatGPT_API(model=model, prompt=prompt) json_content = extract_json(response) return json_content['page_index_given_in_toc'] @@ -264,7 +267,7 @@ def toc_index_extractor(toc, content, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = toc_extractor_prompt + '\nTable of contents:\n' + str(toc) + '\nDocument pages:\n' + content - response = llm_completion(model=model, prompt=prompt) + response = ChatGPT_API(model=model, prompt=prompt) json_content = extract_json(response) return json_content @@ -292,7 +295,7 @@ def toc_transformer(toc_content, model=None): Directly return the final JSON structure, do not output anything else. """ prompt = init_prompt + '\n Given table of contents\n:' + toc_content - last_complete, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + last_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) if if_complete == "yes" and finish_reason == "finished": last_complete = extract_json(last_complete) @@ -300,12 +303,13 @@ def toc_transformer(toc_content, model=None): return cleaned_response last_complete = get_json_content(last_complete) - attempt = 0 - max_attempts = 5 + attempts = 0 + max_attempts = 10 while not (if_complete == "yes" and finish_reason == "finished"): - attempt += 1 - if attempt > max_attempts: - raise Exception('Failed to complete toc transformation after maximum retries') + attempts += 1 + if attempts > max_attempts: + raise Exception('Failed to complete table of contents after maximum retries') + position = last_complete.rfind('}') if position != -1: last_complete = last_complete[:position+2] @@ -321,11 +325,17 @@ def toc_transformer(toc_content, model=None): Please continue the json structure, directly output the remaining part of the json structure.""" - new_complete, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + new_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) - if new_complete.startswith('```json'): - new_complete = get_json_content(new_complete) - last_complete = last_complete+new_complete + new_complete_cleaned = new_complete.strip() + if new_complete_cleaned.startswith("```json"): + new_complete_cleaned = new_complete_cleaned[7:] + if new_complete_cleaned.startswith("```"): + new_complete_cleaned = new_complete_cleaned[3:] + if new_complete_cleaned.endswith("```"): + new_complete_cleaned = new_complete_cleaned[:-3] + + last_complete = last_complete + new_complete_cleaned if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) @@ -482,7 +492,7 @@ def add_page_number_to_toc(part, structure, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = fill_prompt_seq + f"\n\nCurrent Partial Document:\n{part}\n\nGiven Structure\n{json.dumps(structure, indent=2)}\n" - current_json_raw = llm_completion(model=model, prompt=prompt) + current_json_raw = ChatGPT_API(model=model, prompt=prompt) json_result = extract_json(current_json_raw) for item in json_result: @@ -504,7 +514,7 @@ def remove_first_physical_index_section(text): return text ### add verify completeness -def generate_toc_continue(toc_content, part, model=None): +def generate_toc_continue(toc_content, part, model="gpt-4o-2024-11-20"): print('start generate_toc_continue') prompt = """ You are an expert in extracting hierarchical tree structure. @@ -532,7 +542,7 @@ def generate_toc_continue(toc_content, part, model=None): Directly return the additional part of the final JSON structure. Do not output anything else.""" prompt = prompt + '\nGiven text\n:' + part + '\nPrevious tree structure\n:' + json.dumps(toc_content, indent=2) - response, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) if finish_reason == 'finished': return extract_json(response) else: @@ -566,7 +576,7 @@ def generate_toc_init(part, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\nGiven text\n:' + part - response, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) if finish_reason == 'finished': return extract_json(response) @@ -674,9 +684,9 @@ def process_none_page_numbers(toc_items, page_list, start_index=1, model=None): page_contents = [] for page_index in range(prev_physical_index, next_physical_index+1): # Add bounds checking to prevent IndexError - list_index = page_index - start_index - if list_index >= 0 and list_index < len(page_list): - page_text = f"\n{page_list[list_index][0]}\n\n\n" + page_list_idx = page_index - start_index + if page_list_idx >= 0 and page_list_idx < len(page_list): + page_text = f"\n{page_list[page_list_idx][0]}\n\n\n" page_contents.append(page_text) else: continue @@ -737,8 +747,27 @@ def check_toc(page_list, opt=None): ################### fix incorrect toc ######################################################### -async def single_toc_item_index_fixer(section_title, content, model=None): - toc_extractor_prompt = """ +def single_toc_item_index_fixer(section_title, content, model="gpt-4o-2024-11-20"): + tob_extractor_prompt = """ + You are given a section title and several pages of a document, your job is to find the physical index of the start page of the section in the partial document. + + The provided pages contains tags like and to indicate the physical location of the page X. + + Reply in a JSON format: + { + "thinking": , contains the start of this section>, + "physical_index": "" (keep the format) + } + Directly return the final JSON structure. Do not output anything else.""" + + prompt = tob_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content + response = ChatGPT_API(model=model, prompt=prompt) + json_content = extract_json(response) + return convert_physical_index_to_int(json_content['physical_index']) + + +async def single_toc_item_index_fixer_async(section_title, content, model="gpt-4o-2024-11-20"): + tob_extractor_prompt = """ You are given a section title and several pages of a document, your job is to find the physical index of the start page of the section in the partial document. The provided pages contains tags like and to indicate the physical location of the page X. @@ -750,8 +779,8 @@ async def single_toc_item_index_fixer(section_title, content, model=None): } Directly return the final JSON structure. Do not output anything else.""" - prompt = toc_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content - response = await llm_acompletion(model=model, prompt=prompt) + prompt = tob_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content + response = await ChatGPT_API_async(model=model, prompt=prompt) json_content = extract_json(response) return convert_physical_index_to_int(json_content['physical_index']) @@ -820,7 +849,7 @@ async def process_and_check_item(incorrect_item): continue content_range = ''.join(page_contents) - physical_index_int = await single_toc_item_index_fixer(incorrect_item['title'], content_range, model) + physical_index_int = await single_toc_item_index_fixer_async(incorrect_item['title'], content_range, model) # Check if the result is correct check_item = incorrect_item.copy() @@ -1074,24 +1103,24 @@ def page_index_main(doc, opt=None): raise ValueError("Unsupported input type. Expected a PDF file path or BytesIO object.") print('Parsing PDF...') - page_list = get_page_tokens(doc, model=opt.model) + page_list = get_page_tokens(doc, model=opt.model if opt is not None else None) logger.info({'total_page_number': len(page_list)}) logger.info({'total_token': sum([page[1] for page in page_list])}) async def page_index_builder(): structure = await tree_parser(page_list, opt, doc=doc, logger=logger) - if opt.if_add_node_id == 'yes': + if opt.if_add_node_id: write_node_id(structure) - if opt.if_add_node_text == 'yes': + if opt.if_add_node_text: add_node_text(structure, page_list) - if opt.if_add_node_summary == 'yes': - if opt.if_add_node_text == 'no': + if opt.if_add_node_summary: + if not opt.if_add_node_text: add_node_text(structure, page_list) await generate_summaries_for_structure(structure, model=opt.model) - if opt.if_add_node_text == 'no': + if not opt.if_add_node_text: remove_structure_text(structure) - if opt.if_add_doc_description == 'yes': + if opt.if_add_doc_description: # Create a clean structure without unnecessary fields for description generation clean_structure = create_clean_structure_for_description(structure) doc_description = generate_doc_description(clean_structure, model=opt.model) diff --git a/pageindex/utils.py b/pageindex/utils.py index 57b69c5b5..855830964 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -1,680 +1,4 @@ -import litellm -import logging -import os -from datetime import datetime -import time -import json -import PyPDF2 -import copy -import asyncio -import pymupdf -from io import BytesIO -from dotenv import load_dotenv -load_dotenv() -import logging -import yaml -from pathlib import Path -from types import SimpleNamespace as config - -# Backward compatibility: support CHATGPT_API_KEY as alias for OPENAI_API_KEY -if not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"): - os.environ["OPENAI_API_KEY"] = os.getenv("CHATGPT_API_KEY") - -litellm.drop_params = True - -def count_tokens(text, model=None): - if not text: - return 0 - return litellm.token_counter(model=model, text=text) - - -def llm_completion(model, prompt, chat_history=None, return_finish_reason=False): - max_retries = 10 - messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}] - for i in range(max_retries): - try: - response = litellm.completion( - model=model, - messages=messages, - temperature=0, - ) - content = response.choices[0].message.content - if return_finish_reason: - finish_reason = "max_output_reached" if response.choices[0].finish_reason == "length" else "finished" - return content, finish_reason - return content - except Exception as e: - print('************* Retrying *************') - logging.error(f"Error: {e}") - if i < max_retries - 1: - time.sleep(1) - else: - logging.error('Max retries reached for prompt: ' + prompt) - if return_finish_reason: - return "", "error" - return "" - - - -async def llm_acompletion(model, prompt): - max_retries = 10 - messages = [{"role": "user", "content": prompt}] - for i in range(max_retries): - try: - response = await litellm.acompletion( - model=model, - messages=messages, - temperature=0, - ) - return response.choices[0].message.content - except Exception as e: - print('************* Retrying *************') - logging.error(f"Error: {e}") - if i < max_retries - 1: - await asyncio.sleep(1) - else: - logging.error('Max retries reached for prompt: ' + prompt) - return "" - - -def get_json_content(response): - start_idx = response.find("```json") - if start_idx != -1: - start_idx += 7 - response = response[start_idx:] - - end_idx = response.rfind("```") - if end_idx != -1: - response = response[:end_idx] - - json_content = response.strip() - return json_content - - -def extract_json(content): - try: - # First, try to extract JSON enclosed within ```json and ``` - start_idx = content.find("```json") - if start_idx != -1: - start_idx += 7 # Adjust index to start after the delimiter - end_idx = content.rfind("```") - json_content = content[start_idx:end_idx].strip() - else: - # If no delimiters, assume entire content could be JSON - json_content = content.strip() - - # Clean up common issues that might cause parsing errors - json_content = json_content.replace('None', 'null') # Replace Python None with JSON null - json_content = json_content.replace('\n', ' ').replace('\r', ' ') # Remove newlines - json_content = ' '.join(json_content.split()) # Normalize whitespace - - # Attempt to parse and return the JSON object - return json.loads(json_content) - except json.JSONDecodeError as e: - logging.error(f"Failed to extract JSON: {e}") - # Try to clean up the content further if initial parsing fails - try: - # Remove any trailing commas before closing brackets/braces - json_content = json_content.replace(',]', ']').replace(',}', '}') - return json.loads(json_content) - except: - logging.error("Failed to parse JSON even after cleanup") - return {} - except Exception as e: - logging.error(f"Unexpected error while extracting JSON: {e}") - return {} - -def write_node_id(data, node_id=0): - if isinstance(data, dict): - data['node_id'] = str(node_id).zfill(4) - node_id += 1 - for key in list(data.keys()): - if 'nodes' in key: - node_id = write_node_id(data[key], node_id) - elif isinstance(data, list): - for index in range(len(data)): - node_id = write_node_id(data[index], node_id) - return node_id - -def get_nodes(structure): - if isinstance(structure, dict): - structure_node = copy.deepcopy(structure) - structure_node.pop('nodes', None) - nodes = [structure_node] - for key in list(structure.keys()): - if 'nodes' in key: - nodes.extend(get_nodes(structure[key])) - return nodes - elif isinstance(structure, list): - nodes = [] - for item in structure: - nodes.extend(get_nodes(item)) - return nodes - -def structure_to_list(structure): - if isinstance(structure, dict): - nodes = [] - nodes.append(structure) - if 'nodes' in structure: - nodes.extend(structure_to_list(structure['nodes'])) - return nodes - elif isinstance(structure, list): - nodes = [] - for item in structure: - nodes.extend(structure_to_list(item)) - return nodes - - -def get_leaf_nodes(structure): - if isinstance(structure, dict): - if not structure['nodes']: - structure_node = copy.deepcopy(structure) - structure_node.pop('nodes', None) - return [structure_node] - else: - leaf_nodes = [] - for key in list(structure.keys()): - if 'nodes' in key: - leaf_nodes.extend(get_leaf_nodes(structure[key])) - return leaf_nodes - elif isinstance(structure, list): - leaf_nodes = [] - for item in structure: - leaf_nodes.extend(get_leaf_nodes(item)) - return leaf_nodes - -def is_leaf_node(data, node_id): - # Helper function to find the node by its node_id - def find_node(data, node_id): - if isinstance(data, dict): - if data.get('node_id') == node_id: - return data - for key in data.keys(): - if 'nodes' in key: - result = find_node(data[key], node_id) - if result: - return result - elif isinstance(data, list): - for item in data: - result = find_node(item, node_id) - if result: - return result - return None - - # Find the node with the given node_id - node = find_node(data, node_id) - - # Check if the node is a leaf node - if node and not node.get('nodes'): - return True - return False - -def get_last_node(structure): - return structure[-1] - - -def extract_text_from_pdf(pdf_path): - pdf_reader = PyPDF2.PdfReader(pdf_path) - ###return text not list - text="" - for page_num in range(len(pdf_reader.pages)): - page = pdf_reader.pages[page_num] - text+=page.extract_text() - return text - -def get_pdf_title(pdf_path): - pdf_reader = PyPDF2.PdfReader(pdf_path) - meta = pdf_reader.metadata - title = meta.title if meta and meta.title else 'Untitled' - return title - -def get_text_of_pages(pdf_path, start_page, end_page, tag=True): - pdf_reader = PyPDF2.PdfReader(pdf_path) - text = "" - for page_num in range(start_page-1, end_page): - page = pdf_reader.pages[page_num] - page_text = page.extract_text() - if tag: - text += f"\n{page_text}\n\n" - else: - text += page_text - return text - -def get_first_start_page_from_text(text): - start_page = -1 - start_page_match = re.search(r'', text) - if start_page_match: - start_page = int(start_page_match.group(1)) - return start_page - -def get_last_start_page_from_text(text): - start_page = -1 - # Find all matches of start_index tags - start_page_matches = re.finditer(r'', text) - # Convert iterator to list and get the last match if any exist - matches_list = list(start_page_matches) - if matches_list: - start_page = int(matches_list[-1].group(1)) - return start_page - - -def sanitize_filename(filename, replacement='-'): - # In Linux, only '/' and '\0' (null) are invalid in filenames. - # Null can't be represented in strings, so we only handle '/'. - return filename.replace('/', replacement) - -def get_pdf_name(pdf_path): - # Extract PDF name - if isinstance(pdf_path, str): - pdf_name = os.path.basename(pdf_path) - elif isinstance(pdf_path, BytesIO): - pdf_reader = PyPDF2.PdfReader(pdf_path) - meta = pdf_reader.metadata - pdf_name = meta.title if meta and meta.title else 'Untitled' - pdf_name = sanitize_filename(pdf_name) - return pdf_name - - -class JsonLogger: - def __init__(self, file_path): - # Extract PDF name for logger name - pdf_name = get_pdf_name(file_path) - - current_time = datetime.now().strftime("%Y%m%d_%H%M%S") - self.filename = f"{pdf_name}_{current_time}.json" - os.makedirs("./logs", exist_ok=True) - # Initialize empty list to store all messages - self.log_data = [] - - def log(self, level, message, **kwargs): - if isinstance(message, dict): - self.log_data.append(message) - else: - self.log_data.append({'message': message}) - # Add new message to the log data - - # Write entire log data to file - with open(self._filepath(), "w") as f: - json.dump(self.log_data, f, indent=2) - - def info(self, message, **kwargs): - self.log("INFO", message, **kwargs) - - def error(self, message, **kwargs): - self.log("ERROR", message, **kwargs) - - def debug(self, message, **kwargs): - self.log("DEBUG", message, **kwargs) - - def exception(self, message, **kwargs): - kwargs["exception"] = True - self.log("ERROR", message, **kwargs) - - def _filepath(self): - return os.path.join("logs", self.filename) - - - - -def list_to_tree(data): - def get_parent_structure(structure): - """Helper function to get the parent structure code""" - if not structure: - return None - parts = str(structure).split('.') - return '.'.join(parts[:-1]) if len(parts) > 1 else None - - # First pass: Create nodes and track parent-child relationships - nodes = {} - root_nodes = [] - - for item in data: - structure = item.get('structure') - node = { - 'title': item.get('title'), - 'start_index': item.get('start_index'), - 'end_index': item.get('end_index'), - 'nodes': [] - } - - nodes[structure] = node - - # Find parent - parent_structure = get_parent_structure(structure) - - if parent_structure: - # Add as child to parent if parent exists - if parent_structure in nodes: - nodes[parent_structure]['nodes'].append(node) - else: - root_nodes.append(node) - else: - # No parent, this is a root node - root_nodes.append(node) - - # Helper function to clean empty children arrays - def clean_node(node): - if not node['nodes']: - del node['nodes'] - else: - for child in node['nodes']: - clean_node(child) - return node - - # Clean and return the tree - return [clean_node(node) for node in root_nodes] - -def add_preface_if_needed(data): - if not isinstance(data, list) or not data: - return data - - if data[0]['physical_index'] is not None and data[0]['physical_index'] > 1: - preface_node = { - "structure": "0", - "title": "Preface", - "physical_index": 1, - } - data.insert(0, preface_node) - return data - - - -def get_page_tokens(pdf_path, model=None, pdf_parser="PyPDF2"): - if pdf_parser == "PyPDF2": - pdf_reader = PyPDF2.PdfReader(pdf_path) - page_list = [] - for page_num in range(len(pdf_reader.pages)): - page = pdf_reader.pages[page_num] - page_text = page.extract_text() - token_length = litellm.token_counter(model=model, text=page_text) - page_list.append((page_text, token_length)) - return page_list - elif pdf_parser == "PyMuPDF": - if isinstance(pdf_path, BytesIO): - pdf_stream = pdf_path - doc = pymupdf.open(stream=pdf_stream, filetype="pdf") - elif isinstance(pdf_path, str) and os.path.isfile(pdf_path) and pdf_path.lower().endswith(".pdf"): - doc = pymupdf.open(pdf_path) - page_list = [] - for page in doc: - page_text = page.get_text() - token_length = litellm.token_counter(model=model, text=page_text) - page_list.append((page_text, token_length)) - return page_list - else: - raise ValueError(f"Unsupported PDF parser: {pdf_parser}") - - - -def get_text_of_pdf_pages(pdf_pages, start_page, end_page): - text = "" - for page_num in range(start_page-1, end_page): - text += pdf_pages[page_num][0] - return text - -def get_text_of_pdf_pages_with_labels(pdf_pages, start_page, end_page): - text = "" - for page_num in range(start_page-1, end_page): - text += f"\n{pdf_pages[page_num][0]}\n\n" - return text - -def get_number_of_pages(pdf_path): - pdf_reader = PyPDF2.PdfReader(pdf_path) - num = len(pdf_reader.pages) - return num - - - -def post_processing(structure, end_physical_index): - # First convert page_number to start_index in flat list - for i, item in enumerate(structure): - item['start_index'] = item.get('physical_index') - if i < len(structure) - 1: - if structure[i + 1].get('appear_start') == 'yes': - item['end_index'] = structure[i + 1]['physical_index']-1 - else: - item['end_index'] = structure[i + 1]['physical_index'] - else: - item['end_index'] = end_physical_index - tree = list_to_tree(structure) - if len(tree)!=0: - return tree - else: - ### remove appear_start - for node in structure: - node.pop('appear_start', None) - node.pop('physical_index', None) - return structure - -def clean_structure_post(data): - if isinstance(data, dict): - data.pop('page_number', None) - data.pop('start_index', None) - data.pop('end_index', None) - if 'nodes' in data: - clean_structure_post(data['nodes']) - elif isinstance(data, list): - for section in data: - clean_structure_post(section) - return data - -def remove_fields(data, fields=['text']): - if isinstance(data, dict): - return {k: remove_fields(v, fields) - for k, v in data.items() if k not in fields} - elif isinstance(data, list): - return [remove_fields(item, fields) for item in data] - return data - -def print_toc(tree, indent=0): - for node in tree: - print(' ' * indent + node['title']) - if node.get('nodes'): - print_toc(node['nodes'], indent + 1) - -def print_json(data, max_len=40, indent=2): - def simplify_data(obj): - if isinstance(obj, dict): - return {k: simplify_data(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [simplify_data(item) for item in obj] - elif isinstance(obj, str) and len(obj) > max_len: - return obj[:max_len] + '...' - else: - return obj - - simplified = simplify_data(data) - print(json.dumps(simplified, indent=indent, ensure_ascii=False)) - - -def remove_structure_text(data): - if isinstance(data, dict): - data.pop('text', None) - if 'nodes' in data: - remove_structure_text(data['nodes']) - elif isinstance(data, list): - for item in data: - remove_structure_text(item) - return data - - -def check_token_limit(structure, limit=110000): - list = structure_to_list(structure) - for node in list: - num_tokens = count_tokens(node['text'], model=None) - if num_tokens > limit: - print(f"Node ID: {node['node_id']} has {num_tokens} tokens") - print("Start Index:", node['start_index']) - print("End Index:", node['end_index']) - print("Title:", node['title']) - print("\n") - - -def convert_physical_index_to_int(data): - if isinstance(data, list): - for i in range(len(data)): - # Check if item is a dictionary and has 'physical_index' key - if isinstance(data[i], dict) and 'physical_index' in data[i]: - if isinstance(data[i]['physical_index'], str): - if data[i]['physical_index'].startswith('').strip()) - elif data[i]['physical_index'].startswith('physical_index_'): - data[i]['physical_index'] = int(data[i]['physical_index'].split('_')[-1].strip()) - elif isinstance(data, str): - if data.startswith('').strip()) - elif data.startswith('physical_index_'): - data = int(data.split('_')[-1].strip()) - # Check data is int - if isinstance(data, int): - return data - else: - return None - return data - - -def convert_page_to_int(data): - for item in data: - if 'page' in item and isinstance(item['page'], str): - try: - item['page'] = int(item['page']) - except ValueError: - # Keep original value if conversion fails - pass - return data - - -def add_node_text(node, pdf_pages): - if isinstance(node, dict): - start_page = node.get('start_index') - end_page = node.get('end_index') - node['text'] = get_text_of_pdf_pages(pdf_pages, start_page, end_page) - if 'nodes' in node: - add_node_text(node['nodes'], pdf_pages) - elif isinstance(node, list): - for index in range(len(node)): - add_node_text(node[index], pdf_pages) - return - - -def add_node_text_with_labels(node, pdf_pages): - if isinstance(node, dict): - start_page = node.get('start_index') - end_page = node.get('end_index') - node['text'] = get_text_of_pdf_pages_with_labels(pdf_pages, start_page, end_page) - if 'nodes' in node: - add_node_text_with_labels(node['nodes'], pdf_pages) - elif isinstance(node, list): - for index in range(len(node)): - add_node_text_with_labels(node[index], pdf_pages) - return - - -async def generate_node_summary(node, model=None): - prompt = f"""You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document. - - Partial Document Text: {node['text']} - - Directly return the description, do not include any other text. - """ - response = await llm_acompletion(model, prompt) - return response - - -async def generate_summaries_for_structure(structure, model=None): - nodes = structure_to_list(structure) - tasks = [generate_node_summary(node, model=model) for node in nodes] - summaries = await asyncio.gather(*tasks) - - for node, summary in zip(nodes, summaries): - node['summary'] = summary - return structure - - -def create_clean_structure_for_description(structure): - """ - Create a clean structure for document description generation, - excluding unnecessary fields like 'text'. - """ - if isinstance(structure, dict): - clean_node = {} - # Only include essential fields for description - for key in ['title', 'node_id', 'summary', 'prefix_summary']: - if key in structure: - clean_node[key] = structure[key] - - # Recursively process child nodes - if 'nodes' in structure and structure['nodes']: - clean_node['nodes'] = create_clean_structure_for_description(structure['nodes']) - - return clean_node - elif isinstance(structure, list): - return [create_clean_structure_for_description(item) for item in structure] - else: - return structure - - -def generate_doc_description(structure, model=None): - prompt = f"""Your are an expert in generating descriptions for a document. - You are given a structure of a document. Your task is to generate a one-sentence description for the document, which makes it easy to distinguish the document from other documents. - - Document Structure: {structure} - - Directly return the description, do not include any other text. - """ - response = llm_completion(model, prompt) - return response - - -def reorder_dict(data, key_order): - if not key_order: - return data - return {key: data[key] for key in key_order if key in data} - - -def format_structure(structure, order=None): - if not order: - return structure - if isinstance(structure, dict): - if 'nodes' in structure: - structure['nodes'] = format_structure(structure['nodes'], order) - if not structure.get('nodes'): - structure.pop('nodes', None) - structure = reorder_dict(structure, order) - elif isinstance(structure, list): - structure = [format_structure(item, order) for item in structure] - return structure - - -class ConfigLoader: - def __init__(self, default_path: str = None): - if default_path is None: - default_path = Path(__file__).parent / "config.yaml" - self._default_dict = self._load_yaml(default_path) - - @staticmethod - def _load_yaml(path): - with open(path, "r", encoding="utf-8") as f: - return yaml.safe_load(f) or {} - - def _validate_keys(self, user_dict): - unknown_keys = set(user_dict) - set(self._default_dict) - if unknown_keys: - raise ValueError(f"Unknown config keys: {unknown_keys}") - - def load(self, user_opt=None) -> config: - """ - Load the configuration, merging user options with default values. - """ - if user_opt is None: - user_dict = {} - elif isinstance(user_opt, config): - user_dict = vars(user_opt) - elif isinstance(user_opt, dict): - user_dict = user_opt - else: - raise TypeError("user_opt must be dict, config(SimpleNamespace) or None") - - self._validate_keys(user_dict) - merged = {**self._default_dict, **user_dict} - return config(**merged) +from .core.llm import * +from .core.pdf import * +from .core.tree import * +from .core.logging import * diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..e9752821a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "pageindex" +version = "0.1.0" +description = "Vectorless, reasoning-based RAG indexer" +readme = "README.md" +requires-python = ">=3.9" +license = {text = "MIT"} +dependencies = [ + "openai==1.101.0", + "pymupdf==1.26.4", + "PyPDF2==3.0.1", + "python-dotenv==1.1.0", + "tiktoken==0.11.0", + "pyyaml==6.0.2", + "pydantic>=2.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.4.0", + "pytest-asyncio>=0.21.0", +] + +[project.scripts] +pageindex = "pageindex.cli:main" + +[tool.setuptools.packages.find] +where = ["."] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..94d322bfe --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,6 @@ +import pytest +import os +import sys + +# Add src to python path for testing +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 000000000..ffe4f23c8 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,39 @@ +import pytest +from types import SimpleNamespace +from pageindex.config import ConfigLoader, PageIndexConfig + +def test_config_loader_default(tmp_path): + # Mock config file + config_file = tmp_path / "config.yaml" + config_file.write_text('model: "gpt-4-test"\nmax_page_num_each_node: 10', encoding="utf-8") + + loader = ConfigLoader(default_path=config_file) + cfg = loader.load() + + assert isinstance(cfg, PageIndexConfig) + assert cfg.model == "gpt-4-test" + assert cfg.max_page_num_each_node == 10 + # Check default logic + assert cfg.toc_check_page_num == 3 + +def test_config_loader_override(): + loader = ConfigLoader(default_path=None) + override = {"model": "gpt-override", "if_add_node_id": False} + + cfg = loader.load(user_opt=override) + assert cfg.model == "gpt-override" + assert cfg.if_add_node_id is False + +def test_config_validation_error(): + loader = ConfigLoader(default_path=None) + # Pass invalid type for integer field + override = {"max_page_num_each_node": "not-an-int"} + + with pytest.raises(ValueError, match="Configuration validation failed"): + loader.load(user_opt=override) + +def test_partial_override_object(): + args = SimpleNamespace(model="cmd-model", other_arg=None) + loader = ConfigLoader(default_path=None) + cfg = loader.load(user_opt=args) + assert cfg.model == "cmd-model" diff --git a/tests/test_llm.py b/tests/test_llm.py new file mode 100644 index 000000000..c8feb3a35 --- /dev/null +++ b/tests/test_llm.py @@ -0,0 +1,21 @@ +import pytest +from pageindex.core.llm import extract_json, count_tokens + +def test_extract_json_basic(): + text = '{"key": "value"}' + assert extract_json(text) == {"key": "value"} + +def test_extract_json_with_markdown(): + text = 'Here is the json:\n```json\n{"key": "value"}\n```' + assert extract_json(text) == {"key": "value"} + +def test_extract_json_with_trailing_commas(): + # This might fail depending on implementation robustness, but let's see + text = '{"key": "value",}' + # Our implementation tries to fix this + assert extract_json(text) == {"key": "value"} + +def test_count_tokens(): + text = "Hello world" + # Basic check, exact number depends on encoding + assert count_tokens(text) > 0 diff --git a/tests/test_tree.py b/tests/test_tree.py new file mode 100644 index 000000000..defb67ceb --- /dev/null +++ b/tests/test_tree.py @@ -0,0 +1,36 @@ +import pytest +from pageindex.core.tree import list_to_tree, structure_to_list, get_nodes, write_node_id + +@pytest.fixture +def sample_structure(): + return [ + {"structure": "1", "title": "Chapter 1", "start_index": 1, "end_index": 5}, + {"structure": "1.1", "title": "Section 1.1", "start_index": 1, "end_index": 3}, + {"structure": "1.2", "title": "Section 1.2", "start_index": 4, "end_index": 5}, + {"structure": "2", "title": "Chapter 2", "start_index": 6, "end_index": 10} + ] + +def test_list_to_tree(sample_structure): + tree = list_to_tree(sample_structure) + assert len(tree) == 2 + assert tree[0]["title"] == "Chapter 1" + assert len(tree[0]["nodes"]) == 2 + assert tree[0]["nodes"][0]["title"] == "Section 1.1" + assert tree[1]["title"] == "Chapter 2" + assert "nodes" not in tree[1] or len(tree[1]["nodes"]) == 0 + +def test_structure_to_list(sample_structure): + tree = list_to_tree(sample_structure) + flat_list = structure_to_list(tree) + # Note: structure_to_list might not preserve original order exactly or might include container nodes + # But for our simple case it should be close. + assert len(flat_list) == 4 + titles = [item["title"] for item in flat_list] + assert "Chapter 1" in titles + assert "Section 1.1" in titles + +def test_write_node_id(sample_structure): + tree = list_to_tree(sample_structure) + write_node_id(tree) + assert tree[0]["node_id"] == "0000" + assert tree[0]["nodes"][0]["node_id"] == "0001"