diff --git a/adalflow/adalflow/optim/text_grad/tgd_optimizer.py b/adalflow/adalflow/optim/text_grad/tgd_optimizer.py
index b052b739e..a77749da6 100644
--- a/adalflow/adalflow/optim/text_grad/tgd_optimizer.py
+++ b/adalflow/adalflow/optim/text_grad/tgd_optimizer.py
@@ -285,8 +285,62 @@ def get_output_format_str(self) -> str:
Make sure to include all three fields and properly close all XML tags."""
+ def _sanitize_xml(self, xml_str: str) -> str:
+ """Sanitize XML string by removing invalid characters and fixing common issues."""
+ # Remove control characters except for tab, newline, and carriage return
+ xml_str = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', xml_str)
+
+ # Handle CDATA sections - extract content from CDATA
+ xml_str = re.sub(r'', lambda m: m.group(1), xml_str, flags=re.DOTALL)
+
+ return xml_str
+
+ def _extract_with_regex(self, input: str) -> TGDData:
+ """Fallback extraction using regex when XML parsing fails."""
+ log.info("Using regex fallback for XML parsing")
+
+ # Sanitize input to handle CDATA in regex path too
+ sanitized = self._sanitize_xml(input)
+
+ # Try to extract content between tags using regex
+ reasoning_pattern = r'(.*?)'
+ method_pattern = r'(.*?)'
+ proposed_variable_pattern = r'(.*?)'
+
+ reasoning_match = re.search(reasoning_pattern, sanitized, re.DOTALL | re.IGNORECASE)
+ method_match = re.search(method_pattern, sanitized, re.DOTALL | re.IGNORECASE)
+ proposed_variable_match = re.search(proposed_variable_pattern, sanitized, re.DOTALL | re.IGNORECASE)
+
+ reasoning = reasoning_match.group(1).strip() if reasoning_match else ""
+ method = method_match.group(1).strip() if method_match else ""
+ proposed_variable = proposed_variable_match.group(1).strip() if proposed_variable_match else ""
+
+ # If we didn't find proposed_variable, try to extract anything that looks like content
+ if not proposed_variable and proposed_variable_match is None:
+ # Look for the last substantial text block as a fallback
+ lines = [line.strip() for line in sanitized.split('\n') if line.strip() and not line.strip().startswith('<')]
+ if lines:
+ proposed_variable = lines[-1]
+
+ log.info(f"Regex extraction - reasoning: {bool(reasoning)}, method: {bool(method)}, proposed_variable: {bool(proposed_variable)}")
+
+ return TGDData(
+ reasoning=reasoning if reasoning else "Extracted via regex fallback",
+ method=method if method else "regex extraction",
+ proposed_variable=proposed_variable if proposed_variable else input
+ )
+
def call(self, input: str) -> TGDData:
- """Parse the XML response and extract the three fields, returning TGDData directly."""
+ """Parse the XML response and extract the three fields, returning TGDData directly.
+
+ This method implements robust XML parsing with the following features:
+ 1. XML sanitization to remove invalid characters
+ 2. CDATA section handling
+ 3. Regex-based fallback for malformed XML
+ 4. Graceful error recovery
+
+ Fixes issue #455: Make XML parsing more robust with regex fallback
+ """
try:
# Clean the input and extract XML content
input = input.strip()
@@ -304,6 +358,9 @@ def call(self, input: str) -> TGDData:
else:
xml_content = input[start_idx:end_idx + len(end_tag)]
+ # Sanitize XML before parsing
+ xml_content = self._sanitize_xml(xml_content)
+
# Parse XML
root = ET.fromstring(xml_content)
@@ -312,9 +369,20 @@ def call(self, input: str) -> TGDData:
method_elem = root.find('method')
proposed_variable_elem = root.find('proposed_variable')
- reasoning = reasoning_elem.text.strip() if reasoning_elem is not None and reasoning_elem.text else ""
- method = method_elem.text.strip() if method_elem is not None and method_elem.text else ""
- proposed_variable = proposed_variable_elem.text.strip() if proposed_variable_elem is not None and proposed_variable_elem.text else ""
+ # Handle text content including nested elements and CDATA
+ def get_element_text(elem):
+ if elem is None:
+ return ""
+ # Get all text content including from nested elements
+ text_parts = [elem.text or ""]
+ for child in elem:
+ text_parts.append(get_element_text(child))
+ text_parts.append(child.tail or "")
+ return "".join(text_parts).strip()
+
+ reasoning = get_element_text(reasoning_elem)
+ method = get_element_text(method_elem)
+ proposed_variable = get_element_text(proposed_variable_elem)
# Create and return TGDData object directly
return TGDData(
@@ -324,19 +392,14 @@ def call(self, input: str) -> TGDData:
)
except ET.ParseError as e:
- log.error(f"XML parsing error: {e}")
- return TGDData(
- reasoning="XML parsing failed",
- method="Error",
- proposed_variable=input
- )
+ log.warning(f"XML parsing error: {e}, attempting regex fallback")
+ # Use regex-based extraction as fallback
+ return self._extract_with_regex(input)
+
except Exception as e:
- log.error(f"Error parsing XML output: {e}")
- return TGDData(
- reasoning="Parsing failed",
- method="Error",
- proposed_variable=input
- )
+ log.error(f"Error parsing XML output: {e}, attempting regex fallback")
+ # Use regex-based extraction as fallback
+ return self._extract_with_regex(input)
new_variable_tags = ["", ""]