From d3e770254de0bb301815ca87257c8b1a357d06c4 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 26 Apr 2026 21:02:47 +0800 Subject: hehe --- agent/agent.py | 162 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 agent/agent.py (limited to 'agent/agent.py') diff --git a/agent/agent.py b/agent/agent.py new file mode 100644 index 0000000..12ad9d6 --- /dev/null +++ b/agent/agent.py @@ -0,0 +1,162 @@ +"""Tool-using agent over an OpenAI-compatible backend. + +Uses the standard OpenAI tools API (function calling). vLLM maps this to the +model's native tool-call template (Qwen here), so small models follow the +protocol much more reliably than a hand-rolled text convention. + +POST /ask {"question": "..."} -> {"answer": "...", "transcript": [...]} +GET /health -> "ok" +""" +import json +import os +import re +from http.server import BaseHTTPRequestHandler, HTTPServer + +from openai import OpenAI + +client = OpenAI( + base_url=os.environ["OPENAI_BASE_URL"], + api_key=os.environ.get("OPENAI_API_KEY", "sk-local"), +) +MODEL = os.environ.get("MODEL", "Qwen2.5-1.5B-Instruct") +MAX_STEPS = int(os.environ.get("MAX_STEPS", "6")) + +SYSTEM = ( + "You are a careful math assistant. When the user asks any arithmetic question, " + "call the 'calc' tool with the exact expression. Do not compute arithmetic in your head. " + "After you receive the tool result, give a concise final answer." +) + +TOOLS = [ + { + "type": "function", + "function": { + "name": "calc", + "description": "Evaluate a safe arithmetic expression and return the numeric result.", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Arithmetic expression using only digits, spaces, and + - * / . ( )", + } + }, + "required": ["expression"], + }, + }, + } +] + +SAFE_EXPR = re.compile(r"^[\d\s+\-*/().]+$") + + +def calc(expression: str) -> str: + if not SAFE_EXPR.fullmatch(expression): + return "ERROR: disallowed characters" + try: + return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307 + except Exception as e: + return f"ERROR: {e}" + + +def run_agent(question: str) -> dict: + messages = [ + {"role": "system", "content": SYSTEM}, + {"role": "user", "content": question}, + ] + transcript: list = [] + + for step in range(MAX_STEPS): + resp = client.chat.completions.create( + model=MODEL, + messages=messages, + tools=TOOLS, + tool_choice="auto", + temperature=0.0, + max_tokens=256, + ) + msg = resp.choices[0].message + + # Always append the assistant message (with any tool_calls) to history. + assistant_entry = {"role": "assistant", "content": msg.content or ""} + if msg.tool_calls: + assistant_entry["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": {"name": tc.function.name, "arguments": tc.function.arguments}, + } + for tc in msg.tool_calls + ] + messages.append(assistant_entry) + + transcript.append( + { + "step": step + 1, + "content": msg.content, + "tool_calls": [ + {"name": tc.function.name, "arguments": tc.function.arguments} + for tc in (msg.tool_calls or []) + ], + } + ) + + if msg.tool_calls: + for tc in msg.tool_calls: + if tc.function.name != "calc": + result = f"ERROR: unknown tool {tc.function.name}" + else: + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError: + result = "ERROR: bad JSON arguments" + else: + result = calc(args.get("expression", "")) + transcript.append({"tool_result": {"name": tc.function.name, "result": result}}) + messages.append( + {"role": "tool", "tool_call_id": tc.id, "content": result} + ) + continue + + # No tool call -> model produced a final answer. + return {"answer": (msg.content or "").strip(), "steps": step + 1, "transcript": transcript} + + return {"answer": None, "steps": MAX_STEPS, "note": "MAX_STEPS reached", "transcript": transcript} + + +class Handler(BaseHTTPRequestHandler): + def do_POST(self): # noqa: N802 + if self.path != "/ask": + self.send_response(404); self.end_headers(); return + n = int(self.headers.get("Content-Length", "0")) + try: + body = json.loads(self.rfile.read(n) or b"{}") + except json.JSONDecodeError: + self.send_response(400); self.end_headers(); self.wfile.write(b'{"error":"invalid json"}'); return + q = body.get("question", "") + try: + result = run_agent(q) + code = 200 + except Exception as e: + result = {"error": str(e), "type": type(e).__name__} + code = 500 + payload = json.dumps(result).encode() + self.send_response(code) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(payload))) + self.end_headers() + self.wfile.write(payload) + + def do_GET(self): # noqa: N802 + if self.path == "/health": + self.send_response(200); self.end_headers(); self.wfile.write(b"ok"); return + self.send_response(404); self.end_headers() + + def log_message(self, fmt, *args): + import sys + print(f"{self.address_string()} {fmt % args}", file=sys.stderr) + + +if __name__ == "__main__": + print(f"agent starting on :8001, model={MODEL}, backend={os.environ['OPENAI_BASE_URL']}") + HTTPServer(("0.0.0.0", 8001), Handler).serve_forever() -- cgit