diff options
| -rw-r--r-- | documents/doc1.txt | 3 | ||||
| -rw-r--r-- | documents/doc2.txt | 3 | ||||
| -rw-r--r-- | readme | 16 | ||||
| -rw-r--r-- | requirements.txt | 9 | ||||
| -rwxr-xr-x | search.py | 272 |
5 files changed, 303 insertions, 0 deletions
diff --git a/documents/doc1.txt b/documents/doc1.txt new file mode 100644 index 0000000..56f5cd5 --- /dev/null +++ b/documents/doc1.txt | |||
| @@ -0,0 +1,3 @@ | |||
| 1 | horse number 77331893112373437 jumped over a greyhound | ||
| 2 | strawberries are red in colour | ||
| 3 | mcdonalds serve fast food | ||
diff --git a/documents/doc2.txt b/documents/doc2.txt new file mode 100644 index 0000000..7d5a79d --- /dev/null +++ b/documents/doc2.txt | |||
| @@ -0,0 +1,3 @@ | |||
| 1 | john built a building named "wonkystairs" in 2002 | ||
| 2 | one of the locations of the september 11 attacks is new york | ||
| 3 | "The Straits Times" is a newspaper from singapore | ||
| @@ -0,0 +1,16 @@ | |||
| 1 | #Installation: (pytorch higher version should work as well, the gpu i have is a bit old) | ||
| 2 | python3 -m venv venv && source venv/bin/activate | ||
| 3 | pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 | ||
| 4 | |||
| 5 | #Usage: | ||
| 6 | ./search.py "your research question" | ||
| 7 | ./search.py --test # sanity check to test llm workability | ||
| 8 | ./search.py "what does mcdonalds serve?" | ||
| 9 | ./search.py "is new york one of the locations of the sept 11 attacks?" | ||
| 10 | ./search.py "strawberries. what colour are they?" | ||
| 11 | |||
| 12 | #Documentation: | ||
| 13 | This program uses two research tools: | ||
| 14 | - local document(in ./documents) search using an embedding model + ChromaDB for semantic retrieval | ||
| 15 | - web search using duckduckgo | ||
| 16 | A local llm orchestrates the research loop, deciding when enough information has been gathered, then provides a final concise summary. | ||
diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d003057 --- /dev/null +++ b/requirements.txt | |||
| @@ -0,0 +1,9 @@ | |||
| 1 | torch==2.5.1+cu121 | ||
| 2 | torchvision==0.20.1+cu121 | ||
| 3 | transformers | ||
| 4 | accelerate | ||
| 5 | sentence-transformers | ||
| 6 | chromadb | ||
| 7 | httpx | ||
| 8 | beautifulsoup4 | ||
| 9 | ddgs | ||
diff --git a/search.py b/search.py new file mode 100755 index 0000000..c370ffe --- /dev/null +++ b/search.py | |||
| @@ -0,0 +1,272 @@ | |||
| 1 | #!/usr/bin/env python3 | ||
| 2 | """ | ||
| 3 | multi-agent deep research thingy | ||
| 4 | """ | ||
| 5 | |||
| 6 | |||
| 7 | import os, re, sys | ||
| 8 | from dataclasses import dataclass, field | ||
| 9 | from pathlib import Path | ||
| 10 | import torch | ||
| 11 | from transformers import AutoModelForCausalLM, AutoTokenizer | ||
| 12 | from sentence_transformers import SentenceTransformer | ||
| 13 | import chromadb | ||
| 14 | import httpx | ||
| 15 | from bs4 import BeautifulSoup | ||
| 16 | from ddgs import DDGS | ||
| 17 | |||
| 18 | # ==================== CONFIG ==================== | ||
| 19 | @dataclass | ||
| 20 | class Config: | ||
| 21 | model_name: str = "Qwen/Qwen2.5-1.5B-Instruct" | ||
| 22 | embedding_model: str = "all-MiniLM-L6-v2" | ||
| 23 | device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu") | ||
| 24 | docs_dir: str = field(default_factory=lambda: os.getenv("DOCS_DIR", "./documents")) | ||
| 25 | max_critique_rounds: int = 3 | ||
| 26 | |||
| 27 | CFG = Config() | ||
| 28 | log = lambda tag, msg: print(f"[{tag}] {msg}") | ||
| 29 | |||
| 30 | # ==================== MODEL init ==================== | ||
| 31 | class _Models: | ||
| 32 | def __init__(self): | ||
| 33 | self._ready = False | ||
| 34 | self.embedder = self.tokenizer = self.llm = self.collection = None | ||
| 35 | |||
| 36 | def _init(self): | ||
| 37 | if self._ready: return | ||
| 38 | log("init", f"Device: {CFG.device}") | ||
| 39 | log("init", f"Loading embedder ({CFG.embedding_model})...") | ||
| 40 | self.embedder = SentenceTransformer(CFG.embedding_model, device=CFG.device) | ||
| 41 | log("init", f"Loading LLM ({CFG.model_name})...") | ||
| 42 | self.tokenizer = AutoTokenizer.from_pretrained(CFG.model_name) | ||
| 43 | self.llm = AutoModelForCausalLM.from_pretrained( | ||
| 44 | CFG.model_name, dtype=torch.float16, device_map="auto" | ||
| 45 | ) | ||
| 46 | client = chromadb.Client() | ||
| 47 | self.collection = client.get_or_create_collection("research_docs", metadata={"hnsw:space": "cosine"}) | ||
| 48 | log("init", "Ready.") | ||
| 49 | self._ready = True | ||
| 50 | |||
| 51 | def embed(self, text: str) -> list[float]: | ||
| 52 | self._init() | ||
| 53 | return self.embedder.encode(text).tolist() | ||
| 54 | |||
| 55 | def generate(self, task: str, instructions: str = "", max_tokens: int = 512) -> str: | ||
| 56 | self._init() | ||
| 57 | msgs = ([{"role": "system", "content": instructions}] if instructions else []) + [{"role": "user", "content": task}] | ||
| 58 | text = self.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | ||
| 59 | inputs = self.tokenizer(text, return_tensors="pt").to(self.llm.device) | ||
| 60 | with torch.no_grad(): | ||
| 61 | out = self.llm.generate(**inputs, max_new_tokens=max_tokens, temperature=0.7, | ||
| 62 | do_sample=True, pad_token_id=self.tokenizer.eos_token_id) | ||
| 63 | return self.tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() | ||
| 64 | |||
| 65 | M = _Models() | ||
| 66 | |||
| 67 | # ==================== MEMORY ==================== | ||
| 68 | @dataclass | ||
| 69 | class Memory: | ||
| 70 | findings: list = field(default_factory=list) | ||
| 71 | |||
| 72 | def save(self, source: str, query: str, content: str) -> str: | ||
| 73 | summary = (content[:300].replace("\n", " ").strip() + "...") if len(content) > 300 else content | ||
| 74 | self.findings.append({"source": source, "query": query, "summary": summary}) | ||
| 75 | return summary | ||
| 76 | |||
| 77 | def by_source(self, src: str) -> list[dict]: | ||
| 78 | return [f for f in self.findings if f["source"] == src] | ||
| 79 | |||
| 80 | def all_summaries(self) -> str: | ||
| 81 | return "\n".join(f"- [{f['source']}] {f['query']}: {f['summary']}" for f in self.findings) | ||
| 82 | |||
| 83 | # ==================== TOOLS ==================== | ||
| 84 | def web_search(query: str, max_results: int = 5) -> list[dict]: | ||
| 85 | try: | ||
| 86 | log("duck", f"Searching: {query}") | ||
| 87 | with DDGS() as ddgs: | ||
| 88 | raw = list(ddgs.text(query, max_results=max_results)) | ||
| 89 | results = [{"title": r.get("title", ""), "snippet": r.get("body", ""), "url": r.get("href", "")} for r in raw] | ||
| 90 | for item in results[:2]: | ||
| 91 | item["content"] = fetch_url(item["url"]) | ||
| 92 | return results | ||
| 93 | except Exception as e: | ||
| 94 | log("duck", f"Error: {e}") | ||
| 95 | return [] | ||
| 96 | |||
| 97 | def fetch_url(url: str, max_chars: int = 3000) -> str: | ||
| 98 | try: | ||
| 99 | r = httpx.get(url, timeout=2, follow_redirects=True, | ||
| 100 | headers={"User-Agent": "Mozilla/5.0 (compatible; ResearchBot/1.0)"}) | ||
| 101 | soup = BeautifulSoup(r.text, "html.parser") | ||
| 102 | for tag in soup(["script", "style", "nav", "header", "footer"]): tag.decompose() | ||
| 103 | return soup.get_text(separator="\n", strip=True)[:max_chars] | ||
| 104 | except Exception as e: | ||
| 105 | log("fetch", f"Failed: {e}") | ||
| 106 | return "" | ||
| 107 | |||
| 108 | def doc_search(query: str, n_results: int = 5) -> list[dict]: | ||
| 109 | if M.collection.count() == 0: return [] | ||
| 110 | results = M.collection.query(query_embeddings=[M.embed(query)], n_results=n_results, include=["documents", "distances"]) | ||
| 111 | if not results["documents"] or not results["documents"][0]: return [] | ||
| 112 | docs = results["documents"][0] | ||
| 113 | dists = results.get("distances", [[1.0] * len(docs)])[0] | ||
| 114 | return [{"content": d, "score": 1 - dist} for d, dist in zip(docs, dists)] | ||
| 115 | |||
| 116 | def index_documents(docs_dir: str = None): | ||
| 117 | path = Path(docs_dir or CFG.docs_dir) | ||
| 118 | if not path.exists(): | ||
| 119 | log("docs", f"Directory not found: {path}") | ||
| 120 | return | ||
| 121 | docs, ids = [], [] | ||
| 122 | for f in path.rglob("*"): | ||
| 123 | if f.suffix not in {".txt"}: continue | ||
| 124 | try: | ||
| 125 | content = f.read_text() | ||
| 126 | log("docs", f"Loading: {f.name} ({len(content)} chars)") | ||
| 127 | docs.append(content) | ||
| 128 | ids.append(str(f)) | ||
| 129 | except Exception as e: | ||
| 130 | log("docs", f"Failed to load {f}: {e}") | ||
| 131 | if docs: | ||
| 132 | embeddings = [M.embed(d) for d in docs] # triggers _init() | ||
| 133 | M.collection.add(documents=docs, embeddings=embeddings, ids=ids) | ||
| 134 | log("docs", f"Indexed {len(docs)} chunks") | ||
| 135 | |||
| 136 | # ==================== AGENTS ==================== | ||
| 137 | def parse_action(text: str) -> tuple[str | None, str]: | ||
| 138 | if m := re.search(r'\[\[(\w+):(.+?)\]\]', text, re.DOTALL): return m.group(1).upper(), m.group(2).strip() | ||
| 139 | if m := re.search(r'\[\[(\w+)\]\]', text): return m.group(1).upper(), "" | ||
| 140 | return None, "" | ||
| 141 | |||
| 142 | def extract_findings(resp: str) -> str: | ||
| 143 | """Extract content from [[FINDINGS:...]] or return raw response.""" | ||
| 144 | if m := re.search(r'\[\[FINDINGS:(.*?)\]\]', resp, re.DOTALL): | ||
| 145 | return m.group(1).strip() | ||
| 146 | return resp | ||
| 147 | |||
| 148 | def agent(name: str, instructions: str, task: str, max_tokens: int = 512) -> str: | ||
| 149 | resp = M.generate(task, instructions=instructions, max_tokens=max_tokens) | ||
| 150 | log(name, resp[:1000] + ("..." if len(resp) > 1000 else "")) | ||
| 151 | return resp | ||
| 152 | |||
| 153 | INSTRUCTIONS = { | ||
| 154 | "planner": """You are a research planner. Break the query into 3 subtopics MAX. | ||
| 155 | Output EXACTLY: [[PLAN:\n- subtopic 1\n- subtopic 2\n]]\nKeep subtopics short (3-5 words). No explanations.""", | ||
| 156 | |||
| 157 | "researcher": """You are a research agent. Be CONCISE - max 2-3 sentences. | ||
| 158 | Extract ANY facts from the documents that relate to the query. | ||
| 159 | Output format: [[FINDINGS:\nThe relevant facts found.\n]]""", | ||
| 160 | |||
| 161 | "critic": """You are a research critic. Review findings for completeness and accuracy. | ||
| 162 | If sufficient: [[SATISFIED]] | ||
| 163 | If gaps exist: [[ISSUES:what specific information is missing]] | ||
| 164 | Be concise and specific.""", | ||
| 165 | |||
| 166 | "writer": "You are a research writer. Be CONCISE and DIRECT. No fluff, no hedging. Just state the facts." | ||
| 167 | } | ||
| 168 | |||
| 169 | def plan(mem: Memory, query: str) -> list[str]: | ||
| 170 | resp = agent("planner", INSTRUCTIONS["planner"], f"Research query: {query}") | ||
| 171 | if m := re.search(r'\[\[PLAN:(.*?)\]\]', resp, re.DOTALL): | ||
| 172 | subtopics = [l.strip().lstrip("-").strip() for l in m.group(1).strip().split("\n")] | ||
| 173 | subtopics = [s for s in subtopics if len(s) > 3] | ||
| 174 | if subtopics: | ||
| 175 | mem.save("planner", query, "\n".join(subtopics)) | ||
| 176 | return subtopics | ||
| 177 | return [query] | ||
| 178 | |||
| 179 | def do_research(mem: Memory, query: str, source: str = "web"): | ||
| 180 | log("research", f"Searching {source.upper()} for: {query}") | ||
| 181 | results = doc_search(query) if source == "local" else web_search(query) | ||
| 182 | if not results: | ||
| 183 | log("research", f"No {source} results") | ||
| 184 | return | ||
| 185 | log("research", f"Found {len(results)} {source} results") | ||
| 186 | if source == "local": | ||
| 187 | content = "\n".join(f"[{i}] (sim: {r['score']:.2f})\n{r['content'][:1000]}" for i, r in enumerate(results, 1)) | ||
| 188 | else: | ||
| 189 | content = "\n".join(f"[{i}] {r['title']}\n{r['url']}\n{r.get('content', r.get('snippet', ''))[:1000]}" for i, r in enumerate(results, 1)) | ||
| 190 | prompt = f"Research query: {query}\n\nResults:\n{content[:3000]}\n\nExtract key findings." | ||
| 191 | findings = extract_findings(agent("research", INSTRUCTIONS["researcher"], prompt)) | ||
| 192 | mem.save(source, query, findings) | ||
| 193 | |||
| 194 | def critique(mem: Memory, query: str) -> tuple[bool, str]: | ||
| 195 | prompt = f"Original query: {query}\n\nResearch so far:\n{mem.all_summaries()}\n\nIs this sufficient?" | ||
| 196 | resp = agent("critic", INSTRUCTIONS["critic"], prompt, max_tokens=200) | ||
| 197 | action, arg = parse_action(resp) | ||
| 198 | if action == "SATISFIED": return True, "Research approved" | ||
| 199 | if action == "ISSUES": | ||
| 200 | mem.save("critic", "gap identified", arg) | ||
| 201 | return False, arg | ||
| 202 | return True, "Assumed complete" | ||
| 203 | |||
| 204 | def write(mem: Memory, query: str) -> str: | ||
| 205 | fmt = lambda t, s, e: f"## {t}\n" + ("\n".join(f"- {f['summary']}" for f in mem.by_source(s)) or e) | ||
| 206 | sections = [ | ||
| 207 | fmt("LOCAL DOCUMENTS", "local", "No relevant local documents."), | ||
| 208 | fmt("WEB SEARCH", "web", "No relevant web results.") | ||
| 209 | ] | ||
| 210 | prompt = f"Query: {query}\n\nFindings:\n{mem.all_summaries()}\n\nWrite a 2-3 sentence answer." | ||
| 211 | sections.append(f"## ANSWER\n{agent('writer', INSTRUCTIONS['writer'], prompt, 150)}") | ||
| 212 | return "\n\n".join(sections) | ||
| 213 | |||
| 214 | # ==================== ORCHESTRATOR ==================== | ||
| 215 | def research(query: str, verbose: bool = True) -> dict: | ||
| 216 | vlog = (lambda phase, msg: print(f"\n[Phase {phase}] {msg}")) if verbose else (lambda *_: None) | ||
| 217 | if verbose: print(f"\n{'='*60}\nRESEARCH: {query}\n{'='*60}\n") | ||
| 218 | |||
| 219 | mem = Memory() | ||
| 220 | index_documents() | ||
| 221 | |||
| 222 | vlog(1, "Searching local documents...") | ||
| 223 | do_research(mem, query, "local") | ||
| 224 | |||
| 225 | vlog(2, "Planning web research...") | ||
| 226 | subtopics = plan(mem, query) | ||
| 227 | if verbose: print(f"Subtopics: {subtopics}\n") | ||
| 228 | |||
| 229 | vlog(3, "Web research...") | ||
| 230 | for topic in subtopics: | ||
| 231 | if verbose: print(f"\n--- Web: {topic} ---") | ||
| 232 | do_research(mem, topic, "web") | ||
| 233 | |||
| 234 | vlog(4, "Critique loop...") | ||
| 235 | for rnd in range(CFG.max_critique_rounds): | ||
| 236 | if verbose: print(f"\n--- Critique round {rnd+1} ---") | ||
| 237 | ok, feedback = critique(mem, query) | ||
| 238 | if ok: | ||
| 239 | if verbose: print("Critic satisfied") | ||
| 240 | break | ||
| 241 | if verbose: print(f"Gap: {feedback[:100]}...") | ||
| 242 | do_research(mem, feedback, "web") | ||
| 243 | |||
| 244 | vlog(5, "Writing synthesis...") | ||
| 245 | return {"query": query, "subtopics": subtopics, "answer": write(mem, query)} | ||
| 246 | |||
| 247 | # ==================== CLI ==================== | ||
| 248 | def test_model(): | ||
| 249 | log("test", "Loading model and asking: 'What is an apple?'") | ||
| 250 | resp = M.generate("What is an apple? Answer in 2-3 sentences.", max_tokens=100) | ||
| 251 | print(f"[response] {resp}\n[test] Done.") | ||
| 252 | |||
| 253 | if __name__ == "__main__": | ||
| 254 | if len(sys.argv) < 2: | ||
| 255 | print(""" | ||
| 256 | Installation: (pytorch higher version should work as well, the gpu i have is a bit old) | ||
| 257 | pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 | ||
| 258 | |||
| 259 | Usage: | ||
| 260 | ./search.py "your research question" | ||
| 261 | ./search.py --test # sanity check to test llm workability | ||
| 262 | |||
| 263 | ./search.py "what does mcdonalds serve?" | ||
| 264 | ./search.py "is new york one of the locations of the sept 11 attacks?" | ||
| 265 | ./search.py "strawberries. what colour are they?" | ||
| 266 | """) | ||
| 267 | sys.exit(1) | ||
| 268 | if sys.argv[1] == "--test": | ||
| 269 | test_model() | ||
| 270 | else: | ||
| 271 | result = research(sys.argv[1]) | ||
| 272 | print(f"\n{'='*60}\nFINAL ANSWER\n{'='*60}\n{result['answer']}") | ||
