230 lines
8.3 KiB
Python
230 lines
8.3 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import sys
|
|
from dataclasses import dataclass, field
|
|
from threading import Event
|
|
from typing import Callable, Deque
|
|
|
|
from collections import deque
|
|
|
|
try:
|
|
from litellm import completion
|
|
except ImportError: # pragma: no cover
|
|
completion = None # type: ignore[assignment]
|
|
|
|
from agents import Agent
|
|
|
|
ToolInvoker = Callable[[str], bool]
|
|
CommandExecutor = Callable[[str], None]
|
|
|
|
|
|
@dataclass
|
|
class IntelligentAgent(Agent):
|
|
"""LLM-driven agent that decides between tools and raw commands."""
|
|
|
|
model: str = "mistral/mistral-large-2407"
|
|
system_prompt: str = (
|
|
"You are Mistle, a helpful MUD assistant. "
|
|
"You can either call tools or send plain commands to the MUD."
|
|
)
|
|
temperature: float = 0.7
|
|
max_output_tokens: int = 200
|
|
instruction: str = ""
|
|
turn_delay: float = 0.0
|
|
allowed_tools: dict[str, str] = field(default_factory=dict)
|
|
history: Deque[dict[str, str]] = field(default_factory=deque, init=False)
|
|
|
|
def observe(self, message: str) -> None:
|
|
content = message.strip()
|
|
if not content:
|
|
return
|
|
self.history.append({"role": "user", "content": content})
|
|
self._trim_history()
|
|
|
|
def run(
|
|
self,
|
|
*,
|
|
invoke_tool: ToolInvoker,
|
|
send_command: CommandExecutor | None,
|
|
stop_event: Event,
|
|
) -> None:
|
|
if send_command is None:
|
|
raise RuntimeError("IntelligentAgent requires send_command support")
|
|
if completion is None:
|
|
print("[Agent] litellm not available; intelligent agent disabled", file=sys.stderr)
|
|
return
|
|
|
|
messages = [{"role": "system", "content": self.system_prompt}]
|
|
if self.instruction:
|
|
messages.append({"role": "system", "content": self.instruction})
|
|
messages.extend(self.history)
|
|
if self.allowed_tools:
|
|
tool_list = "\n".join(
|
|
f"- {name}: {desc}" for name, desc in self.allowed_tools.items()
|
|
)
|
|
messages.append(
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
"Available tools (only use these names when type=tool):\n"
|
|
f"{tool_list}"
|
|
),
|
|
}
|
|
)
|
|
messages.append(
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
"Respond with JSON only. Schema: {\n"
|
|
" \"type\": \"tool\" or \"command\",\n"
|
|
" \"value\": string (tool name or raw command),\n"
|
|
" \"notes\": optional string explanation\n}"""
|
|
),
|
|
}
|
|
)
|
|
if not self.history or self.history[-1]["role"] != "assistant":
|
|
messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": "I will decide the next action now.",
|
|
}
|
|
)
|
|
messages.append(
|
|
{
|
|
"role": "user",
|
|
"content": "What is the next action you will take?",
|
|
}
|
|
)
|
|
cycle = 0
|
|
while not stop_event.is_set():
|
|
cycle += 1
|
|
print(f"[Agent] LLM cycle {cycle}...")
|
|
try:
|
|
response = completion(
|
|
model=self.model,
|
|
messages=messages,
|
|
temperature=self.temperature,
|
|
max_tokens=self.max_output_tokens,
|
|
)
|
|
except Exception as exc: # pragma: no cover
|
|
print(f"[Agent] LLM call failed: {exc}", file=sys.stderr)
|
|
return
|
|
|
|
try:
|
|
content = response["choices"][0]["message"]["content"].strip()
|
|
print(f"[Agent] LLM raw output: {content}")
|
|
json_payload = self._extract_json(content)
|
|
payload = json.loads(json_payload)
|
|
except (KeyError, IndexError, TypeError, json.JSONDecodeError) as exc:
|
|
print(f"[Agent] Invalid LLM response: {exc}", file=sys.stderr)
|
|
print(f"[Agent] Raw content: {content}", file=sys.stderr)
|
|
return
|
|
|
|
action_type = payload.get("type")
|
|
value = (payload.get("value") or "").strip()
|
|
notes = payload.get("notes")
|
|
if notes:
|
|
self.history.append({"role": "assistant", "content": f"NOTE: {notes}"})
|
|
self._trim_history()
|
|
if not value:
|
|
print("[Agent] LLM returned empty action", file=sys.stderr)
|
|
return
|
|
|
|
allowed_map = {
|
|
name.lower(): (name, desc)
|
|
for name, desc in self.allowed_tools.items()
|
|
}
|
|
|
|
if action_type == "tool":
|
|
lower = value.lower()
|
|
if allowed_map and lower not in allowed_map:
|
|
print(
|
|
f"[Agent] Tool '{value}' not in allowed list {list(self.allowed_tools)}",
|
|
file=sys.stderr,
|
|
)
|
|
return
|
|
canonical, _ = allowed_map.get(lower, (value, ""))
|
|
success = invoke_tool(canonical)
|
|
print(f"[Agent] Executed tool: {canonical} (success={success})")
|
|
self.history.append({"role": "assistant", "content": f"TOOL {canonical}"})
|
|
self._trim_history()
|
|
if not success:
|
|
return
|
|
elif action_type == "command":
|
|
send_command(value)
|
|
print(f"[Agent] Sent command: {value}")
|
|
self.history.append({"role": "assistant", "content": f"COMMAND {value}"})
|
|
self._trim_history()
|
|
elif action_type == "end":
|
|
print("[Agent] LLM requested to end the session.")
|
|
self.history.append({"role": "assistant", "content": "END"})
|
|
self._trim_history()
|
|
stop_event.set()
|
|
break
|
|
else:
|
|
print(f"[Agent] Unknown action type '{action_type}'", file=sys.stderr)
|
|
return
|
|
|
|
if self.turn_delay > 0 and stop_event.wait(self.turn_delay):
|
|
break
|
|
|
|
messages = [{"role": "system", "content": self.system_prompt}]
|
|
if self.instruction:
|
|
messages.append({"role": "system", "content": self.instruction})
|
|
messages.extend(self.history)
|
|
if self.allowed_tools:
|
|
tool_list = "\n".join(
|
|
f"- {name}: {desc}" for name, desc in self.allowed_tools.items()
|
|
)
|
|
messages.append(
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
"Available tools (only use these names when type=tool):\n"
|
|
f"{tool_list}"
|
|
),
|
|
}
|
|
)
|
|
messages.append(
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
"Respond with JSON only. Schema: {\n"
|
|
" \"type\": \"tool\" or \"command\" or \"end\",\n"
|
|
" \"value\": string,\n"
|
|
" \"notes\": optional string\n}"""
|
|
),
|
|
}
|
|
)
|
|
if not self.history or self.history[-1]["role"] != "assistant":
|
|
messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": "I will decide the next action now.",
|
|
}
|
|
)
|
|
messages.append(
|
|
{
|
|
"role": "user",
|
|
"content": "What is the next action you will take?",
|
|
}
|
|
)
|
|
|
|
print("[Agent] Intelligent agent finished.")
|
|
|
|
def _trim_history(self) -> None:
|
|
while len(self.history) > 50:
|
|
self.history.popleft()
|
|
|
|
def _extract_json(self, content: str) -> str:
|
|
text = content.strip()
|
|
if text.startswith("```"):
|
|
text = text.strip("` ")
|
|
if "\n" in text:
|
|
text = text.split("\n", 1)[1]
|
|
start = text.find("{")
|
|
end = text.rfind("}")
|
|
if start == -1 or end == -1 or end < start:
|
|
raise json.JSONDecodeError("No JSON object found", text, 0)
|
|
return text[start : end + 1]
|