diff --git a/examples/documents/test_issue_163.py b/examples/documents/test_issue_163.py new file mode 100644 index 000000000..517892a15 --- /dev/null +++ b/examples/documents/test_issue_163.py @@ -0,0 +1,135 @@ +import pytest +import sys +import os +from unittest.mock import patch, MagicMock + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from pageindex.page_index import ( + check_if_toc_extraction_is_complete, + check_if_toc_transformation_is_complete, + toc_detector_single_page, + detect_page_index, + extract_toc_content, + toc_transformer, +) + + +class TestRobustKeyAccess: + @patch("pageindex.page_index.llm_completion", return_value="") + def test_toc_detector_empty_response(self, mock_llm): + result = toc_detector_single_page("some content", model="test") + assert result == "no" + + @patch("pageindex.page_index.llm_completion", return_value='{"toc_detected": "yes"}') + def test_toc_detector_valid_response(self, mock_llm): + result = toc_detector_single_page("some content", model="test") + assert result == "yes" + + @patch("pageindex.page_index.llm_completion", return_value="not json at all") + def test_toc_detector_malformed_response(self, mock_llm): + result = toc_detector_single_page("some content", model="test") + assert result == "no" + + @patch("pageindex.page_index.llm_completion", return_value="") + def test_extraction_complete_empty_response(self, mock_llm): + result = check_if_toc_extraction_is_complete("doc", "toc", model="test") + assert result == "no" + + @patch("pageindex.page_index.llm_completion", return_value='{"completed": "yes"}') + def test_extraction_complete_valid_response(self, mock_llm): + result = check_if_toc_extraction_is_complete("doc", "toc", model="test") + assert result == "yes" + + @patch("pageindex.page_index.llm_completion", return_value="") + def test_transformation_complete_empty_response(self, mock_llm): + result = check_if_toc_transformation_is_complete("raw", "cleaned", model="test") + assert result == "no" + + @patch("pageindex.page_index.llm_completion", return_value='{"thinking": "looks fine", "completed": "yes"}') + def test_transformation_complete_valid_response(self, mock_llm): + result = check_if_toc_transformation_is_complete("raw", "cleaned", model="test") + assert result == "yes" + + @patch("pageindex.page_index.llm_completion", return_value="") + def test_detect_page_index_empty_response(self, mock_llm): + result = detect_page_index("toc text", model="test") + assert result == "no" + + +class TestExtractTocContentRetryLoop: + @patch("pageindex.page_index.check_if_toc_transformation_is_complete") + @patch("pageindex.page_index.llm_completion") + def test_completes_on_first_try(self, mock_llm, mock_check): + mock_llm.return_value = ("full toc content", "finished") + mock_check.return_value = "yes" + result = extract_toc_content("raw content", model="test") + assert result == "full toc content" + assert mock_llm.call_count == 1 + + @patch("pageindex.page_index.check_if_toc_transformation_is_complete") + @patch("pageindex.page_index.llm_completion") + def test_continues_on_incomplete(self, mock_llm, mock_check): + mock_llm.side_effect = [ + ("partial toc", "max_output_reached"), + (" continued toc", "finished"), + ] + mock_check.side_effect = ["no", "yes"] + result = extract_toc_content("raw content", model="test") + assert result == "partial toc continued toc" + assert mock_llm.call_count == 2 + + @patch("pageindex.page_index.check_if_toc_transformation_is_complete") + @patch("pageindex.page_index.llm_completion") + def test_max_retries_raises_exception(self, mock_llm, mock_check): + mock_llm.return_value = ("chunk", "max_output_reached") + mock_check.return_value = "no" + with pytest.raises(Exception, match="Failed to complete table of contents extraction"): + extract_toc_content("raw content", model="test") + assert mock_llm.call_count == 6 + + @patch("pageindex.page_index.check_if_toc_transformation_is_complete") + @patch("pageindex.page_index.llm_completion") + def test_chat_history_grows_incrementally(self, mock_llm, mock_check): + call_count = [0] + + def side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return ("initial", "max_output_reached") + if call_count[0] == 2: + history = kwargs.get("chat_history", []) + assert len(history) == 2 + return (" part2", "max_output_reached") + if call_count[0] == 3: + history = kwargs.get("chat_history", []) + assert len(history) == 4 + return (" part3", "finished") + return ("", "finished") + + mock_llm.side_effect = side_effect + mock_check.side_effect = ["no", "no", "yes"] + result = extract_toc_content("raw content", model="test") + assert result == "initial part2 part3" + + +class TestTocTransformerRetryLoop: + @patch("pageindex.page_index.check_if_toc_transformation_is_complete") + @patch("pageindex.page_index.llm_completion") + def test_completes_on_first_try(self, mock_llm, mock_check): + mock_llm.return_value = ( + '{"table_of_contents": [{"structure": "1", "title": "Intro", "page": 1}]}', + "finished", + ) + mock_check.return_value = "yes" + result = toc_transformer("raw toc", model="test") + assert len(result) == 1 + assert result[0]["title"] == "Intro" + + @patch("pageindex.page_index.check_if_toc_transformation_is_complete") + @patch("pageindex.page_index.llm_completion") + def test_handles_missing_table_of_contents_key(self, mock_llm, mock_check): + mock_llm.return_value = ('{"other_key": "value"}', "finished") + mock_check.return_value = "yes" + result = toc_transformer("raw toc", model="test") + assert result == [] diff --git a/examples/documents/test_multi_doc.py b/examples/documents/test_multi_doc.py new file mode 100644 index 000000000..5a30c8f6b --- /dev/null +++ b/examples/documents/test_multi_doc.py @@ -0,0 +1,78 @@ +import json +import pytest +from unittest.mock import patch, MagicMock +from pageindex.retrieve import get_document, get_document_structure, get_page_content + +@pytest.fixture +def mock_documents(): + return { + "doc1": { + "id": "doc1", + "path": "test1.pdf", + "type": "pdf", + "doc_name": "Document 1", + "doc_description": "First test doc", + "structure": [{"title": "Section 1", "page": 1, "text": "Content 1"}] + }, + "doc2": { + "id": "doc2", + "path": "test2.md", + "type": "md", + "doc_name": "Document 2", + "doc_description": "Second test doc", + "structure": [{"title": "Header 2", "line_num": 1, "text": "Content 2"}] + } + } + +def test_get_document_multi(mock_documents): + with patch("pageindex.retrieve._count_pages", return_value=5): + result_json = get_document(mock_documents, ["doc1", "doc2"]) + result = json.loads(result_json) + + assert "doc1" in result + assert "doc2" in result + assert result["doc1"]["doc_name"] == "Document 1" + assert result["doc1"]["page_count"] == 5 + assert result["doc2"]["line_count"] == 5 + +def test_get_document_structure_multi(mock_documents): + result_json = get_document_structure(mock_documents, ["doc1", "doc2"]) + result = json.loads(result_json) + + assert "doc1" in result + assert "doc2" in result + # Verify text field is removed + assert "text" not in result["doc1"][0] + assert result["doc1"][0]["title"] == "Section 1" + +def test_get_page_content_multi(mock_documents): + with patch("pageindex.retrieve._get_pdf_page_content", return_value=[{"page": 1, "content": "PDF Content"}]), \ + patch("pageindex.retrieve._get_md_page_content", return_value=[{"page": 1, "content": "MD Content"}]): + + result_json = get_page_content(mock_documents, ["doc1", "doc2"], "1") + result = json.loads(result_json) + + assert "doc1" in result + assert "doc2" in result + assert result["doc1"][0]["content"] == "PDF Content" + assert result["doc2"][0]["content"] == "MD Content" + +def test_get_document_multi_with_invalid_id(mock_documents): + with patch("pageindex.retrieve._count_pages", return_value=5): + result_json = get_document(mock_documents, ["doc1", "invalid-id"]) + result = json.loads(result_json) + + assert "doc1" in result + assert "invalid-id" in result + assert "error" in result["invalid-id"] + assert "not found" in result["invalid-id"]["error"] + +def test_backward_compatibility(mock_documents): + # Single doc_id as string should return flat result, not nested + with patch("pageindex.retrieve._count_pages", return_value=5): + result_json = get_document(mock_documents, "doc1") + result = json.loads(result_json) + + assert "doc_id" in result + assert result["doc_id"] == "doc1" + assert "doc1" not in result # Should not be nested diff --git a/pageindex/client.py b/pageindex/client.py index 894dab181..d79064809 100644 --- a/pageindex/client.py +++ b/pageindex/client.py @@ -217,18 +217,26 @@ def _ensure_doc_loaded(self, doc_id: str): if full.get('pages'): doc['pages'] = full['pages'] - def get_document(self, doc_id: str) -> str: - """Return document metadata JSON.""" + def get_document(self, doc_id) -> str: + """Return document metadata JSON. doc_id can be a string or a list of strings.""" return get_document(self.documents, doc_id) - def get_document_structure(self, doc_id: str) -> str: - """Return document tree structure JSON (without text fields).""" + def get_document_structure(self, doc_id) -> str: + """Return document tree structure JSON (without text fields). doc_id can be a string or a list of strings.""" if self.workspace: - self._ensure_doc_loaded(doc_id) + if isinstance(doc_id, list): + for d_id in doc_id: + self._ensure_doc_loaded(d_id) + else: + self._ensure_doc_loaded(doc_id) return get_document_structure(self.documents, doc_id) - def get_page_content(self, doc_id: str, pages: str) -> str: - """Return page content for the given pages string (e.g. '5-7', '3,8', '12').""" + def get_page_content(self, doc_id, pages: str) -> str: + """Return page content for the given pages string (e.g. '5-7', '3,8', '12'). doc_id can be a string or a list of strings.""" if self.workspace: - self._ensure_doc_loaded(doc_id) + if isinstance(doc_id, list): + for d_id in doc_id: + self._ensure_doc_loaded(d_id) + else: + self._ensure_doc_loaded(doc_id) return get_page_content(self.documents, doc_id, pages) diff --git a/pageindex/page_index.py b/pageindex/page_index.py index 9004309fb..4fe078690 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -117,9 +117,8 @@ def toc_detector_single_page(content, model=None): Please note: abstract,summary, notation list, figure list, table list, etc. are not table of contents.""" response = llm_completion(model=model, prompt=prompt) - # print('response', response) json_content = extract_json(response) - return json_content['toc_detected'] + return json_content.get('toc_detected', 'no') def check_if_toc_extraction_is_complete(content, toc, model=None): @@ -137,7 +136,7 @@ def check_if_toc_extraction_is_complete(content, toc, model=None): prompt = prompt + '\n Document:\n' + content + '\n Table of contents:\n' + toc response = llm_completion(model=model, prompt=prompt) json_content = extract_json(response) - return json_content['completed'] + return json_content.get('completed', 'no') def check_if_toc_transformation_is_complete(content, toc, model=None): @@ -155,7 +154,7 @@ def check_if_toc_transformation_is_complete(content, toc, model=None): prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc response = llm_completion(model=model, prompt=prompt) json_content = extract_json(response) - return json_content['completed'] + return json_content.get('completed', 'no') def extract_toc_content(content, model=None): prompt = f""" @@ -175,27 +174,19 @@ def extract_toc_content(content, model=None): {"role": "user", "content": prompt}, {"role": "assistant", "content": response}, ] - prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" - new_response, finish_reason = llm_completion(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True) - response = response + new_response - if_complete = check_if_toc_transformation_is_complete(content, response, model) + continue_prompt = "please continue the generation of table of contents, directly output the remaining part of the structure" - attempt = 0 max_attempts = 5 - - while not (if_complete == "yes" and finish_reason == "finished"): - attempt += 1 - if attempt > max_attempts: - raise Exception('Failed to complete table of contents after maximum retries') - - chat_history = [ - {"role": "user", "content": prompt}, - {"role": "assistant", "content": response}, - ] - prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" - new_response, finish_reason = llm_completion(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True) + for attempt in range(max_attempts): + new_response, finish_reason = llm_completion(model=model, prompt=continue_prompt, chat_history=chat_history, return_finish_reason=True) response = response + new_response + chat_history.append({"role": "user", "content": continue_prompt}) + chat_history.append({"role": "assistant", "content": new_response}) if_complete = check_if_toc_transformation_is_complete(content, response, model) + if if_complete == "yes" and finish_reason == "finished": + break + else: + raise Exception('Failed to complete table of contents extraction after maximum retries') return response @@ -217,7 +208,7 @@ def detect_page_index(toc_content, model=None): response = llm_completion(model=model, prompt=prompt) json_content = extract_json(response) - return json_content['page_index_given_in_toc'] + return json_content.get('page_index_given_in_toc', 'no') def toc_extractor(page_list, toc_page_list, model): def transform_dots_to_colon(text): @@ -296,43 +287,41 @@ def toc_transformer(toc_content, model=None): if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) if if_complete == "yes" and finish_reason == "finished": last_complete = extract_json(last_complete) - cleaned_response=convert_page_to_int(last_complete['table_of_contents']) + cleaned_response = convert_page_to_int(last_complete.get('table_of_contents', [])) return cleaned_response last_complete = get_json_content(last_complete) - attempt = 0 - max_attempts = 5 - while not (if_complete == "yes" and finish_reason == "finished"): - attempt += 1 - if attempt > max_attempts: - raise Exception('Failed to complete toc transformation after maximum retries') - position = last_complete.rfind('}') - if position != -1: - last_complete = last_complete[:position+2] - prompt = f""" - Your task is to continue the table of contents json structure, directly output the remaining part of the json structure. - The response should be in the following JSON format: - - The raw table of contents json structure is: - {toc_content} + chat_history = [ + {"role": "user", "content": prompt}, + {"role": "assistant", "content": last_complete}, + ] + continue_prompt = "Please continue the table of contents JSON structure from where you left off. Directly output only the remaining part." - The incomplete transformed table of contents json structure is: - {last_complete} + position = last_complete.rfind('}') + if position != -1: + last_complete = last_complete[:position+2] - Please continue the json structure, directly output the remaining part of the json structure.""" + max_attempts = 5 + for attempt in range(max_attempts): - new_complete, finish_reason = llm_completion(model=model, prompt=prompt, return_finish_reason=True) + new_complete, finish_reason = llm_completion(model=model, prompt=continue_prompt, chat_history=chat_history, return_finish_reason=True) if new_complete.startswith('```json'): - new_complete = get_json_content(new_complete) - last_complete = last_complete+new_complete + new_complete = get_json_content(new_complete) + last_complete = last_complete + new_complete + + chat_history.append({"role": "user", "content": continue_prompt}) + chat_history.append({"role": "assistant", "content": new_complete}) if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) - + if if_complete == "yes" and finish_reason == "finished": + break + else: + raise Exception('Failed to complete TOC transformation after maximum retries') last_complete = extract_json(last_complete) - cleaned_response=convert_page_to_int(last_complete['table_of_contents']) + cleaned_response = convert_page_to_int(last_complete.get('table_of_contents', [])) return cleaned_response @@ -753,7 +742,10 @@ async def single_toc_item_index_fixer(section_title, content, model=None): prompt = toc_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content response = await llm_acompletion(model=model, prompt=prompt) json_content = extract_json(response) - return convert_physical_index_to_int(json_content['physical_index']) + physical_index = json_content.get('physical_index') + if physical_index is None: + return None + return convert_physical_index_to_int(physical_index) diff --git a/pageindex/retrieve.py b/pageindex/retrieve.py index 55c38509c..6057155e8 100644 --- a/pageindex/retrieve.py +++ b/pageindex/retrieve.py @@ -78,8 +78,29 @@ def _traverse(nodes): # ── Tool functions ──────────────────────────────────────────────────────────── -def get_document(documents: dict, doc_id: str) -> str: - """Return JSON with document metadata: doc_id, doc_name, doc_description, type, status, page_count (PDF) or line_count (Markdown).""" +def get_document(documents: dict, doc_id) -> str: + """Return JSON with document metadata. doc_id can be a string or a list of strings.""" + if isinstance(doc_id, list): + results = {} + for d_id in doc_id: + doc_info = documents.get(d_id) + if not doc_info: + results[d_id] = {'error': f'Document {d_id} not found'} + continue + res = { + 'doc_id': d_id, + 'doc_name': doc_info.get('doc_name', ''), + 'doc_description': doc_info.get('doc_description', ''), + 'type': doc_info.get('type', ''), + 'status': 'completed', + } + if doc_info.get('type') == 'pdf': + res['page_count'] = _count_pages(doc_info) + else: + res['line_count'] = _count_pages(doc_info) + results[d_id] = res + return json.dumps(results, ensure_ascii=False) + doc_info = documents.get(doc_id) if not doc_info: return json.dumps({'error': f'Document {doc_id} not found'}) @@ -97,8 +118,19 @@ def get_document(documents: dict, doc_id: str) -> str: return json.dumps(result) -def get_document_structure(documents: dict, doc_id: str) -> str: - """Return tree structure JSON with text fields removed (saves tokens).""" +def get_document_structure(documents: dict, doc_id) -> str: + """Return tree structure JSON with text fields removed. doc_id can be a string or a list of strings.""" + if isinstance(doc_id, list): + results = {} + for d_id in doc_id: + doc_info = documents.get(d_id) + if not doc_info: + results[d_id] = {'error': f'Document {d_id} not found'} + continue + structure = doc_info.get('structure', []) + results[d_id] = remove_fields(structure, fields=['text']) + return json.dumps(results, ensure_ascii=False) + doc_info = documents.get(doc_id) if not doc_info: return json.dumps({'error': f'Document {doc_id} not found'}) @@ -107,16 +139,29 @@ def get_document_structure(documents: dict, doc_id: str) -> str: return json.dumps(structure_no_text, ensure_ascii=False) -def get_page_content(documents: dict, doc_id: str, pages: str) -> str: +def get_page_content(documents: dict, doc_id, pages: str) -> str: """ - Retrieve page content for a document. - - pages format: '5-7', '3,8', or '12' - For PDF: pages are physical page numbers (1-indexed). - For Markdown: pages are line numbers corresponding to node headers. - - Returns JSON list of {'page': int, 'content': str}. + Retrieve page content for document(s). doc_id can be a string or a list of strings. + ... (rest of originally provided docstring) """ + if isinstance(doc_id, list): + results = {} + for d_id in doc_id: + doc_info = documents.get(d_id) + if not doc_info: + results[d_id] = {'error': f'Document {d_id} not found'} + continue + try: + page_nums = _parse_pages(pages) + if doc_info.get('type') == 'pdf': + content = _get_pdf_page_content(doc_info, page_nums) + else: + content = _get_md_page_content(doc_info, page_nums) + results[d_id] = content + except Exception as e: + results[d_id] = {'error': f'Failed to read page content for {d_id}: {e}'} + return json.dumps(results, ensure_ascii=False) + doc_info = documents.get(doc_id) if not doc_info: return json.dumps({'error': f'Document {doc_id} not found'})