summaryrefslogtreecommitdiff
path: root/agent/agent.py
diff options
context:
space:
mode:
authorYour Name <you@example.com>2026-04-26 21:02:47 +0800
committerYour Name <you@example.com>2026-04-26 21:02:47 +0800
commitd3e770254de0bb301815ca87257c8b1a357d06c4 (patch)
tree358c814be2a06b9e2009905f14938243286b8d82 /agent/agent.py
Diffstat (limited to 'agent/agent.py')
-rw-r--r--agent/agent.py162
1 files changed, 162 insertions, 0 deletions
diff --git a/agent/agent.py b/agent/agent.py
new file mode 100644
index 0000000..12ad9d6
--- /dev/null
+++ b/agent/agent.py
@@ -0,0 +1,162 @@
1"""Tool-using agent over an OpenAI-compatible backend.
2
3Uses the standard OpenAI tools API (function calling). vLLM maps this to the
4model's native tool-call template (Qwen here), so small models follow the
5protocol much more reliably than a hand-rolled text convention.
6
7POST /ask {"question": "..."} -> {"answer": "...", "transcript": [...]}
8GET /health -> "ok"
9"""
10import json
11import os
12import re
13from http.server import BaseHTTPRequestHandler, HTTPServer
14
15from openai import OpenAI
16
17client = OpenAI(
18 base_url=os.environ["OPENAI_BASE_URL"],
19 api_key=os.environ.get("OPENAI_API_KEY", "sk-local"),
20)
21MODEL = os.environ.get("MODEL", "Qwen2.5-1.5B-Instruct")
22MAX_STEPS = int(os.environ.get("MAX_STEPS", "6"))
23
24SYSTEM = (
25 "You are a careful math assistant. When the user asks any arithmetic question, "
26 "call the 'calc' tool with the exact expression. Do not compute arithmetic in your head. "
27 "After you receive the tool result, give a concise final answer."
28)
29
30TOOLS = [
31 {
32 "type": "function",
33 "function": {
34 "name": "calc",
35 "description": "Evaluate a safe arithmetic expression and return the numeric result.",
36 "parameters": {
37 "type": "object",
38 "properties": {
39 "expression": {
40 "type": "string",
41 "description": "Arithmetic expression using only digits, spaces, and + - * / . ( )",
42 }
43 },
44 "required": ["expression"],
45 },
46 },
47 }
48]
49
50SAFE_EXPR = re.compile(r"^[\d\s+\-*/().]+$")
51
52
53def calc(expression: str) -> str:
54 if not SAFE_EXPR.fullmatch(expression):
55 return "ERROR: disallowed characters"
56 try:
57 return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307
58 except Exception as e:
59 return f"ERROR: {e}"
60
61
62def run_agent(question: str) -> dict:
63 messages = [
64 {"role": "system", "content": SYSTEM},
65 {"role": "user", "content": question},
66 ]
67 transcript: list = []
68
69 for step in range(MAX_STEPS):
70 resp = client.chat.completions.create(
71 model=MODEL,
72 messages=messages,
73 tools=TOOLS,
74 tool_choice="auto",
75 temperature=0.0,
76 max_tokens=256,
77 )
78 msg = resp.choices[0].message
79
80 # Always append the assistant message (with any tool_calls) to history.
81 assistant_entry = {"role": "assistant", "content": msg.content or ""}
82 if msg.tool_calls:
83 assistant_entry["tool_calls"] = [
84 {
85 "id": tc.id,
86 "type": "function",
87 "function": {"name": tc.function.name, "arguments": tc.function.arguments},
88 }
89 for tc in msg.tool_calls
90 ]
91 messages.append(assistant_entry)
92
93 transcript.append(
94 {
95 "step": step + 1,
96 "content": msg.content,
97 "tool_calls": [
98 {"name": tc.function.name, "arguments": tc.function.arguments}
99 for tc in (msg.tool_calls or [])
100 ],
101 }
102 )
103
104 if msg.tool_calls:
105 for tc in msg.tool_calls:
106 if tc.function.name != "calc":
107 result = f"ERROR: unknown tool {tc.function.name}"
108 else:
109 try:
110 args = json.loads(tc.function.arguments)
111 except json.JSONDecodeError:
112 result = "ERROR: bad JSON arguments"
113 else:
114 result = calc(args.get("expression", ""))
115 transcript.append({"tool_result": {"name": tc.function.name, "result": result}})
116 messages.append(
117 {"role": "tool", "tool_call_id": tc.id, "content": result}
118 )
119 continue
120
121 # No tool call -> model produced a final answer.
122 return {"answer": (msg.content or "").strip(), "steps": step + 1, "transcript": transcript}
123
124 return {"answer": None, "steps": MAX_STEPS, "note": "MAX_STEPS reached", "transcript": transcript}
125
126
127class Handler(BaseHTTPRequestHandler):
128 def do_POST(self): # noqa: N802
129 if self.path != "/ask":
130 self.send_response(404); self.end_headers(); return
131 n = int(self.headers.get("Content-Length", "0"))
132 try:
133 body = json.loads(self.rfile.read(n) or b"{}")
134 except json.JSONDecodeError:
135 self.send_response(400); self.end_headers(); self.wfile.write(b'{"error":"invalid json"}'); return
136 q = body.get("question", "")
137 try:
138 result = run_agent(q)
139 code = 200
140 except Exception as e:
141 result = {"error": str(e), "type": type(e).__name__}
142 code = 500
143 payload = json.dumps(result).encode()
144 self.send_response(code)
145 self.send_header("Content-Type", "application/json")
146 self.send_header("Content-Length", str(len(payload)))
147 self.end_headers()
148 self.wfile.write(payload)
149
150 def do_GET(self): # noqa: N802
151 if self.path == "/health":
152 self.send_response(200); self.end_headers(); self.wfile.write(b"ok"); return
153 self.send_response(404); self.end_headers()
154
155 def log_message(self, fmt, *args):
156 import sys
157 print(f"{self.address_string()} {fmt % args}", file=sys.stderr)
158
159
160if __name__ == "__main__":
161 print(f"agent starting on :8001, model={MODEL}, backend={os.environ['OPENAI_BASE_URL']}")
162 HTTPServer(("0.0.0.0", 8001), Handler).serve_forever()