summaryrefslogtreecommitdiff
path: root/search.py
diff options
context:
space:
mode:
authorYour Name <you@example.com>2026-01-04 14:37:35 +0100
committerYour Name <you@example.com>2026-01-04 14:37:35 +0100
commit3a98216eccf326cfd322478cbf791232d3390c61 (patch)
treef24feec1d516db15dca75dacaed0e9f84f718ec6 /search.py
initial_commit
Diffstat (limited to 'search.py')
-rwxr-xr-xsearch.py272
1 files changed, 272 insertions, 0 deletions
diff --git a/search.py b/search.py
new file mode 100755
index 0000000..c370ffe
--- /dev/null
+++ b/search.py
@@ -0,0 +1,272 @@
1#!/usr/bin/env python3
2"""
3multi-agent deep research thingy
4"""
5
6
7import os, re, sys
8from dataclasses import dataclass, field
9from pathlib import Path
10import torch
11from transformers import AutoModelForCausalLM, AutoTokenizer
12from sentence_transformers import SentenceTransformer
13import chromadb
14import httpx
15from bs4 import BeautifulSoup
16from ddgs import DDGS
17
18# ==================== CONFIG ====================
19@dataclass
20class Config:
21 model_name: str = "Qwen/Qwen2.5-1.5B-Instruct"
22 embedding_model: str = "all-MiniLM-L6-v2"
23 device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")
24 docs_dir: str = field(default_factory=lambda: os.getenv("DOCS_DIR", "./documents"))
25 max_critique_rounds: int = 3
26
27CFG = Config()
28log = lambda tag, msg: print(f"[{tag}] {msg}")
29
30# ==================== MODEL init ====================
31class _Models:
32 def __init__(self):
33 self._ready = False
34 self.embedder = self.tokenizer = self.llm = self.collection = None
35
36 def _init(self):
37 if self._ready: return
38 log("init", f"Device: {CFG.device}")
39 log("init", f"Loading embedder ({CFG.embedding_model})...")
40 self.embedder = SentenceTransformer(CFG.embedding_model, device=CFG.device)
41 log("init", f"Loading LLM ({CFG.model_name})...")
42 self.tokenizer = AutoTokenizer.from_pretrained(CFG.model_name)
43 self.llm = AutoModelForCausalLM.from_pretrained(
44 CFG.model_name, dtype=torch.float16, device_map="auto"
45 )
46 client = chromadb.Client()
47 self.collection = client.get_or_create_collection("research_docs", metadata={"hnsw:space": "cosine"})
48 log("init", "Ready.")
49 self._ready = True
50
51 def embed(self, text: str) -> list[float]:
52 self._init()
53 return self.embedder.encode(text).tolist()
54
55 def generate(self, task: str, instructions: str = "", max_tokens: int = 512) -> str:
56 self._init()
57 msgs = ([{"role": "system", "content": instructions}] if instructions else []) + [{"role": "user", "content": task}]
58 text = self.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
59 inputs = self.tokenizer(text, return_tensors="pt").to(self.llm.device)
60 with torch.no_grad():
61 out = self.llm.generate(**inputs, max_new_tokens=max_tokens, temperature=0.7,
62 do_sample=True, pad_token_id=self.tokenizer.eos_token_id)
63 return self.tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
64
65M = _Models()
66
67# ==================== MEMORY ====================
68@dataclass
69class Memory:
70 findings: list = field(default_factory=list)
71
72 def save(self, source: str, query: str, content: str) -> str:
73 summary = (content[:300].replace("\n", " ").strip() + "...") if len(content) > 300 else content
74 self.findings.append({"source": source, "query": query, "summary": summary})
75 return summary
76
77 def by_source(self, src: str) -> list[dict]:
78 return [f for f in self.findings if f["source"] == src]
79
80 def all_summaries(self) -> str:
81 return "\n".join(f"- [{f['source']}] {f['query']}: {f['summary']}" for f in self.findings)
82
83# ==================== TOOLS ====================
84def web_search(query: str, max_results: int = 5) -> list[dict]:
85 try:
86 log("duck", f"Searching: {query}")
87 with DDGS() as ddgs:
88 raw = list(ddgs.text(query, max_results=max_results))
89 results = [{"title": r.get("title", ""), "snippet": r.get("body", ""), "url": r.get("href", "")} for r in raw]
90 for item in results[:2]:
91 item["content"] = fetch_url(item["url"])
92 return results
93 except Exception as e:
94 log("duck", f"Error: {e}")
95 return []
96
97def fetch_url(url: str, max_chars: int = 3000) -> str:
98 try:
99 r = httpx.get(url, timeout=2, follow_redirects=True,
100 headers={"User-Agent": "Mozilla/5.0 (compatible; ResearchBot/1.0)"})
101 soup = BeautifulSoup(r.text, "html.parser")
102 for tag in soup(["script", "style", "nav", "header", "footer"]): tag.decompose()
103 return soup.get_text(separator="\n", strip=True)[:max_chars]
104 except Exception as e:
105 log("fetch", f"Failed: {e}")
106 return ""
107
108def doc_search(query: str, n_results: int = 5) -> list[dict]:
109 if M.collection.count() == 0: return []
110 results = M.collection.query(query_embeddings=[M.embed(query)], n_results=n_results, include=["documents", "distances"])
111 if not results["documents"] or not results["documents"][0]: return []
112 docs = results["documents"][0]
113 dists = results.get("distances", [[1.0] * len(docs)])[0]
114 return [{"content": d, "score": 1 - dist} for d, dist in zip(docs, dists)]
115
116def index_documents(docs_dir: str = None):
117 path = Path(docs_dir or CFG.docs_dir)
118 if not path.exists():
119 log("docs", f"Directory not found: {path}")
120 return
121 docs, ids = [], []
122 for f in path.rglob("*"):
123 if f.suffix not in {".txt"}: continue
124 try:
125 content = f.read_text()
126 log("docs", f"Loading: {f.name} ({len(content)} chars)")
127 docs.append(content)
128 ids.append(str(f))
129 except Exception as e:
130 log("docs", f"Failed to load {f}: {e}")
131 if docs:
132 embeddings = [M.embed(d) for d in docs] # triggers _init()
133 M.collection.add(documents=docs, embeddings=embeddings, ids=ids)
134 log("docs", f"Indexed {len(docs)} chunks")
135
136# ==================== AGENTS ====================
137def parse_action(text: str) -> tuple[str | None, str]:
138 if m := re.search(r'\[\[(\w+):(.+?)\]\]', text, re.DOTALL): return m.group(1).upper(), m.group(2).strip()
139 if m := re.search(r'\[\[(\w+)\]\]', text): return m.group(1).upper(), ""
140 return None, ""
141
142def extract_findings(resp: str) -> str:
143 """Extract content from [[FINDINGS:...]] or return raw response."""
144 if m := re.search(r'\[\[FINDINGS:(.*?)\]\]', resp, re.DOTALL):
145 return m.group(1).strip()
146 return resp
147
148def agent(name: str, instructions: str, task: str, max_tokens: int = 512) -> str:
149 resp = M.generate(task, instructions=instructions, max_tokens=max_tokens)
150 log(name, resp[:1000] + ("..." if len(resp) > 1000 else ""))
151 return resp
152
153INSTRUCTIONS = {
154 "planner": """You are a research planner. Break the query into 3 subtopics MAX.
155Output EXACTLY: [[PLAN:\n- subtopic 1\n- subtopic 2\n]]\nKeep subtopics short (3-5 words). No explanations.""",
156
157 "researcher": """You are a research agent. Be CONCISE - max 2-3 sentences.
158Extract ANY facts from the documents that relate to the query.
159Output format: [[FINDINGS:\nThe relevant facts found.\n]]""",
160
161 "critic": """You are a research critic. Review findings for completeness and accuracy.
162If sufficient: [[SATISFIED]]
163If gaps exist: [[ISSUES:what specific information is missing]]
164Be concise and specific.""",
165
166 "writer": "You are a research writer. Be CONCISE and DIRECT. No fluff, no hedging. Just state the facts."
167}
168
169def plan(mem: Memory, query: str) -> list[str]:
170 resp = agent("planner", INSTRUCTIONS["planner"], f"Research query: {query}")
171 if m := re.search(r'\[\[PLAN:(.*?)\]\]', resp, re.DOTALL):
172 subtopics = [l.strip().lstrip("-").strip() for l in m.group(1).strip().split("\n")]
173 subtopics = [s for s in subtopics if len(s) > 3]
174 if subtopics:
175 mem.save("planner", query, "\n".join(subtopics))
176 return subtopics
177 return [query]
178
179def do_research(mem: Memory, query: str, source: str = "web"):
180 log("research", f"Searching {source.upper()} for: {query}")
181 results = doc_search(query) if source == "local" else web_search(query)
182 if not results:
183 log("research", f"No {source} results")
184 return
185 log("research", f"Found {len(results)} {source} results")
186 if source == "local":
187 content = "\n".join(f"[{i}] (sim: {r['score']:.2f})\n{r['content'][:1000]}" for i, r in enumerate(results, 1))
188 else:
189 content = "\n".join(f"[{i}] {r['title']}\n{r['url']}\n{r.get('content', r.get('snippet', ''))[:1000]}" for i, r in enumerate(results, 1))
190 prompt = f"Research query: {query}\n\nResults:\n{content[:3000]}\n\nExtract key findings."
191 findings = extract_findings(agent("research", INSTRUCTIONS["researcher"], prompt))
192 mem.save(source, query, findings)
193
194def critique(mem: Memory, query: str) -> tuple[bool, str]:
195 prompt = f"Original query: {query}\n\nResearch so far:\n{mem.all_summaries()}\n\nIs this sufficient?"
196 resp = agent("critic", INSTRUCTIONS["critic"], prompt, max_tokens=200)
197 action, arg = parse_action(resp)
198 if action == "SATISFIED": return True, "Research approved"
199 if action == "ISSUES":
200 mem.save("critic", "gap identified", arg)
201 return False, arg
202 return True, "Assumed complete"
203
204def write(mem: Memory, query: str) -> str:
205 fmt = lambda t, s, e: f"## {t}\n" + ("\n".join(f"- {f['summary']}" for f in mem.by_source(s)) or e)
206 sections = [
207 fmt("LOCAL DOCUMENTS", "local", "No relevant local documents."),
208 fmt("WEB SEARCH", "web", "No relevant web results.")
209 ]
210 prompt = f"Query: {query}\n\nFindings:\n{mem.all_summaries()}\n\nWrite a 2-3 sentence answer."
211 sections.append(f"## ANSWER\n{agent('writer', INSTRUCTIONS['writer'], prompt, 150)}")
212 return "\n\n".join(sections)
213
214# ==================== ORCHESTRATOR ====================
215def research(query: str, verbose: bool = True) -> dict:
216 vlog = (lambda phase, msg: print(f"\n[Phase {phase}] {msg}")) if verbose else (lambda *_: None)
217 if verbose: print(f"\n{'='*60}\nRESEARCH: {query}\n{'='*60}\n")
218
219 mem = Memory()
220 index_documents()
221
222 vlog(1, "Searching local documents...")
223 do_research(mem, query, "local")
224
225 vlog(2, "Planning web research...")
226 subtopics = plan(mem, query)
227 if verbose: print(f"Subtopics: {subtopics}\n")
228
229 vlog(3, "Web research...")
230 for topic in subtopics:
231 if verbose: print(f"\n--- Web: {topic} ---")
232 do_research(mem, topic, "web")
233
234 vlog(4, "Critique loop...")
235 for rnd in range(CFG.max_critique_rounds):
236 if verbose: print(f"\n--- Critique round {rnd+1} ---")
237 ok, feedback = critique(mem, query)
238 if ok:
239 if verbose: print("Critic satisfied")
240 break
241 if verbose: print(f"Gap: {feedback[:100]}...")
242 do_research(mem, feedback, "web")
243
244 vlog(5, "Writing synthesis...")
245 return {"query": query, "subtopics": subtopics, "answer": write(mem, query)}
246
247# ==================== CLI ====================
248def test_model():
249 log("test", "Loading model and asking: 'What is an apple?'")
250 resp = M.generate("What is an apple? Answer in 2-3 sentences.", max_tokens=100)
251 print(f"[response] {resp}\n[test] Done.")
252
253if __name__ == "__main__":
254 if len(sys.argv) < 2:
255 print("""
256Installation: (pytorch higher version should work as well, the gpu i have is a bit old)
257 pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121
258
259Usage:
260./search.py "your research question"
261./search.py --test # sanity check to test llm workability
262
263./search.py "what does mcdonalds serve?"
264./search.py "is new york one of the locations of the sept 11 attacks?"
265./search.py "strawberries. what colour are they?"
266""")
267 sys.exit(1)
268 if sys.argv[1] == "--test":
269 test_model()
270 else:
271 result = research(sys.argv[1])
272 print(f"\n{'='*60}\nFINAL ANSWER\n{'='*60}\n{result['answer']}")