diff options
Diffstat (limited to 'agent')
| -rw-r--r-- | agent/Dockerfile | 6 | ||||
| -rw-r--r-- | agent/agent.py | 162 |
2 files changed, 168 insertions, 0 deletions
diff --git a/agent/Dockerfile b/agent/Dockerfile new file mode 100644 index 0000000..509c3b6 --- /dev/null +++ b/agent/Dockerfile | |||
| @@ -0,0 +1,6 @@ | |||
| 1 | FROM python:3.12-slim | ||
| 2 | RUN pip install --no-cache-dir 'openai>=1.59.2,<2' 'httpx<0.28' | ||
| 3 | WORKDIR /app | ||
| 4 | COPY agent.py /app/agent.py | ||
| 5 | EXPOSE 8001 | ||
| 6 | CMD ["python", "/app/agent.py"] | ||
diff --git a/agent/agent.py b/agent/agent.py new file mode 100644 index 0000000..12ad9d6 --- /dev/null +++ b/agent/agent.py | |||
| @@ -0,0 +1,162 @@ | |||
| 1 | """Tool-using agent over an OpenAI-compatible backend. | ||
| 2 | |||
| 3 | Uses the standard OpenAI tools API (function calling). vLLM maps this to the | ||
| 4 | model's native tool-call template (Qwen here), so small models follow the | ||
| 5 | protocol much more reliably than a hand-rolled text convention. | ||
| 6 | |||
| 7 | POST /ask {"question": "..."} -> {"answer": "...", "transcript": [...]} | ||
| 8 | GET /health -> "ok" | ||
| 9 | """ | ||
| 10 | import json | ||
| 11 | import os | ||
| 12 | import re | ||
| 13 | from http.server import BaseHTTPRequestHandler, HTTPServer | ||
| 14 | |||
| 15 | from openai import OpenAI | ||
| 16 | |||
| 17 | client = OpenAI( | ||
| 18 | base_url=os.environ["OPENAI_BASE_URL"], | ||
| 19 | api_key=os.environ.get("OPENAI_API_KEY", "sk-local"), | ||
| 20 | ) | ||
| 21 | MODEL = os.environ.get("MODEL", "Qwen2.5-1.5B-Instruct") | ||
| 22 | MAX_STEPS = int(os.environ.get("MAX_STEPS", "6")) | ||
| 23 | |||
| 24 | SYSTEM = ( | ||
| 25 | "You are a careful math assistant. When the user asks any arithmetic question, " | ||
| 26 | "call the 'calc' tool with the exact expression. Do not compute arithmetic in your head. " | ||
| 27 | "After you receive the tool result, give a concise final answer." | ||
| 28 | ) | ||
| 29 | |||
| 30 | TOOLS = [ | ||
| 31 | { | ||
| 32 | "type": "function", | ||
| 33 | "function": { | ||
| 34 | "name": "calc", | ||
| 35 | "description": "Evaluate a safe arithmetic expression and return the numeric result.", | ||
| 36 | "parameters": { | ||
| 37 | "type": "object", | ||
| 38 | "properties": { | ||
| 39 | "expression": { | ||
| 40 | "type": "string", | ||
| 41 | "description": "Arithmetic expression using only digits, spaces, and + - * / . ( )", | ||
| 42 | } | ||
| 43 | }, | ||
| 44 | "required": ["expression"], | ||
| 45 | }, | ||
| 46 | }, | ||
| 47 | } | ||
| 48 | ] | ||
| 49 | |||
| 50 | SAFE_EXPR = re.compile(r"^[\d\s+\-*/().]+$") | ||
| 51 | |||
| 52 | |||
| 53 | def calc(expression: str) -> str: | ||
| 54 | if not SAFE_EXPR.fullmatch(expression): | ||
| 55 | return "ERROR: disallowed characters" | ||
| 56 | try: | ||
| 57 | return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307 | ||
| 58 | except Exception as e: | ||
| 59 | return f"ERROR: {e}" | ||
| 60 | |||
| 61 | |||
| 62 | def run_agent(question: str) -> dict: | ||
| 63 | messages = [ | ||
| 64 | {"role": "system", "content": SYSTEM}, | ||
| 65 | {"role": "user", "content": question}, | ||
| 66 | ] | ||
| 67 | transcript: list = [] | ||
| 68 | |||
| 69 | for step in range(MAX_STEPS): | ||
| 70 | resp = client.chat.completions.create( | ||
| 71 | model=MODEL, | ||
| 72 | messages=messages, | ||
| 73 | tools=TOOLS, | ||
| 74 | tool_choice="auto", | ||
| 75 | temperature=0.0, | ||
| 76 | max_tokens=256, | ||
| 77 | ) | ||
| 78 | msg = resp.choices[0].message | ||
| 79 | |||
| 80 | # Always append the assistant message (with any tool_calls) to history. | ||
| 81 | assistant_entry = {"role": "assistant", "content": msg.content or ""} | ||
| 82 | if msg.tool_calls: | ||
| 83 | assistant_entry["tool_calls"] = [ | ||
| 84 | { | ||
| 85 | "id": tc.id, | ||
| 86 | "type": "function", | ||
| 87 | "function": {"name": tc.function.name, "arguments": tc.function.arguments}, | ||
| 88 | } | ||
| 89 | for tc in msg.tool_calls | ||
| 90 | ] | ||
| 91 | messages.append(assistant_entry) | ||
| 92 | |||
| 93 | transcript.append( | ||
| 94 | { | ||
| 95 | "step": step + 1, | ||
| 96 | "content": msg.content, | ||
| 97 | "tool_calls": [ | ||
| 98 | {"name": tc.function.name, "arguments": tc.function.arguments} | ||
| 99 | for tc in (msg.tool_calls or []) | ||
| 100 | ], | ||
| 101 | } | ||
| 102 | ) | ||
| 103 | |||
| 104 | if msg.tool_calls: | ||
| 105 | for tc in msg.tool_calls: | ||
| 106 | if tc.function.name != "calc": | ||
| 107 | result = f"ERROR: unknown tool {tc.function.name}" | ||
| 108 | else: | ||
| 109 | try: | ||
| 110 | args = json.loads(tc.function.arguments) | ||
| 111 | except json.JSONDecodeError: | ||
| 112 | result = "ERROR: bad JSON arguments" | ||
| 113 | else: | ||
| 114 | result = calc(args.get("expression", "")) | ||
| 115 | transcript.append({"tool_result": {"name": tc.function.name, "result": result}}) | ||
| 116 | messages.append( | ||
| 117 | {"role": "tool", "tool_call_id": tc.id, "content": result} | ||
| 118 | ) | ||
| 119 | continue | ||
| 120 | |||
| 121 | # No tool call -> model produced a final answer. | ||
| 122 | return {"answer": (msg.content or "").strip(), "steps": step + 1, "transcript": transcript} | ||
| 123 | |||
| 124 | return {"answer": None, "steps": MAX_STEPS, "note": "MAX_STEPS reached", "transcript": transcript} | ||
| 125 | |||
| 126 | |||
| 127 | class Handler(BaseHTTPRequestHandler): | ||
| 128 | def do_POST(self): # noqa: N802 | ||
| 129 | if self.path != "/ask": | ||
| 130 | self.send_response(404); self.end_headers(); return | ||
| 131 | n = int(self.headers.get("Content-Length", "0")) | ||
| 132 | try: | ||
| 133 | body = json.loads(self.rfile.read(n) or b"{}") | ||
| 134 | except json.JSONDecodeError: | ||
| 135 | self.send_response(400); self.end_headers(); self.wfile.write(b'{"error":"invalid json"}'); return | ||
| 136 | q = body.get("question", "") | ||
| 137 | try: | ||
| 138 | result = run_agent(q) | ||
| 139 | code = 200 | ||
| 140 | except Exception as e: | ||
| 141 | result = {"error": str(e), "type": type(e).__name__} | ||
| 142 | code = 500 | ||
| 143 | payload = json.dumps(result).encode() | ||
| 144 | self.send_response(code) | ||
| 145 | self.send_header("Content-Type", "application/json") | ||
| 146 | self.send_header("Content-Length", str(len(payload))) | ||
| 147 | self.end_headers() | ||
| 148 | self.wfile.write(payload) | ||
| 149 | |||
| 150 | def do_GET(self): # noqa: N802 | ||
| 151 | if self.path == "/health": | ||
| 152 | self.send_response(200); self.end_headers(); self.wfile.write(b"ok"); return | ||
| 153 | self.send_response(404); self.end_headers() | ||
| 154 | |||
| 155 | def log_message(self, fmt, *args): | ||
| 156 | import sys | ||
| 157 | print(f"{self.address_string()} {fmt % args}", file=sys.stderr) | ||
| 158 | |||
| 159 | |||
| 160 | if __name__ == "__main__": | ||
| 161 | print(f"agent starting on :8001, model={MODEL}, backend={os.environ['OPENAI_BASE_URL']}") | ||
| 162 | HTTPServer(("0.0.0.0", 8001), Handler).serve_forever() | ||
