Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 79 additions & 16 deletions adalflow/adalflow/optim/text_grad/tgd_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,62 @@ def get_output_format_str(self) -> str:

Make sure to include all three fields and properly close all XML tags."""

def _sanitize_xml(self, xml_str: str) -> str:
"""Sanitize XML string by removing invalid characters and fixing common issues."""
# Remove control characters except for tab, newline, and carriage return
xml_str = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', xml_str)

# Handle CDATA sections - extract content from CDATA
xml_str = re.sub(r'<!\[CDATA\[(.*?)\]\]>', lambda m: m.group(1), xml_str, flags=re.DOTALL)

return xml_str

def _extract_with_regex(self, input: str) -> TGDData:
"""Fallback extraction using regex when XML parsing fails."""
log.info("Using regex fallback for XML parsing")

# Sanitize input to handle CDATA in regex path too
sanitized = self._sanitize_xml(input)

# Try to extract content between tags using regex
reasoning_pattern = r'<reasoning>(.*?)</reasoning>'
method_pattern = r'<method>(.*?)</method>'
proposed_variable_pattern = r'<proposed_variable>(.*?)</proposed_variable>'

reasoning_match = re.search(reasoning_pattern, sanitized, re.DOTALL | re.IGNORECASE)
method_match = re.search(method_pattern, sanitized, re.DOTALL | re.IGNORECASE)
proposed_variable_match = re.search(proposed_variable_pattern, sanitized, re.DOTALL | re.IGNORECASE)

reasoning = reasoning_match.group(1).strip() if reasoning_match else ""
method = method_match.group(1).strip() if method_match else ""
proposed_variable = proposed_variable_match.group(1).strip() if proposed_variable_match else ""

# If we didn't find proposed_variable, try to extract anything that looks like content
if not proposed_variable and proposed_variable_match is None:
# Look for the last substantial text block as a fallback
lines = [line.strip() for line in sanitized.split('\n') if line.strip() and not line.strip().startswith('<')]
if lines:
proposed_variable = lines[-1]

log.info(f"Regex extraction - reasoning: {bool(reasoning)}, method: {bool(method)}, proposed_variable: {bool(proposed_variable)}")

return TGDData(
reasoning=reasoning if reasoning else "Extracted via regex fallback",
method=method if method else "regex extraction",
proposed_variable=proposed_variable if proposed_variable else input
)

def call(self, input: str) -> TGDData:
"""Parse the XML response and extract the three fields, returning TGDData directly."""
"""Parse the XML response and extract the three fields, returning TGDData directly.

This method implements robust XML parsing with the following features:
1. XML sanitization to remove invalid characters
2. CDATA section handling
3. Regex-based fallback for malformed XML
4. Graceful error recovery

Fixes issue #455: Make XML parsing more robust with regex fallback
"""
try:
# Clean the input and extract XML content
input = input.strip()
Expand All @@ -304,6 +358,9 @@ def call(self, input: str) -> TGDData:
else:
xml_content = input[start_idx:end_idx + len(end_tag)]

# Sanitize XML before parsing
xml_content = self._sanitize_xml(xml_content)

# Parse XML
root = ET.fromstring(xml_content)

Expand All @@ -312,9 +369,20 @@ def call(self, input: str) -> TGDData:
method_elem = root.find('method')
proposed_variable_elem = root.find('proposed_variable')

reasoning = reasoning_elem.text.strip() if reasoning_elem is not None and reasoning_elem.text else ""
method = method_elem.text.strip() if method_elem is not None and method_elem.text else ""
proposed_variable = proposed_variable_elem.text.strip() if proposed_variable_elem is not None and proposed_variable_elem.text else ""
# Handle text content including nested elements and CDATA
def get_element_text(elem):
if elem is None:
return ""
# Get all text content including from nested elements
text_parts = [elem.text or ""]
for child in elem:
text_parts.append(get_element_text(child))
text_parts.append(child.tail or "")
return "".join(text_parts).strip()

reasoning = get_element_text(reasoning_elem)
method = get_element_text(method_elem)
proposed_variable = get_element_text(proposed_variable_elem)

# Create and return TGDData object directly
return TGDData(
Expand All @@ -324,19 +392,14 @@ def call(self, input: str) -> TGDData:
)

except ET.ParseError as e:
log.error(f"XML parsing error: {e}")
return TGDData(
reasoning="XML parsing failed",
method="Error",
proposed_variable=input
)
log.warning(f"XML parsing error: {e}, attempting regex fallback")
# Use regex-based extraction as fallback
return self._extract_with_regex(input)

except Exception as e:
log.error(f"Error parsing XML output: {e}")
return TGDData(
reasoning="Parsing failed",
method="Error",
proposed_variable=input
)
log.error(f"Error parsing XML output: {e}, attempting regex fallback")
# Use regex-based extraction as fallback
return self._extract_with_regex(input)


new_variable_tags = ["<VARIABLE>", "</VARIABLE>"]
Expand Down
Loading