| | @@ -310,10 +310,138 @@ |
| 310 | 310 | data=locations, |
| 311 | 311 | query_type="filter", |
| 312 | 312 | raw_query=f"provenance({entity_name!r})", |
| 313 | 313 | explanation=f"Found {len(locations)} provenance records for '{entity_name}'", |
| 314 | 314 | ) |
| 315 | + |
| 316 | + def shortest_path(self, start: str, end: str, max_depth: int = 6) -> QueryResult: |
| 317 | + """Find the shortest path between two entities via BFS.""" |
| 318 | + start_entity = self.store.get_entity(start) |
| 319 | + end_entity = self.store.get_entity(end) |
| 320 | + if not start_entity: |
| 321 | + return QueryResult( |
| 322 | + data=[], |
| 323 | + query_type="filter", |
| 324 | + raw_query=f"shortest_path({start!r}, {end!r})", |
| 325 | + explanation=f"Entity '{start}' not found", |
| 326 | + ) |
| 327 | + if not end_entity: |
| 328 | + return QueryResult( |
| 329 | + data=[], |
| 330 | + query_type="filter", |
| 331 | + raw_query=f"shortest_path({start!r}, {end!r})", |
| 332 | + explanation=f"Entity '{end}' not found", |
| 333 | + ) |
| 334 | + |
| 335 | + all_rels = self.store.get_all_relationships() |
| 336 | + # Build adjacency list |
| 337 | + adj: dict[str, list[tuple[str, dict]]] = {} |
| 338 | + for rel in all_rels: |
| 339 | + src_l = rel["source"].lower() |
| 340 | + tgt_l = rel["target"].lower() |
| 341 | + adj.setdefault(src_l, []).append((tgt_l, rel)) |
| 342 | + adj.setdefault(tgt_l, []).append((src_l, rel)) |
| 343 | + |
| 344 | + # BFS |
| 345 | + start_l = start.lower() |
| 346 | + end_l = end.lower() |
| 347 | + if start_l == end_l: |
| 348 | + return QueryResult( |
| 349 | + data=[start_entity], |
| 350 | + query_type="filter", |
| 351 | + raw_query=f"shortest_path({start!r}, {end!r})", |
| 352 | + explanation="Start and end are the same entity", |
| 353 | + ) |
| 354 | + |
| 355 | + from collections import deque |
| 356 | + |
| 357 | + queue: deque[tuple[str, list[dict]]] = deque([(start_l, [])]) |
| 358 | + visited = {start_l} |
| 359 | + |
| 360 | + while queue: |
| 361 | + current, path = queue.popleft() |
| 362 | + if len(path) >= max_depth: |
| 363 | + continue |
| 364 | + for neighbor, rel in adj.get(current, []): |
| 365 | + if neighbor in visited: |
| 366 | + continue |
| 367 | + new_path = path + [rel] |
| 368 | + if neighbor == end_l: |
| 369 | + # Build result: entities + relationships along path |
| 370 | + path_entities = [start_entity] |
| 371 | + for r in new_path: |
| 372 | + path_entities.append(r) |
| 373 | + tgt_name = r["target"] if r["source"].lower() == current else r["source"] |
| 374 | + e = self.store.get_entity(tgt_name) |
| 375 | + if e: |
| 376 | + path_entities.append(e) |
| 377 | + path_entities.append(end_entity) |
| 378 | + # Deduplicate |
| 379 | + seen = set() |
| 380 | + deduped = [] |
| 381 | + for item in path_entities: |
| 382 | + key = str(item) |
| 383 | + if key not in seen: |
| 384 | + seen.add(key) |
| 385 | + deduped.append(item) |
| 386 | + return QueryResult( |
| 387 | + data=deduped, |
| 388 | + query_type="filter", |
| 389 | + raw_query=f"shortest_path({start!r}, {end!r})", |
| 390 | + explanation=f"Path found: {len(new_path)} hops", |
| 391 | + ) |
| 392 | + visited.add(neighbor) |
| 393 | + queue.append((neighbor, new_path)) |
| 394 | + |
| 395 | + return QueryResult( |
| 396 | + data=[], |
| 397 | + query_type="filter", |
| 398 | + raw_query=f"shortest_path({start!r}, {end!r})", |
| 399 | + explanation=f"No path found between '{start}' and '{end}' within {max_depth} hops", |
| 400 | + ) |
| 401 | + |
| 402 | + def clusters(self) -> QueryResult: |
| 403 | + """Find connected components (clusters) in the graph.""" |
| 404 | + all_entities = self.store.get_all_entities() |
| 405 | + all_rels = self.store.get_all_relationships() |
| 406 | + |
| 407 | + # Build adjacency |
| 408 | + adj: dict[str, set[str]] = {} |
| 409 | + for e in all_entities: |
| 410 | + adj.setdefault(e["name"].lower(), set()) |
| 411 | + for r in all_rels: |
| 412 | + adj.setdefault(r["source"].lower(), set()).add(r["target"].lower()) |
| 413 | + adj.setdefault(r["target"].lower(), set()).add(r["source"].lower()) |
| 414 | + |
| 415 | + visited: set[str] = set() |
| 416 | + components: list[list[str]] = [] |
| 417 | + |
| 418 | + for node in adj: |
| 419 | + if node in visited: |
| 420 | + continue |
| 421 | + component: list[str] = [] |
| 422 | + stack = [node] |
| 423 | + while stack: |
| 424 | + n = stack.pop() |
| 425 | + if n in visited: |
| 426 | + continue |
| 427 | + visited.add(n) |
| 428 | + component.append(n) |
| 429 | + stack.extend(adj.get(n, set()) - visited) |
| 430 | + components.append(sorted(component)) |
| 431 | + |
| 432 | + # Sort by size descending |
| 433 | + components.sort(key=len, reverse=True) |
| 434 | + |
| 435 | + result = [{"cluster_id": i, "size": len(c), "members": c} for i, c in enumerate(components)] |
| 436 | + |
| 437 | + return QueryResult( |
| 438 | + data=result, |
| 439 | + query_type="filter", |
| 440 | + raw_query="clusters()", |
| 441 | + explanation=f"Found {len(components)} clusters", |
| 442 | + ) |
| 315 | 443 | |
| 316 | 444 | def sql(self, query: str) -> QueryResult: |
| 317 | 445 | """Execute a raw SQL query (SQLite only).""" |
| 318 | 446 | result = self.store.raw_query(query) |
| 319 | 447 | return QueryResult( |
| | @@ -350,10 +478,12 @@ |
| 350 | 478 | f"Graph stats: {json.dumps(stats)}\n\n" |
| 351 | 479 | "Available actions (pick exactly one):\n" |
| 352 | 480 | '- {{"action": "entities", "name": "...", "entity_type": "..."}}\n' |
| 353 | 481 | '- {{"action": "relationships", "source": "...", "target": "...", "rel_type": "..."}}\n' |
| 354 | 482 | '- {{"action": "neighbors", "entity_name": "...", "depth": 1}}\n' |
| 483 | + '- {{"action": "shortest_path", "start": "...", "end": "..."}}\n' |
| 484 | + '- {{"action": "clusters"}}\n' |
| 355 | 485 | '- {{"action": "stats"}}\n\n' |
| 356 | 486 | f"User question: {question}\n\n" |
| 357 | 487 | "Return ONLY a JSON object with the action. Omit optional fields you don't need." |
| 358 | 488 | ) |
| 359 | 489 | |
| | @@ -398,10 +528,17 @@ |
| 398 | 528 | elif action == "neighbors": |
| 399 | 529 | result = self.neighbors( |
| 400 | 530 | entity_name=plan.get("entity_name", ""), |
| 401 | 531 | depth=plan.get("depth", 1), |
| 402 | 532 | ) |
| 533 | + elif action == "shortest_path": |
| 534 | + result = self.shortest_path( |
| 535 | + start=plan.get("start", ""), |
| 536 | + end=plan.get("end", ""), |
| 537 | + ) |
| 538 | + elif action == "clusters": |
| 539 | + result = self.clusters() |
| 403 | 540 | elif action == "stats": |
| 404 | 541 | result = self.stats() |
| 405 | 542 | else: |
| 406 | 543 | return QueryResult( |
| 407 | 544 | data=None, |
| 408 | 545 | |