fix(mistral-parser): handle nested JSON in fallback extraction (#2335)
fix(mistral-parser): handle nested JSON in fallback extraction
This commit is contained in:
commit
fff7203049
2 changed files with 82 additions and 14 deletions
|
|
@ -10,7 +10,6 @@ The [TOOL_CALLS] token is the bot_token used by Mistral models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
@ -42,9 +41,6 @@ class MistralToolCallParser(ToolCallParser):
|
||||||
# The [TOOL_CALLS] token -- may appear as different strings depending on tokenizer
|
# The [TOOL_CALLS] token -- may appear as different strings depending on tokenizer
|
||||||
BOT_TOKEN = "[TOOL_CALLS]"
|
BOT_TOKEN = "[TOOL_CALLS]"
|
||||||
|
|
||||||
# Fallback regex for pre-v11 format when JSON parsing fails
|
|
||||||
TOOL_CALL_REGEX = re.compile(r"\[?\s*(\{.*?\})\s*\]?", re.DOTALL)
|
|
||||||
|
|
||||||
def parse(self, text: str) -> ParseResult:
|
def parse(self, text: str) -> ParseResult:
|
||||||
if self.BOT_TOKEN not in text:
|
if self.BOT_TOKEN not in text:
|
||||||
return text, None
|
return text, None
|
||||||
|
|
@ -71,6 +67,13 @@ class MistralToolCallParser(ToolCallParser):
|
||||||
tool_name = raw[:brace_idx].strip()
|
tool_name = raw[:brace_idx].strip()
|
||||||
args_str = raw[brace_idx:]
|
args_str = raw[brace_idx:]
|
||||||
|
|
||||||
|
# Validate and clean the JSON arguments
|
||||||
|
try:
|
||||||
|
parsed_args = json.loads(args_str)
|
||||||
|
args_str = json.dumps(parsed_args, ensure_ascii=False)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass # Keep raw if parsing fails
|
||||||
|
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ChatCompletionMessageToolCall(
|
ChatCompletionMessageToolCall(
|
||||||
id=_generate_mistral_id(),
|
id=_generate_mistral_id(),
|
||||||
|
|
@ -100,13 +103,14 @@ class MistralToolCallParser(ToolCallParser):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# Fallback regex extraction
|
# Fallback: extract JSON objects using raw_decode
|
||||||
match = self.TOOL_CALL_REGEX.findall(first_raw)
|
decoder = json.JSONDecoder()
|
||||||
if match:
|
idx = 0
|
||||||
for raw_json in match:
|
while idx < len(first_raw):
|
||||||
try:
|
try:
|
||||||
tc = json.loads(raw_json)
|
obj, end_idx = decoder.raw_decode(first_raw, idx)
|
||||||
args = tc.get("arguments", {})
|
if isinstance(obj, dict) and "name" in obj:
|
||||||
|
args = obj.get("arguments", {})
|
||||||
if isinstance(args, dict):
|
if isinstance(args, dict):
|
||||||
args = json.dumps(args, ensure_ascii=False)
|
args = json.dumps(args, ensure_ascii=False)
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
|
|
@ -114,12 +118,13 @@ class MistralToolCallParser(ToolCallParser):
|
||||||
id=_generate_mistral_id(),
|
id=_generate_mistral_id(),
|
||||||
type="function",
|
type="function",
|
||||||
function=Function(
|
function=Function(
|
||||||
name=tc["name"], arguments=args
|
name=obj["name"], arguments=args
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except (json.JSONDecodeError, KeyError):
|
idx = end_idx
|
||||||
continue
|
except json.JSONDecodeError:
|
||||||
|
idx += 1
|
||||||
|
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
return text, None
|
return text, None
|
||||||
|
|
|
||||||
|
|
@ -209,3 +209,66 @@ class TestDeepSeekV3Parser:
|
||||||
content, tool_calls = parser.parse(text)
|
content, tool_calls = parser.parse(text)
|
||||||
assert tool_calls is not None
|
assert tool_calls is not None
|
||||||
assert len(tool_calls) == 1
|
assert len(tool_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Mistral parser tests ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestMistralParser:
|
||||||
|
@pytest.fixture
|
||||||
|
def parser(self):
|
||||||
|
return get_parser("mistral")
|
||||||
|
|
||||||
|
def test_no_tool_call(self, parser):
|
||||||
|
text = "Hello, how can I help you?"
|
||||||
|
content, tool_calls = parser.parse(text)
|
||||||
|
assert content == text
|
||||||
|
assert tool_calls is None
|
||||||
|
|
||||||
|
def test_pre_v11_single_tool_call(self, parser):
|
||||||
|
text = '[TOOL_CALLS] [{"name": "func", "arguments": {"key": "val"}}]'
|
||||||
|
content, tool_calls = parser.parse(text)
|
||||||
|
assert tool_calls is not None
|
||||||
|
assert len(tool_calls) == 1
|
||||||
|
assert tool_calls[0].function.name == "func"
|
||||||
|
args = json.loads(tool_calls[0].function.arguments)
|
||||||
|
assert args["key"] == "val"
|
||||||
|
|
||||||
|
def test_pre_v11_nested_json(self, parser):
|
||||||
|
text = '[TOOL_CALLS] [{"name": "func", "arguments": {"nested": {"deep": true}}}]'
|
||||||
|
content, tool_calls = parser.parse(text)
|
||||||
|
assert tool_calls is not None
|
||||||
|
assert len(tool_calls) == 1
|
||||||
|
assert tool_calls[0].function.name == "func"
|
||||||
|
args = json.loads(tool_calls[0].function.arguments)
|
||||||
|
assert args["nested"]["deep"] is True
|
||||||
|
|
||||||
|
def test_v11_single_tool_call(self, parser):
|
||||||
|
text = '[TOOL_CALLS]get_weather{"city": "London"}'
|
||||||
|
content, tool_calls = parser.parse(text)
|
||||||
|
assert tool_calls is not None
|
||||||
|
assert len(tool_calls) == 1
|
||||||
|
assert tool_calls[0].function.name == "get_weather"
|
||||||
|
args = json.loads(tool_calls[0].function.arguments)
|
||||||
|
assert args["city"] == "London"
|
||||||
|
|
||||||
|
def test_v11_multiple_tool_calls(self, parser):
|
||||||
|
text = '[TOOL_CALLS]func1{"a": 1}[TOOL_CALLS]func2{"b": 2}'
|
||||||
|
content, tool_calls = parser.parse(text)
|
||||||
|
assert tool_calls is not None
|
||||||
|
assert len(tool_calls) == 2
|
||||||
|
names = [tc.function.name for tc in tool_calls]
|
||||||
|
assert "func1" in names
|
||||||
|
assert "func2" in names
|
||||||
|
|
||||||
|
def test_preceding_text_preserved(self, parser):
|
||||||
|
text = 'Hello[TOOL_CALLS]func{"a": 1}'
|
||||||
|
content, tool_calls = parser.parse(text)
|
||||||
|
assert content == "Hello"
|
||||||
|
assert tool_calls is not None
|
||||||
|
assert len(tool_calls) == 1
|
||||||
|
assert tool_calls[0].function.name == "func"
|
||||||
|
|
||||||
|
def test_malformed_json_fallback(self, parser):
|
||||||
|
text = "[TOOL_CALLS] not valid json"
|
||||||
|
content, tool_calls = parser.parse(text)
|
||||||
|
assert tool_calls is None
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue