diff --git a/adalflow/adalflow/optim/text_grad/tgd_optimizer.py b/adalflow/adalflow/optim/text_grad/tgd_optimizer.py index b052b739e..a77749da6 100644 --- a/adalflow/adalflow/optim/text_grad/tgd_optimizer.py +++ b/adalflow/adalflow/optim/text_grad/tgd_optimizer.py @@ -285,8 +285,62 @@ def get_output_format_str(self) -> str: Make sure to include all three fields and properly close all XML tags.""" + def _sanitize_xml(self, xml_str: str) -> str: + """Sanitize XML string by removing invalid characters and fixing common issues.""" + # Remove control characters except for tab, newline, and carriage return + xml_str = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', xml_str) + + # Handle CDATA sections - extract content from CDATA + xml_str = re.sub(r'', lambda m: m.group(1), xml_str, flags=re.DOTALL) + + return xml_str + + def _extract_with_regex(self, input: str) -> TGDData: + """Fallback extraction using regex when XML parsing fails.""" + log.info("Using regex fallback for XML parsing") + + # Sanitize input to handle CDATA in regex path too + sanitized = self._sanitize_xml(input) + + # Try to extract content between tags using regex + reasoning_pattern = r'(.*?)' + method_pattern = r'(.*?)' + proposed_variable_pattern = r'(.*?)' + + reasoning_match = re.search(reasoning_pattern, sanitized, re.DOTALL | re.IGNORECASE) + method_match = re.search(method_pattern, sanitized, re.DOTALL | re.IGNORECASE) + proposed_variable_match = re.search(proposed_variable_pattern, sanitized, re.DOTALL | re.IGNORECASE) + + reasoning = reasoning_match.group(1).strip() if reasoning_match else "" + method = method_match.group(1).strip() if method_match else "" + proposed_variable = proposed_variable_match.group(1).strip() if proposed_variable_match else "" + + # If we didn't find proposed_variable, try to extract anything that looks like content + if not proposed_variable and proposed_variable_match is None: + # Look for the last substantial text block as a fallback + lines = [line.strip() for line in sanitized.split('\n') if line.strip() and not line.strip().startswith('<')] + if lines: + proposed_variable = lines[-1] + + log.info(f"Regex extraction - reasoning: {bool(reasoning)}, method: {bool(method)}, proposed_variable: {bool(proposed_variable)}") + + return TGDData( + reasoning=reasoning if reasoning else "Extracted via regex fallback", + method=method if method else "regex extraction", + proposed_variable=proposed_variable if proposed_variable else input + ) + def call(self, input: str) -> TGDData: - """Parse the XML response and extract the three fields, returning TGDData directly.""" + """Parse the XML response and extract the three fields, returning TGDData directly. + + This method implements robust XML parsing with the following features: + 1. XML sanitization to remove invalid characters + 2. CDATA section handling + 3. Regex-based fallback for malformed XML + 4. Graceful error recovery + + Fixes issue #455: Make XML parsing more robust with regex fallback + """ try: # Clean the input and extract XML content input = input.strip() @@ -304,6 +358,9 @@ def call(self, input: str) -> TGDData: else: xml_content = input[start_idx:end_idx + len(end_tag)] + # Sanitize XML before parsing + xml_content = self._sanitize_xml(xml_content) + # Parse XML root = ET.fromstring(xml_content) @@ -312,9 +369,20 @@ def call(self, input: str) -> TGDData: method_elem = root.find('method') proposed_variable_elem = root.find('proposed_variable') - reasoning = reasoning_elem.text.strip() if reasoning_elem is not None and reasoning_elem.text else "" - method = method_elem.text.strip() if method_elem is not None and method_elem.text else "" - proposed_variable = proposed_variable_elem.text.strip() if proposed_variable_elem is not None and proposed_variable_elem.text else "" + # Handle text content including nested elements and CDATA + def get_element_text(elem): + if elem is None: + return "" + # Get all text content including from nested elements + text_parts = [elem.text or ""] + for child in elem: + text_parts.append(get_element_text(child)) + text_parts.append(child.tail or "") + return "".join(text_parts).strip() + + reasoning = get_element_text(reasoning_elem) + method = get_element_text(method_elem) + proposed_variable = get_element_text(proposed_variable_elem) # Create and return TGDData object directly return TGDData( @@ -324,19 +392,14 @@ def call(self, input: str) -> TGDData: ) except ET.ParseError as e: - log.error(f"XML parsing error: {e}") - return TGDData( - reasoning="XML parsing failed", - method="Error", - proposed_variable=input - ) + log.warning(f"XML parsing error: {e}, attempting regex fallback") + # Use regex-based extraction as fallback + return self._extract_with_regex(input) + except Exception as e: - log.error(f"Error parsing XML output: {e}") - return TGDData( - reasoning="Parsing failed", - method="Error", - proposed_variable=input - ) + log.error(f"Error parsing XML output: {e}, attempting regex fallback") + # Use regex-based extraction as fallback + return self._extract_with_regex(input) new_variable_tags = ["", ""]