mistle/mud_session.py

329 lines
9.6 KiB
Python

from __future__ import annotations
import select
import sys
import time
import unicodedata
from collections import deque
from threading import Event, Lock
from typing import Callable, Deque, List, Optional
from telnetclient import TelnetClient
from tools import Tool
_UMLAUT_TRANSLATION = str.maketrans(
{
"ä": "ae",
"ö": "oe",
"ü": "ue",
"Ä": "Ae",
"Ö": "Oe",
"Ü": "Ue",
"ß": "ss",
}
)
def sanitize_for_mud(text: str) -> str:
"""Return ASCII-only text suitable for Silberland's input parser."""
if not text:
return ""
replaced = text.translate(_UMLAUT_TRANSLATION)
normalized = unicodedata.normalize("NFKD", replaced)
cleaned: list[str] = []
for ch in normalized:
if unicodedata.combining(ch):
continue
code = ord(ch)
if 32 <= code <= 126:
cleaned.append(ch)
else:
cleaned.append("?")
return "".join(cleaned)
class SessionState:
"""Share Telnet session state safely across threads."""
class _OutputListener:
def __init__(self) -> None:
self._queue: Deque[str] = deque()
self._lock = Lock()
self._event = Event()
def publish(self, text: str) -> None:
with self._lock:
self._queue.append(text)
self._event.set()
def wait(self, timeout: float) -> bool:
return self._event.wait(timeout)
def drain(self) -> List[str]:
with self._lock:
items = list(self._queue)
self._queue.clear()
self._event.clear()
return items
def close(self) -> None:
with self._lock:
self._queue.clear()
self._event.set()
def __init__(self) -> None:
self._send_lock = Lock()
self._output_lock = Lock()
self._output_event = Event()
self._last_output = ""
self._last_tool_send = 0.0
self._listeners: set[SessionState._OutputListener] = set()
self._listeners_lock = Lock()
def send(self, client: TelnetClient, message: str) -> None:
sanitized = sanitize_for_mud(message)
with self._send_lock:
client.send(sanitized)
def tool_send(
self,
client: TelnetClient,
message: str,
*,
min_interval: float,
stop_event: Event,
) -> bool:
"""Send on behalf of the tool while respecting a minimum cadence."""
sanitized = sanitize_for_mud(message)
while not stop_event.is_set():
with self._send_lock:
now = time.time()
elapsed = now - self._last_tool_send
if elapsed >= min_interval:
client.send(sanitized)
self._last_tool_send = now
return True
wait_time = min_interval - elapsed
if wait_time <= 0:
continue
if stop_event.wait(wait_time):
break
return False
def update_output(self, text: str) -> None:
if not text:
return
with self._output_lock:
self._last_output = text
with self._listeners_lock:
listeners = list(self._listeners)
for listener in listeners:
listener.publish(text)
self._output_event.set()
def snapshot_output(self) -> str:
with self._output_lock:
return self._last_output
def register_listener(self) -> "SessionState._OutputListener":
listener = SessionState._OutputListener()
with self._listeners_lock:
self._listeners.add(listener)
with self._output_lock:
last_output = self._last_output
if last_output:
listener.publish(last_output)
return listener
def remove_listener(self, listener: "SessionState._OutputListener") -> None:
with self._listeners_lock:
existed = listener in self._listeners
if existed:
self._listeners.remove(listener)
if existed:
listener.close()
def wait_for_output(self, timeout: float) -> bool:
return self._output_event.wait(timeout)
def clear_output_event(self) -> None:
self._output_event.clear()
def run_tool_loop(
client: TelnetClient,
state: SessionState,
tool: Tool,
stop_event: Event,
*,
idle_delay: float = 0.5,
min_send_interval: float = 1.0,
auto_stop: bool = False,
auto_stop_idle: float = 2.0,
) -> None:
"""Invoke *tool* whenever new output arrives and send its response."""
idle_started: Optional[float] = None
listener = state.register_listener()
def maybe_send() -> None:
nonlocal idle_started
try:
command = tool.decide()
except Exception as exc: # pragma: no cover - defensive logging
print(f"[Tool] Failed: {exc}", file=sys.stderr)
return
if not command:
return
sent = state.tool_send(
client,
command,
min_interval=min_send_interval,
stop_event=stop_event,
)
if not sent:
return
idle_started = None
try:
while not stop_event.is_set():
maybe_send()
if stop_event.is_set():
break
triggered = listener.wait(timeout=idle_delay)
if stop_event.is_set():
break
if not triggered:
if auto_stop:
now = time.time()
if idle_started is None:
idle_started = now
elif now - idle_started >= auto_stop_idle:
break
continue
outputs = listener.drain()
if not outputs:
continue
idle_started = None
for chunk in outputs:
if not chunk:
continue
try:
tool.observe(chunk)
except Exception as exc: # pragma: no cover - defensive logging
print(f"[Tool] Failed during observe: {exc}", file=sys.stderr)
maybe_send()
finally:
state.remove_listener(listener)
def login(
client: TelnetClient,
*,
user: str,
password: str,
login_prompt: str,
banner_timeout: float = 10.0,
response_timeout: float = 2.0,
state: Optional[SessionState] = None,
) -> None:
"""Handle the banner/prompt exchange and send credentials."""
if login_prompt:
banner = client.read_until(login_prompt, timeout=banner_timeout)
else:
banner = client.receive(timeout=response_timeout)
if banner:
print(banner, end="" if banner.endswith("\n") else "\n")
if state:
state.update_output(banner)
if user:
client.send(sanitize_for_mud(user))
time.sleep(0.2)
if password:
client.send(sanitize_for_mud(password))
response = client.receive(timeout=response_timeout)
if response:
print(response, end="" if response.endswith("\n") else "\n")
if state:
state.update_output(response)
def interactive_session(
client: TelnetClient,
state: SessionState,
stop_event: Event,
*,
poll_interval: float = 0.2,
receive_timeout: float = 0.2,
exit_command: str,
tool_command: Optional[Callable[[str], None]] = None,
agent_command: Optional[Callable[[str], None]] = None,
) -> None:
"""Keep the Telnet session running, proxying input/output until interrupted."""
if exit_command:
print(f"Connected. Press Ctrl-C to exit (will send {exit_command!r}).")
else:
print("Connected. Press Ctrl-C to exit.")
while not stop_event.is_set():
incoming = client.receive(timeout=receive_timeout)
if incoming:
print(incoming, end="" if incoming.endswith("\n") else "\n")
state.update_output(incoming)
readable, _, _ = select.select([sys.stdin], [], [], poll_interval)
if sys.stdin in readable:
line = sys.stdin.readline()
if line == "":
stop_event.set()
break
line = line.rstrip("\r\n")
if not line:
continue
lowered = line.lower()
if agent_command and lowered.startswith("#agent"):
parts = line.split(maxsplit=1)
if len(parts) == 1:
print("[Agent] Usage: #agent <agent_spec>")
else:
agent_command(parts[1])
continue
if tool_command and lowered.startswith("#execute"):
parts = line.split(maxsplit=1)
if len(parts) == 1:
print("[Tool] Usage: #execute <tool_spec>")
else:
tool_command(parts[1])
continue
state.send(client, line)
def graceful_shutdown(
client: TelnetClient,
exit_command: str,
*,
state: Optional[SessionState] = None,
) -> None:
if not exit_command:
return
try:
if state:
state.send(client, exit_command)
else:
client.send(sanitize_for_mud(exit_command))
farewell = client.receive(timeout=2.0)
if farewell:
print(farewell, end="" if farewell.endswith("\n") else "\n")
if state:
state.update_output(farewell)
except Exception as exc: # pragma: no cover - best effort logging
print(f"Failed to send exit command: {exc}", file=sys.stderr)