mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-21 03:39:54 +00:00
3fbedd732e
Original commit 395e9dd9e by Teknium targeted module-level _is_mcp_tool_parallel_safe
and _should_parallelize_tool_batch helpers in pre-refactor run_agent.py. Both
helpers now live in agent/tool_dispatch_helpers.py — re-applied to that
module.
The tools/mcp_tool.py portion (the public is_mcp_tool_parallel_safe API
+ _parallel_safe_servers tracking) merged cleanly from main via the prior
merge commit.
Co-authored-by: Teknium <127238744+teknium1@users.noreply.github.com>
337 lines
11 KiB
Python
337 lines
11 KiB
Python
"""Tool-dispatch helpers — parallelism gating, multimodal envelopes, mutation tracking.
|
|
|
|
Pure module-level utilities extracted from ``run_agent.py``:
|
|
|
|
* ``_is_destructive_command`` — terminal-command heuristic used to gate
|
|
parallel batch dispatch.
|
|
* ``_should_parallelize_tool_batch`` / ``_extract_parallel_scope_path`` /
|
|
``_paths_overlap`` — the rules engine deciding when a multi-tool batch
|
|
can run concurrently.
|
|
* ``_is_multimodal_tool_result`` / ``_multimodal_text_summary`` /
|
|
``_append_subdir_hint_to_multimodal`` — envelope helpers for the
|
|
``{"_multimodal": True, "content": [...], "text_summary": ...}`` dict
|
|
shape returned by tools like ``computer_use``.
|
|
* ``_extract_file_mutation_targets`` / ``_extract_error_preview`` —
|
|
per-turn file-mutation verifier inputs.
|
|
* ``_trajectory_normalize_msg`` — strip image blobs from a message for
|
|
trajectory saving.
|
|
|
|
All helpers are stateless. ``run_agent`` re-exports each name so existing
|
|
``from run_agent import ...`` imports in tests and other modules keep
|
|
working unchanged.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from agent.tool_result_classification import (
|
|
FILE_MUTATING_TOOL_NAMES as _FILE_MUTATING_TOOLS,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Tools that must never run concurrently (interactive / user-facing).
|
|
# When any of these appear in a batch, we fall back to sequential execution.
|
|
_NEVER_PARALLEL_TOOLS = frozenset({"clarify"})
|
|
|
|
# Read-only tools with no shared mutable session state.
|
|
_PARALLEL_SAFE_TOOLS = frozenset({
|
|
"ha_get_state",
|
|
"ha_list_entities",
|
|
"ha_list_services",
|
|
"read_file",
|
|
"search_files",
|
|
"session_search",
|
|
"skill_view",
|
|
"skills_list",
|
|
"vision_analyze",
|
|
"web_extract",
|
|
"web_search",
|
|
})
|
|
|
|
# File tools can run concurrently when they target independent paths.
|
|
_PATH_SCOPED_TOOLS = frozenset({"read_file", "write_file", "patch"})
|
|
|
|
# Patterns that indicate a terminal command may modify/delete files.
|
|
_DESTRUCTIVE_PATTERNS = re.compile(
|
|
r"""(?:^|\s|&&|\|\||;|`)(?:
|
|
rm\s|rmdir\s|
|
|
cp\s|install\s|
|
|
mv\s|
|
|
sed\s+-i|
|
|
truncate\s|
|
|
dd\s|
|
|
shred\s|
|
|
git\s+(?:reset|clean|checkout)\s
|
|
)""",
|
|
re.VERBOSE,
|
|
)
|
|
# Output redirects that overwrite files (> but not >>)
|
|
_REDIRECT_OVERWRITE = re.compile(r'[^>]>[^>]|^>[^>]')
|
|
|
|
|
|
def _is_destructive_command(cmd: str) -> bool:
|
|
"""Heuristic: does this terminal command look like it modifies/deletes files?"""
|
|
if not cmd:
|
|
return False
|
|
if _DESTRUCTIVE_PATTERNS.search(cmd):
|
|
return True
|
|
if _REDIRECT_OVERWRITE.search(cmd):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _is_mcp_tool_parallel_safe(tool_name: str) -> bool:
|
|
"""Check if an MCP tool comes from a server with parallel tool calls enabled.
|
|
|
|
Lazy-imports from ``tools.mcp_tool`` to avoid circular dependencies.
|
|
Returns False if the MCP module is not available.
|
|
"""
|
|
try:
|
|
from tools.mcp_tool import is_mcp_tool_parallel_safe
|
|
return is_mcp_tool_parallel_safe(tool_name)
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _should_parallelize_tool_batch(tool_calls) -> bool:
|
|
"""Return True when a tool-call batch is safe to run concurrently."""
|
|
if len(tool_calls) <= 1:
|
|
return False
|
|
|
|
tool_names = [tc.function.name for tc in tool_calls]
|
|
if any(name in _NEVER_PARALLEL_TOOLS for name in tool_names):
|
|
return False
|
|
|
|
reserved_paths: list[Path] = []
|
|
for tool_call in tool_calls:
|
|
tool_name = tool_call.function.name
|
|
try:
|
|
function_args = json.loads(tool_call.function.arguments)
|
|
except Exception:
|
|
logging.debug(
|
|
"Could not parse args for %s — defaulting to sequential; raw=%s",
|
|
tool_name,
|
|
tool_call.function.arguments[:200],
|
|
)
|
|
return False
|
|
if not isinstance(function_args, dict):
|
|
logging.debug(
|
|
"Non-dict args for %s (%s) — defaulting to sequential",
|
|
tool_name,
|
|
type(function_args).__name__,
|
|
)
|
|
return False
|
|
|
|
if tool_name in _PATH_SCOPED_TOOLS:
|
|
scoped_path = _extract_parallel_scope_path(tool_name, function_args)
|
|
if scoped_path is None:
|
|
return False
|
|
if any(_paths_overlap(scoped_path, existing) for existing in reserved_paths):
|
|
return False
|
|
reserved_paths.append(scoped_path)
|
|
continue
|
|
|
|
if tool_name not in _PARALLEL_SAFE_TOOLS:
|
|
# Check if it's an MCP tool from a server that opted into parallel calls.
|
|
if not _is_mcp_tool_parallel_safe(tool_name):
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def _extract_parallel_scope_path(tool_name: str, function_args: dict) -> Optional[Path]:
|
|
"""Return the normalized file target for path-scoped tools."""
|
|
if tool_name not in _PATH_SCOPED_TOOLS:
|
|
return None
|
|
|
|
raw_path = function_args.get("path")
|
|
if not isinstance(raw_path, str) or not raw_path.strip():
|
|
return None
|
|
|
|
expanded = Path(raw_path).expanduser()
|
|
if expanded.is_absolute():
|
|
return Path(os.path.abspath(str(expanded)))
|
|
|
|
# Avoid resolve(); the file may not exist yet.
|
|
return Path(os.path.abspath(str(Path.cwd() / expanded)))
|
|
|
|
|
|
def _paths_overlap(left: Path, right: Path) -> bool:
|
|
"""Return True when two paths may refer to the same subtree."""
|
|
left_parts = left.parts
|
|
right_parts = right.parts
|
|
if not left_parts or not right_parts:
|
|
# Empty paths shouldn't reach here (guarded upstream), but be safe.
|
|
return bool(left_parts) == bool(right_parts) and bool(left_parts)
|
|
common_len = min(len(left_parts), len(right_parts))
|
|
return left_parts[:common_len] == right_parts[:common_len]
|
|
|
|
|
|
def _is_multimodal_tool_result(value: Any) -> bool:
|
|
"""True if the value is a multimodal tool result envelope.
|
|
|
|
Multimodal handlers (e.g. tools/computer_use) return a dict with
|
|
`_multimodal=True`, a `content` key holding OpenAI-style content
|
|
parts, and an optional `text_summary` for string-only fallbacks.
|
|
"""
|
|
return (
|
|
isinstance(value, dict)
|
|
and value.get("_multimodal") is True
|
|
and isinstance(value.get("content"), list)
|
|
)
|
|
|
|
|
|
def _multimodal_text_summary(value: Any) -> str:
|
|
"""Extract a plain text view of a multimodal tool result.
|
|
|
|
Used wherever downstream code needs a string — logging, previews,
|
|
persistence size heuristics, fall-back content for providers that
|
|
don't support multipart tool messages.
|
|
"""
|
|
if _is_multimodal_tool_result(value):
|
|
if value.get("text_summary"):
|
|
return str(value["text_summary"])
|
|
parts = []
|
|
for p in value.get("content") or []:
|
|
if isinstance(p, dict) and p.get("type") == "text":
|
|
parts.append(str(p.get("text", "")))
|
|
if parts:
|
|
return "\n".join(parts)
|
|
return "[multimodal tool result]"
|
|
if isinstance(value, str):
|
|
return value
|
|
try:
|
|
return json.dumps(value, default=str)
|
|
except Exception:
|
|
return str(value)
|
|
|
|
|
|
def _append_subdir_hint_to_multimodal(value: Dict[str, Any], hint: str) -> None:
|
|
"""Mutate a multimodal tool-result envelope to append a subdir hint.
|
|
|
|
The hint is added to the first text part so the model sees it; image
|
|
parts are left untouched. `text_summary` is also updated for
|
|
string-fallback callers.
|
|
"""
|
|
if not _is_multimodal_tool_result(value):
|
|
return
|
|
parts = value.get("content") or []
|
|
for p in parts:
|
|
if isinstance(p, dict) and p.get("type") == "text":
|
|
p["text"] = str(p.get("text", "")) + hint
|
|
break
|
|
else:
|
|
parts.insert(0, {"type": "text", "text": hint})
|
|
value["content"] = parts
|
|
if isinstance(value.get("text_summary"), str):
|
|
value["text_summary"] = value["text_summary"] + hint
|
|
|
|
|
|
def _extract_file_mutation_targets(tool_name: str, args: Dict[str, Any]) -> List[str]:
|
|
"""Return the file paths a ``write_file`` or ``patch`` call is targeting.
|
|
|
|
For ``write_file`` and ``patch`` in replace mode this is just ``args["path"]``.
|
|
For ``patch`` in V4A patch mode we parse the patch content for
|
|
``*** Update File:`` / ``*** Add File:`` / ``*** Delete File:`` headers so
|
|
the verifier can track each file in a multi-file patch separately.
|
|
"""
|
|
if tool_name not in _FILE_MUTATING_TOOLS:
|
|
return []
|
|
if tool_name == "write_file":
|
|
p = args.get("path")
|
|
return [str(p)] if p else []
|
|
# tool_name == "patch"
|
|
mode = args.get("mode") or "replace"
|
|
if mode == "replace":
|
|
p = args.get("path")
|
|
return [str(p)] if p else []
|
|
if mode == "patch":
|
|
body = args.get("patch") or ""
|
|
if not isinstance(body, str) or not body:
|
|
return []
|
|
paths: List[str] = []
|
|
for _m in re.finditer(
|
|
r'^\*\*\*\s+(?:Update|Add|Delete)\s+File:\s*(.+)$',
|
|
body,
|
|
re.MULTILINE,
|
|
):
|
|
p = _m.group(1).strip()
|
|
if p:
|
|
paths.append(p)
|
|
return paths
|
|
return []
|
|
|
|
|
|
def _extract_error_preview(result: Any, max_len: int = 180) -> str:
|
|
"""Pull a one-line error summary out of a tool result for footer display."""
|
|
text = _multimodal_text_summary(result) if result is not None else ""
|
|
if not isinstance(text, str):
|
|
try:
|
|
text = str(text)
|
|
except Exception:
|
|
return ""
|
|
# Try to parse JSON and pull the ``error`` field — tool handlers return
|
|
# ``{"success": false, "error": "..."}``; raw string wins if parse fails.
|
|
stripped = text.strip()
|
|
if stripped.startswith("{"):
|
|
try:
|
|
data = json.loads(stripped)
|
|
if isinstance(data, dict) and isinstance(data.get("error"), str):
|
|
text = data["error"]
|
|
except Exception:
|
|
pass
|
|
# Collapse whitespace, trim to max_len.
|
|
text = " ".join(text.split())
|
|
if len(text) > max_len:
|
|
text = text[: max_len - 1] + "…"
|
|
return text
|
|
|
|
|
|
def _trajectory_normalize_msg(msg: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Strip image blobs from a message for trajectory saving.
|
|
|
|
Returns a shallow copy with multimodal tool results replaced by their
|
|
text_summary, and image parts in content lists replaced by
|
|
`[screenshot]` placeholders. Keeps the message schema otherwise intact.
|
|
"""
|
|
if not isinstance(msg, dict):
|
|
return msg
|
|
content = msg.get("content")
|
|
if _is_multimodal_tool_result(content):
|
|
return {**msg, "content": _multimodal_text_summary(content)}
|
|
if isinstance(content, list):
|
|
cleaned = []
|
|
for p in content:
|
|
if isinstance(p, dict) and p.get("type") in {"image", "image_url", "input_image"}:
|
|
cleaned.append({"type": "text", "text": "[screenshot]"})
|
|
else:
|
|
cleaned.append(p)
|
|
return {**msg, "content": cleaned}
|
|
return msg
|
|
|
|
|
|
__all__ = [
|
|
"_NEVER_PARALLEL_TOOLS",
|
|
"_PARALLEL_SAFE_TOOLS",
|
|
"_PATH_SCOPED_TOOLS",
|
|
"_DESTRUCTIVE_PATTERNS",
|
|
"_REDIRECT_OVERWRITE",
|
|
"_is_destructive_command",
|
|
"_should_parallelize_tool_batch",
|
|
"_extract_parallel_scope_path",
|
|
"_paths_overlap",
|
|
"_is_multimodal_tool_result",
|
|
"_multimodal_text_summary",
|
|
"_append_subdir_hint_to_multimodal",
|
|
"_extract_file_mutation_targets",
|
|
"_extract_error_preview",
|
|
"_trajectory_normalize_msg",
|
|
]
|