diff --git a/graph_viewer.html b/graph_viewer.html
new file mode 100644
index 0000000..403b6ce
--- /dev/null
+++ b/graph_viewer.html
@@ -0,0 +1,529 @@
+
+
+
+
+
+ GraphResearcher 图谱可视化
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/memory/graph.py b/src/memory/graph.py
index 3691b6e..677fbd7 100644
--- a/src/memory/graph.py
+++ b/src/memory/graph.py
@@ -1,8 +1,14 @@
# src/memory/graph.py
+import asyncio
+import hashlib
+import json
+import os
+import uuid
from dataclasses import dataclass, field
-from enum import StrEnum
-from typing import Any, Literal
from datetime import datetime
+from enum import StrEnum
+from pathlib import Path
+from typing import Any, Optional
class NodeType(StrEnum):
@@ -41,6 +47,21 @@ class GraphNode:
type: NodeType
properties: dict[str, Any] = field(default_factory=dict)
+ def to_dict(self) -> dict:
+ return {
+ "id": self.id,
+ "type": self.type.value,
+ **self.properties
+ }
+
+ @classmethod
+ def from_dict(cls, data: dict) -> "GraphNode":
+ return cls(
+ id=data["id"],
+ type=NodeType(data["type"]),
+ properties={k: v for k, v in data.items() if k not in ("id", "type")}
+ )
+
@dataclass
class GraphEdge:
@@ -49,12 +70,290 @@ class GraphEdge:
type: EdgeType
properties: dict[str, Any] = field(default_factory=dict)
+ def to_dict(self) -> dict:
+ return {
+ "source": self.source,
+ "target": self.target,
+ "type": self.type.value,
+ **self.properties
+ }
+
+ @classmethod
+ def from_dict(cls, data: dict) -> "GraphEdge":
+ return cls(
+ source=data["source"],
+ target=data["target"],
+ type=EdgeType(data["type"]),
+ properties={k: v for k, v in data.items() if k not in ("source", "target", "type")}
+ )
+
+
+FILE_MARKER = {"type": "_graphresearcher", "source": "graph-researcher"}
+
+
+def get_default_graph_cache_dir() -> Path:
+ if os.environ.get("REPO_BASE_DIR"):
+ return Path(os.environ["REPO_BASE_DIR"]) / ".graphresearcher"
+ return Path.home() / ".graphresearcher"
+
+
+def get_graph_file_path(cache_dir: Path, context: Optional[str] = None) -> Path:
+ filename = "research_graph.jsonl" if context is None else f"research_graph-{context}.jsonl"
+ return cache_dir / filename
+
+
+def get_index_file_path(cache_dir: Path, context: Optional[str] = None) -> Path:
+ filename = "index.pkl" if context is None else f"index-{context}.pkl"
+ return cache_dir / filename
+
class GraphManager:
- def add_node(self, node_type: str, attributes: dict, node_id: str | None = None) -> str: ...
- def upsert_node(self, node_type: str, attributes: dict, dedupe_key: str | None = None) -> str: ...
- def add_edge(self, source_id: str, target_id: str, edge_type: str, attributes: dict | None = None) -> str: ...
- def get_node(self, node_id: str) -> dict: ...
- def search_node(self, node_type: str, filter: dict) -> list[dict]: ...
- def get_subgraph(self, node_id: str, depth: int = 2) -> dict: ...
- def save(self) -> None: ...
+ def __init__(self, cache_dir: Optional[str] = None, context: Optional[str] = None):
+ if cache_dir:
+ self._cache_dir = Path(cache_dir)
+ else:
+ self._cache_dir = get_default_graph_cache_dir()
+ self._cache_dir.mkdir(parents=True, exist_ok=True)
+ self._context = context
+
+ self._nodes: dict[str, GraphNode] = {}
+ self._edges: list[GraphEdge] = []
+ self._loaded = False
+ self._file_mtime: Optional[float] = None
+
+ self._out_edges: dict[str, list[GraphEdge]] = {}
+ self._in_edges: dict[str, list[GraphEdge]] = {}
+
+ async def _ensure_loaded(self) -> None:
+ if self._loaded:
+ file_path = get_graph_file_path(self._cache_dir, self._context)
+ try:
+ current_mtime = os.path.getmtime(file_path) if file_path.exists() else 0.0
+ if current_mtime != self._file_mtime:
+ await self._load_from_file()
+ except OSError:
+ await self._load_from_file()
+ else:
+ await self._load_from_file()
+
+ async def _load_from_file(self) -> None:
+ file_path = get_graph_file_path(self._cache_dir, self._context)
+
+ self._nodes.clear()
+ self._edges.clear()
+ self._out_edges.clear()
+ self._in_edges.clear()
+
+ if not file_path.exists():
+ self._loaded = True
+ self._file_mtime = 0.0
+ return
+
+ try:
+ self._file_mtime = os.path.getmtime(file_path)
+
+ with open(file_path, 'r', encoding='utf-8') as f:
+ lines = [line.strip() for line in f if line.strip()]
+
+ if not lines:
+ self._loaded = True
+ return
+
+ first_line = json.loads(lines[0])
+ if first_line.get("type") != "_graphresearcher" or first_line.get("source") != "graph-researcher":
+ raise ValueError(f"Invalid file marker in {file_path}")
+
+ for line in lines[1:]:
+ try:
+ item = json.loads(line)
+ if item.get("kind") == "node":
+ node = GraphNode.from_dict(item)
+ self._nodes[node.id] = node
+ elif item.get("kind") == "edge":
+ edge = GraphEdge.from_dict(item)
+ self._edges.append(edge)
+ self._out_edges.setdefault(edge.source, []).append(edge)
+ self._in_edges.setdefault(edge.target, []).append(edge)
+ except (json.JSONDecodeError, KeyError, ValueError):
+ continue
+
+ self._loaded = True
+
+ except FileNotFoundError:
+ self._loaded = True
+ self._file_mtime = 0.0
+
+ async def _save_to_file(self) -> None:
+ file_path = get_graph_file_path(self._cache_dir, self._context)
+ file_path.parent.mkdir(parents=True, exist_ok=True)
+
+ lines = [json.dumps(FILE_MARKER)]
+
+ for node in self._nodes.values():
+ lines.append(json.dumps({"kind": "node", **node.to_dict()}))
+
+ for edge in self._edges:
+ lines.append(json.dumps({"kind": "edge", **edge.to_dict()}))
+
+ with open(file_path, 'w', encoding='utf-8') as f:
+ for line in lines:
+ f.write(line + '\n')
+
+ try:
+ self._file_mtime = os.path.getmtime(file_path)
+ except OSError:
+ pass
+
+ def _generate_id(self) -> str:
+ return f"{datetime.now().strftime('%Y%m%d%H%M%S%f')}_{uuid.uuid4().hex[:8]}"
+
+ async def add_node(self, node_type: str, attributes: dict, node_id: Optional[str] = None) -> str:
+ await self._ensure_loaded()
+
+ if node_id is None:
+ node_id = self._generate_id()
+
+ node = GraphNode(
+ id=node_id,
+ type=NodeType(node_type),
+ properties=attributes
+ )
+ self._nodes[node_id] = node
+ await self._save_to_file()
+ return node_id
+
+ async def upsert_node(self, node_type: str, attributes: dict, dedupe_key: Optional[str] = None) -> str:
+ await self._ensure_loaded()
+
+ if dedupe_key and dedupe_key in self._nodes:
+ node = self._nodes[dedupe_key]
+ node.properties.update(attributes)
+ await self._save_to_file()
+ return dedupe_key
+
+ return await self.add_node(node_type, attributes)
+
+ async def add_edge(self, source_id: str, target_id: str, edge_type: str, attributes: Optional[dict] = None) -> str:
+ await self._ensure_loaded()
+
+ edge = GraphEdge(
+ source=source_id,
+ target=target_id,
+ type=EdgeType(edge_type),
+ properties=attributes or {}
+ )
+
+ self._edges.append(edge)
+ self._out_edges.setdefault(source_id, []).append(edge)
+ self._in_edges.setdefault(target_id, []).append(edge)
+
+ await self._save_to_file()
+ return f"{source_id}_{edge_type}_{target_id}"
+
+ async def get_node(self, node_id: str) -> dict:
+ await self._ensure_loaded()
+
+ if node_id in self._nodes:
+ return self._nodes[node_id].to_dict()
+ return {}
+
+ async def search_node(self, node_type: str, filter: dict) -> list[dict]:
+ await self._ensure_loaded()
+
+ results = []
+ for node in self._nodes.values():
+ if node.type.value == node_type:
+ match = True
+ for key, value in filter.items():
+ if key not in node.properties or node.properties[key] != value:
+ match = False
+ break
+ if match:
+ results.append(node.to_dict())
+ return results
+
+ async def get_subgraph(self, node_id: str, depth: int = 2) -> dict:
+ await self._ensure_loaded()
+
+ if node_id not in self._nodes:
+ return {"nodes": [], "edges": []}
+
+ visited_nodes = set()
+ visited_edges = set()
+ nodes_to_visit = [(node_id, 0)]
+
+ while nodes_to_visit:
+ current_id, current_depth = nodes_to_visit.pop(0)
+ if current_id in visited_nodes or current_depth > depth:
+ continue
+
+ visited_nodes.add(current_id)
+
+ for edge in self._in_edges.get(current_id, []):
+ edge_key = (edge.source, edge.type.value, edge.target)
+ if edge_key not in visited_edges:
+ visited_edges.add(edge_key)
+ nodes_to_visit.append((edge.source, current_depth + 1))
+
+ for edge in self._out_edges.get(current_id, []):
+ edge_key = (edge.source, edge.type.value, edge.target)
+ if edge_key not in visited_edges:
+ visited_edges.add(edge_key)
+ nodes_to_visit.append((edge.target, current_depth + 1))
+
+ result_nodes = []
+ for nid in visited_nodes:
+ if nid in self._nodes:
+ result_nodes.append(self._nodes[nid].to_dict())
+
+ result_edges = []
+ for edge in self._edges:
+ edge_key = (edge.source, edge.type.value, edge.target)
+ if edge_key in visited_edges:
+ result_edges.append(edge.to_dict())
+
+ return {"nodes": result_nodes, "edges": result_edges}
+
+ async def save(self) -> None:
+ await self._ensure_loaded()
+ await self._save_to_file()
+
+ async def get_all_nodes(self) -> list[dict]:
+ await self._ensure_loaded()
+ return [node.to_dict() for node in self._nodes.values()]
+
+ async def get_all_edges(self) -> list[dict]:
+ await self._ensure_loaded()
+ return [edge.to_dict() for edge in self._edges]
+
+ async def delete_node(self, node_id: str) -> None:
+ await self._ensure_loaded()
+
+ if node_id not in self._nodes:
+ return
+
+ del self._nodes[node_id]
+
+ self._edges = [
+ e for e in self._edges
+ if e.source != node_id and e.target != node_id
+ ]
+
+ if node_id in self._out_edges:
+ del self._out_edges[node_id]
+ if node_id in self._in_edges:
+ del self._in_edges[node_id]
+
+ await self._save_to_file()
+
+ async def get_node_edges(self, node_id: str, direction: str = "both") -> list[dict]:
+ await self._ensure_loaded()
+
+ edges = []
+ if direction in ("out", "both"):
+ for edge in self._out_edges.get(node_id, []):
+ edges.append(edge.to_dict())
+ if direction in ("in", "both"):
+ for edge in self._in_edges.get(node_id, []):
+ edges.append(edge.to_dict())
+ return edges
\ No newline at end of file
diff --git a/src/tools/graph_tools.py b/src/tools/graph_tools.py
index 6ea0eed..e5e22e5 100644
--- a/src/tools/graph_tools.py
+++ b/src/tools/graph_tools.py
@@ -1 +1,600 @@
-# 参照 /mnt/h/DeepResearch/BioDSA/biodsa/memory/memory_graph/tool.py
+import asyncio
+import os
+from pathlib import Path
+from typing import Any, Optional
+
+try:
+ import nest_asyncio
+ nest_asyncio.apply()
+ HAS_NEST_ASYNCIO = True
+except ImportError:
+ HAS_NEST_ASYNCIO = False
+
+from memory.graph import NodeType, EdgeType, GraphNode, GraphEdge
+
+
+FILE_MARKER = {"type": "_graphresearcher", "source": "graph-researcher"}
+
+
+def get_default_graph_cache_dir() -> Path:
+ """
+ 获取默认的图谱缓存目录
+
+ 默认情况下,图谱数据会存储在用户主目录下的 `.graphresearcher` 文件夹中。
+ 如果设置了环境变量 `REPO_BASE_DIR`,则会存储在该目录下。
+
+ Returns:
+ Path: 默认缓存目录的路径
+ """
+ if os.environ.get("REPO_BASE_DIR"):
+ return Path(os.environ["REPO_BASE_DIR"]) / ".graphresearcher"
+ return Path.home() / ".graphresearcher"
+
+
+def get_graph_file_path(cache_dir: Path, context: Optional[str] = None) -> Path:
+ """
+ 获取图谱文件的路径
+
+ 根据上下文名称生成图谱文件名。如果没有提供上下文名称,使用默认文件名。
+
+ Args:
+ cache_dir: 缓存目录路径
+ context: 上下文名称(可选),用于区分不同的图谱实例
+
+ Returns:
+ Path: 图谱文件的完整路径
+ """
+ filename = "research_graph.jsonl" if context is None else f"research_graph-{context}.jsonl"
+ return cache_dir / filename
+
+
+def _run_async(coro):
+ """
+ 同步运行异步函数的内部辅助函数
+
+ 处理不同的异步运行时场景:
+ 1. 如果当前没有运行中的事件循环,直接使用 asyncio.run()
+ 2. 如果安装了 nest_asyncio,使用 asyncio.run()
+ 3. 否则使用线程池执行异步函数
+
+ Args:
+ coro: 异步协程对象
+
+ Returns:
+ 协程执行结果
+ """
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ return asyncio.run(coro)
+
+ if HAS_NEST_ASYNCIO:
+ return asyncio.run(coro)
+
+ import concurrent.futures
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future = executor.submit(asyncio.run, coro)
+ return future.result()
+
+
+_manager_cache = {}
+
+
+def _get_manager(cache_dir: Optional[str] = None):
+ """
+ 获取或创建 GraphManager 实例(内部缓存)
+
+ 使用单例模式管理 GraphManager 实例,避免重复创建。
+
+ Args:
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ GraphManager: 图谱管理器实例
+ """
+ from memory.graph import GraphManager
+ cache_key = cache_dir if cache_dir else "default"
+ if cache_key not in _manager_cache:
+ _manager_cache[cache_key] = GraphManager(cache_dir=cache_dir)
+ return _manager_cache[cache_key]
+
+
+def clear_manager_cache(cache_dir: Optional[str] = None):
+ """
+ 清除 GraphManager 缓存
+
+ 如果指定了缓存目录,则只清除该目录对应的管理器;否则清除所有缓存。
+
+ Args:
+ cache_dir: 缓存目录路径(可选)
+ """
+ if cache_dir:
+ cache_key = cache_dir if cache_dir else "default"
+ if cache_key in _manager_cache:
+ del _manager_cache[cache_key]
+ else:
+ _manager_cache.clear()
+
+
+def add_node(
+ node_type: str,
+ attributes: dict[str, Any],
+ node_id: Optional[str] = None,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 添加节点到图谱
+
+ 在图谱中创建一个新节点。如果没有指定 node_id,系统会自动生成一个唯一ID。
+
+ Args:
+ node_type: 节点类型(如 "query", "evidence", "source" 等)
+ attributes: 节点属性字典
+ node_id: 节点ID(可选),如果不提供会自动生成
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 创建的节点信息
+ """
+ manager = _get_manager(cache_dir)
+ return _run_async(manager.add_node(node_type, attributes, node_id))
+
+
+def upsert_node(
+ node_type: str,
+ attributes: dict[str, Any],
+ dedupe_key: Optional[str] = None,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 更新或插入节点
+
+ 如果节点已存在(通过 dedupe_key 匹配),则更新其属性;否则创建新节点。
+
+ Args:
+ node_type: 节点类型
+ attributes: 节点属性字典
+ dedupe_key: 去重键,通常是节点ID
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 更新或创建的节点信息
+ """
+ manager = _get_manager(cache_dir)
+ return _run_async(manager.upsert_node(node_type, attributes, dedupe_key))
+
+
+def add_edge(
+ source_id: str,
+ target_id: str,
+ edge_type: str,
+ attributes: Optional[dict[str, Any]] = None,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 添加边(关系)到图谱
+
+ 在两个节点之间创建一条边,定义它们之间的关系。
+
+ Args:
+ source_id: 源节点ID
+ target_id: 目标节点ID
+ edge_type: 边类型(如 "has_query", "supported_by", "uses" 等)
+ attributes: 边属性字典(可选)
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 创建的边信息
+ """
+ manager = _get_manager(cache_dir)
+ return _run_async(manager.add_edge(source_id, target_id, edge_type, attributes))
+
+
+def get_node(
+ node_id: str,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 根据ID获取节点
+
+ 从图谱中检索指定ID的节点信息。
+
+ Args:
+ node_id: 节点ID
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 节点信息,如果不存在返回空字典
+ """
+ manager = _get_manager(cache_dir)
+ return _run_async(manager.get_node(node_id))
+
+
+def search_nodes(
+ node_type: str,
+ filter: dict[str, Any],
+ cache_dir: Optional[str] = None
+) -> list[dict[str, Any]]:
+ """
+ 搜索节点
+
+ 根据节点类型和过滤条件检索节点列表。
+
+ Args:
+ node_type: 节点类型
+ filter: 过滤条件字典,键为属性名,值为属性值
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ list: 匹配条件的节点列表
+ """
+ manager = _get_manager(cache_dir)
+ return _run_async(manager.search_node(node_type, filter))
+
+
+def get_subgraph(
+ node_id: str,
+ depth: int = 2,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 获取子图(Ego-network)
+
+ 获取指定节点周围一定深度内的所有节点和边,形成子图。
+
+ Args:
+ node_id: 中心节点ID
+ depth: 搜索深度,默认为2(直接邻居及其邻居)
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 包含 nodes 和 edges 的子图数据
+ """
+ manager = _get_manager(cache_dir)
+ return _run_async(manager.get_subgraph(node_id, depth))
+
+
+def save_graph(cache_dir: Optional[str] = None) -> None:
+ """
+ 保存图谱到磁盘
+
+ 将当前内存中的图谱数据持久化到文件。
+
+ Args:
+ cache_dir: 缓存目录路径(可选)
+ """
+ manager = _get_manager(cache_dir)
+ return _run_async(manager.save())
+
+
+def create_user_request(
+ request: str,
+ research_goal: Optional[str] = None,
+ language: Optional[str] = None,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 创建用户请求节点
+
+ 创建一个表示用户研究请求的节点,包含用户输入的查询内容。
+
+ Args:
+ request: 用户请求文本
+ research_goal: 研究目标(可选)
+ language: 语言标识(可选),如 "zh" 或 "en"
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 创建的用户请求节点
+ """
+ manager = _get_manager(cache_dir)
+ attributes = {"request": request}
+ if research_goal:
+ attributes["research_goal"] = research_goal
+ if language:
+ attributes["language"] = language
+ return _run_async(manager.add_node(NodeType.USER_REQUEST, attributes))
+
+
+def create_query(
+ text: str,
+ status: str = "pending",
+ parent_id: Optional[str] = None,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 创建查询节点
+
+ 创建一个研究查询节点,用于表示一个需要搜索的问题或主题。
+
+ Args:
+ text: 查询文本内容
+ status: 查询状态,默认为 "pending"(待处理)
+ parent_id: 父节点ID(可选),用于建立层级关系
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 创建的查询节点
+ """
+ manager = _get_manager(cache_dir)
+ query_id = _run_async(manager.add_node(NodeType.QUERY, {"text": text, "status": status}))
+ if parent_id:
+ _run_async(manager.add_edge(parent_id, query_id, EdgeType.HAS_QUERY))
+ return query_id
+
+
+def create_source(
+ url: str,
+ title: Optional[str] = None,
+ source_type: Optional[str] = None,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 创建来源节点
+
+ 创建一个表示信息来源的节点,如网页、论文、文档等。
+
+ Args:
+ url: 来源URL地址
+ title: 来源标题(可选)
+ source_type: 来源类型(可选),如 "webpage", "paper", "github" 等
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 创建的来源节点
+ """
+ manager = _get_manager(cache_dir)
+ attributes = {"url": url}
+ if title:
+ attributes["title"] = title
+ if source_type:
+ attributes["source_type"] = source_type
+ return _run_async(manager.add_node(NodeType.SOURCE, attributes))
+
+
+def create_document(
+ text: str,
+ source_id: Optional[str] = None,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 创建文档节点
+
+ 创建一个表示文档内容的节点,通常与来源节点关联。
+
+ Args:
+ text: 文档文本内容
+ source_id: 来源节点ID(可选),建立文档与来源的关联
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 创建的文档节点
+ """
+ manager = _get_manager(cache_dir)
+ attributes = {"text": text}
+ if source_id:
+ attributes["source_id"] = source_id
+ doc_id = _run_async(manager.add_node(NodeType.DOCUMENT, attributes))
+ if source_id:
+ _run_async(manager.add_edge(source_id, doc_id, EdgeType.HAS_DOCUMENT))
+ return doc_id
+
+
+def create_evidence(
+ claim: str,
+ document_id: Optional[str] = None,
+ source_id: Optional[str] = None,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 创建证据节点
+
+ 创建一个表示证据的节点,用于支持或反驳某个主张(claim)。
+
+ Args:
+ claim: 证据支持的主张文本
+ document_id: 文档节点ID(可选),表示证据来源于哪个文档
+ source_id: 来源节点ID(可选),表示证据来源于哪个来源
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 创建的证据节点
+ """
+ manager = _get_manager(cache_dir)
+ evidence_id = _run_async(manager.add_node(NodeType.EVIDENCE, {"claim": claim, "status": "pending"}))
+ if document_id:
+ _run_async(manager.add_edge(document_id, evidence_id, EdgeType.EXTRACTED_STATEMENT))
+ if source_id:
+ _run_async(manager.add_edge(source_id, evidence_id, EdgeType.SUPPORTED_BY))
+ return evidence_id
+
+
+def create_conflict(
+ description: str,
+ evidence_ids: list[str],
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 创建冲突节点
+
+ 创建一个表示证据冲突的节点,用于记录多个证据之间的矛盾。
+
+ Args:
+ description: 冲突描述
+ evidence_ids: 涉及冲突的证据节点ID列表
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 创建的冲突节点
+ """
+ manager = _get_manager(cache_dir)
+ conflict_id = _run_async(manager.add_node(NodeType.CONFLICT, {"description": description}))
+ for evidence_id in evidence_ids:
+ _run_async(manager.add_edge(conflict_id, evidence_id, EdgeType.HAS_CONFLICT))
+ return conflict_id
+
+
+def create_analysis(
+ text: str,
+ query_id: Optional[str] = None,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 创建分析节点
+
+ 创建一个表示中间分析结果的节点,用于记录阶段性结论、框架或取舍方案。
+
+ Args:
+ text: 分析内容文本
+ query_id: 查询节点ID(可选),表示该分析解决了哪个查询的问题
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 创建的分析节点
+ """
+ manager = _get_manager(cache_dir)
+ analysis_id = _run_async(manager.add_node(NodeType.ANALYSIS, {"text": text}))
+ if query_id:
+ _run_async(manager.add_edge(query_id, analysis_id, EdgeType.RESOLVES_GAP))
+ return analysis_id
+
+
+def create_report_section(
+ title: str,
+ content: str,
+ section_type: Optional[str] = None,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 创建报告章节节点
+
+ 创建一个表示报告章节的节点,用于存储报告的各个章节内容。
+
+ Args:
+ title: 章节标题
+ content: 章节内容
+ section_type: 章节类型(可选),如 "introduction", "methodology", "conclusion" 等
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 创建的报告章节节点
+ """
+ manager = _get_manager(cache_dir)
+ attributes = {"title": title, "content": content}
+ if section_type:
+ attributes["section_type"] = section_type
+ return _run_async(manager.add_node(NodeType.REPORT_SECTION, attributes))
+
+
+def bind_evidence_to_section(
+ section_id: str,
+ evidence_ids: list[str],
+ claim_ids: Optional[list[str]] = None,
+ cache_dir: Optional[str] = None
+) -> None:
+ """
+ 将证据绑定到报告章节
+
+ 建立报告章节与证据之间的引用关系,确保报告内容有证据支持。
+
+ Args:
+ section_id: 报告章节节点ID
+ evidence_ids: 证据节点ID列表
+ claim_ids: 主张节点ID列表(可选)
+ cache_dir: 缓存目录路径(可选)
+ """
+ manager = _get_manager(cache_dir)
+ for evidence_id in evidence_ids:
+ _run_async(manager.add_edge(section_id, evidence_id, EdgeType.USES))
+ if claim_ids:
+ for claim_id in claim_ids:
+ _run_async(manager.add_edge(section_id, claim_id, EdgeType.USES))
+
+
+def get_pending_queries(
+ cache_dir: Optional[str] = None
+) -> list[dict[str, Any]]:
+ """
+ 获取待处理的查询列表
+
+ 检索所有状态为 "pending" 的查询节点,用于调度处理。
+
+ Args:
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ list: 待处理查询节点列表
+ """
+ manager = _get_manager(cache_dir)
+ return _run_async(manager.search_node(NodeType.QUERY, {"status": "pending"}))
+
+
+def get_query_subgraph(
+ query_id: str,
+ depth: int = 2,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 获取查询的子图
+
+ 获取指定查询节点及其相关的所有证据、文档、来源等形成的子图。
+
+ Args:
+ query_id: 查询节点ID
+ depth: 搜索深度,默认为2
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 查询的子图数据
+ """
+ manager = _get_manager(cache_dir)
+ return _run_async(manager.get_subgraph(query_id, depth))
+
+
+def mark_query_status(
+ query_id: str,
+ status: str,
+ cache_dir: Optional[str] = None
+) -> dict[str, Any]:
+ """
+ 更新查询状态
+
+ 修改查询节点的状态,如从 "pending" 变为 "completed" 或 "blocked"。
+
+ Args:
+ query_id: 查询节点ID
+ status: 新状态
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 更新后的查询节点
+ """
+ manager = _get_manager(cache_dir)
+ return _run_async(manager.upsert_node(NodeType.QUERY, {"status": status}, dedupe_key=query_id))
+
+
+def list_graph_databases(cache_dir: Optional[str] = None) -> dict[str, Any]:
+ """
+ 列出可用的图谱数据库
+
+ 扫描缓存目录,列出所有已保存的图谱数据库。
+
+ Args:
+ cache_dir: 缓存目录路径(可选)
+
+ Returns:
+ dict: 包含数据库列表和位置信息的字典
+ """
+ cache_path = Path(cache_dir) if cache_dir else get_default_graph_cache_dir()
+ result = {"databases": [], "location": str(cache_path)}
+ try:
+ if cache_path.exists():
+ files = list(cache_path.glob("*.jsonl"))
+ for file in files:
+ if file.name == "research_graph.jsonl":
+ result["databases"].append("default")
+ elif file.name.startswith("research_graph-") and file.name.endswith(".jsonl"):
+ result["databases"].append(file.name[15:-6])
+ except (OSError, PermissionError):
+ result["databases"] = []
+ return result
\ No newline at end of file
diff --git a/tests/test_graph_tools.py b/tests/test_graph_tools.py
new file mode 100644
index 0000000..0d3d8f0
--- /dev/null
+++ b/tests/test_graph_tools.py
@@ -0,0 +1,206 @@
+#!/usr/bin/env python3
+"""
+测试 graph_tools.py 的功能
+"""
+
+import sys
+from pathlib import Path
+
+# 添加 src 目录到 Python 路径
+src_dir = Path(__file__).parent.parent / "src"
+sys.path.insert(0, str(src_dir))
+
+from tools.graph_tools import (
+ create_user_request,
+ create_query,
+ create_source,
+ create_document,
+ create_evidence,
+ create_conflict,
+ create_analysis,
+ create_report_section,
+ bind_evidence_to_section,
+ get_pending_queries,
+ get_query_subgraph,
+ mark_query_status,
+ get_node,
+ list_graph_databases,
+ save_graph,
+ clear_manager_cache
+)
+
+def test_basic_operations():
+ """测试基础的图谱操作"""
+ print("=" * 60)
+ print("测试 1: 创建用户请求和查询")
+ print("=" * 60)
+
+ # 1. 创建用户请求
+ user_request = create_user_request(
+ request="研究人工智能的发展趋势",
+ research_goal="了解AI在过去10年的发展历程和未来方向",
+ language="zh"
+ )
+ print(f"✓ 创建用户请求: {user_request}")
+
+ # 2. 创建查询
+ query1 = create_query(
+ text="AI发展历程 2014-2024",
+ parent_id=user_request
+ )
+ print(f"✓ 创建查询1: {query1}")
+
+ query2 = create_query(
+ text="AI未来发展趋势预测",
+ parent_id=user_request
+ )
+ print(f"✓ 创建查询2: {query2}")
+
+ print()
+ print("=" * 60)
+ print("测试 2: 创建来源、文档和证据")
+ print("=" * 60)
+
+ # 3. 创建来源
+ source1 = create_source(
+ url="https://example.com/ai-trends",
+ title="人工智能发展报告",
+ source_type="webpage"
+ )
+ print(f"✓ 创建来源1: {source1}")
+
+ # 4. 创建文档
+ doc1 = create_document(
+ text="人工智能在过去10年快速发展,特别是深度学习技术的突破。"
+ "GPT系列模型在2020年后迅速崛起,改变了自然语言处理领域。",
+ source_id=source1
+ )
+ print(f"✓ 创建文档1: {doc1}")
+
+ # 5. 创建证据
+ evidence1 = create_evidence(
+ claim="AI发展在过去10年加速,主要驱动力是深度学习",
+ document_id=doc1,
+ source_id=source1
+ )
+ print(f"✓ 创建证据1: {evidence1}")
+
+ evidence2 = create_evidence(
+ claim="GPT系列模型在2020年后对NLP领域产生重大影响",
+ document_id=doc1,
+ source_id=source1
+ )
+ print(f"✓ 创建证据2: {evidence2}")
+
+ print()
+ print("=" * 60)
+ print("测试 3: 创建冲突和分析")
+ print("=" * 60)
+
+ # 6. 创建另一个来源和冲突证据
+ source2 = create_source(
+ url="https://example.com/ai-skeptic",
+ title="AI发展的不同观点",
+ source_type="webpage"
+ )
+ print(f"✓ 创建来源2: {source2}")
+
+ doc2 = create_document(
+ text="有人认为AI的发展被过度炒作,实际进展有限。",
+ source_id=source2
+ )
+ print(f"✓ 创建文档2: {doc2}")
+
+ evidence3 = create_evidence(
+ claim="AI发展被过度炒作,实际进展有限",
+ document_id=doc2,
+ source_id=source2
+ )
+ print(f"✓ 创建证据3: {evidence3}")
+
+ # 7. 创建冲突
+ conflict1 = create_conflict(
+ description="关于AI发展速度的两种观点冲突",
+ evidence_ids=[evidence1, evidence3]
+ )
+ print(f"✓ 创建冲突1: {conflict1}")
+
+ # 8. 创建分析
+ analysis1 = create_analysis(
+ text="AI发展确实取得了显著进展,但也存在一定炒作。需要理性看待。",
+ query_id=query1
+ )
+ print(f"✓ 创建分析1: {analysis1}")
+
+ print()
+ print("=" * 60)
+ print("测试 4: 创建报告章节并绑定证据")
+ print("=" * 60)
+
+ # 9. 创建报告章节
+ section1 = create_report_section(
+ title="AI发展历程",
+ content="人工智能在过去10年经历了快速发展,深度学习技术是主要驱动力。",
+ section_type="introduction"
+ )
+ print(f"✓ 创建报告章节1: {section1}")
+
+ # 10. 绑定证据到章节
+ bind_evidence_to_section(
+ section_id=section1,
+ evidence_ids=[evidence1, evidence2]
+ )
+ print("✓ 绑定证据到报告章节")
+
+ print()
+ print("=" * 60)
+ print("测试 5: 查询和更新操作")
+ print("=" * 60)
+
+ # 11. 获取待处理查询
+ pending_queries = get_pending_queries()
+ print(f"✓ 待处理查询: {len(pending_queries)} 个")
+ for q in pending_queries:
+ print(f" - {q}")
+
+ # 12. 标记查询为已完成
+ updated_query1 = mark_query_status(query1, "completed")
+ print(f"✓ 更新查询状态: {updated_query1}")
+
+ # 13. 获取查询的子图
+ query_subgraph = get_query_subgraph(query1)
+ print(f"✓ 查询子图: {len(query_subgraph['nodes'])} 个节点, {len(query_subgraph['edges'])} 条边")
+
+ # 14. 获取节点详情
+ node_details = get_node(evidence1)
+ print(f"✓ 获取证据节点详情: {node_details}")
+
+ print()
+ print("=" * 60)
+ print("测试 6: 列出图谱数据库")
+ print("=" * 60)
+
+ databases = list_graph_databases()
+ print(f"✓ 图谱存储位置: {databases['location']}")
+ print(f"✓ 可用数据库: {databases['databases']}")
+
+ # 15. 保存图谱
+ save_graph()
+ print("✓ 图谱已保存到磁盘")
+
+ print()
+ print("=" * 60)
+ print("✅ 所有测试完成!")
+ print("=" * 60)
+
+ # 清理缓存
+ clear_manager_cache()
+
+if __name__ == "__main__":
+ print("\n开始测试 graph_tools.py...\n")
+ try:
+ test_basic_operations()
+ except Exception as e:
+ print(f"\n❌ 测试失败: {e}")
+ import traceback
+ traceback.print_exc()