faiss_rag_enterprise/llama_index/graph_stores/simple.py

182 lines
6.0 KiB
Python

"""Simple graph store index."""
import json
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import fsspec
from dataclasses_json import DataClassJsonMixin
from llama_index.graph_stores.types import (
DEFAULT_PERSIST_DIR,
DEFAULT_PERSIST_FNAME,
GraphStore,
)
logger = logging.getLogger(__name__)
@dataclass
class SimpleGraphStoreData(DataClassJsonMixin):
"""Simple Graph Store Data container.
Args:
graph_dict (Optional[dict]): dict mapping subject to
"""
graph_dict: Dict[str, List[List[str]]] = field(default_factory=dict)
def get_rel_map(
self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
) -> Dict[str, List[List[str]]]:
"""Get subjects' rel map in max depth."""
if subjs is None:
subjs = list(self.graph_dict.keys())
rel_map = {}
for subj in subjs:
rel_map[subj] = self._get_rel_map(subj, depth=depth, limit=limit)
# TBD, truncate the rel_map in a spread way, now just truncate based
# on iteration order
rel_count = 0
return_map = {}
for subj in rel_map:
if rel_count + len(rel_map[subj]) > limit:
return_map[subj] = rel_map[subj][: limit - rel_count]
break
else:
return_map[subj] = rel_map[subj]
rel_count += len(rel_map[subj])
return return_map
def _get_rel_map(
self, subj: str, depth: int = 2, limit: int = 30
) -> List[List[str]]:
"""Get one subect's rel map in max depth."""
if depth == 0:
return []
rel_map = []
rel_count = 0
if subj in self.graph_dict:
for rel, obj in self.graph_dict[subj]:
if rel_count >= limit:
break
rel_map.append([subj, rel, obj])
rel_map += self._get_rel_map(obj, depth=depth - 1)
rel_count += 1
return rel_map
class SimpleGraphStore(GraphStore):
"""Simple Graph Store.
In this graph store, triplets are stored within a simple, in-memory dictionary.
Args:
simple_graph_store_data_dict (Optional[dict]): data dict
containing the triplets. See SimpleGraphStoreData
for more details.
"""
def __init__(
self,
data: Optional[SimpleGraphStoreData] = None,
fs: Optional[fsspec.AbstractFileSystem] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
self._data = data or SimpleGraphStoreData()
self._fs = fs or fsspec.filesystem("file")
@classmethod
def from_persist_dir(
cls,
persist_dir: str = DEFAULT_PERSIST_DIR,
fs: Optional[fsspec.AbstractFileSystem] = None,
) -> "SimpleGraphStore":
"""Load from persist dir."""
persist_path = os.path.join(persist_dir, DEFAULT_PERSIST_FNAME)
return cls.from_persist_path(persist_path, fs=fs)
@property
def client(self) -> None:
"""Get client.
Not applicable for this store.
"""
return
def get(self, subj: str) -> List[List[str]]:
"""Get triplets."""
return self._data.graph_dict.get(subj, [])
def get_rel_map(
self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
) -> Dict[str, List[List[str]]]:
"""Get depth-aware rel map."""
return self._data.get_rel_map(subjs=subjs, depth=depth, limit=limit)
def upsert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet."""
if subj not in self._data.graph_dict:
self._data.graph_dict[subj] = []
if (rel, obj) not in self._data.graph_dict[subj]:
self._data.graph_dict[subj].append([rel, obj])
def delete(self, subj: str, rel: str, obj: str) -> None:
"""Delete triplet."""
if subj in self._data.graph_dict:
if (rel, obj) in self._data.graph_dict[subj]:
self._data.graph_dict[subj].remove([rel, obj])
if len(self._data.graph_dict[subj]) == 0:
del self._data.graph_dict[subj]
def persist(
self,
persist_path: str = os.path.join(DEFAULT_PERSIST_DIR, DEFAULT_PERSIST_FNAME),
fs: Optional[fsspec.AbstractFileSystem] = None,
) -> None:
"""Persist the SimpleGraphStore to a directory."""
fs = fs or self._fs
dirpath = os.path.dirname(persist_path)
if not fs.exists(dirpath):
fs.makedirs(dirpath)
with fs.open(persist_path, "w") as f:
json.dump(self._data.to_dict(), f)
def get_schema(self, refresh: bool = False) -> str:
"""Get the schema of the Simple Graph store."""
raise NotImplementedError("SimpleGraphStore does not support get_schema")
def query(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any:
"""Query the Simple Graph store."""
raise NotImplementedError("SimpleGraphStore does not support query")
@classmethod
def from_persist_path(
cls, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None
) -> "SimpleGraphStore":
"""Create a SimpleGraphStore from a persist directory."""
fs = fs or fsspec.filesystem("file")
if not fs.exists(persist_path):
logger.warning(
f"No existing {__name__} found at {persist_path}. "
"Initializing a new graph_store from scratch. "
)
return cls()
logger.debug(f"Loading {__name__} from {persist_path}.")
with fs.open(persist_path, "rb") as f:
data_dict = json.load(f)
data = SimpleGraphStoreData.from_dict(data_dict)
return cls(data)
@classmethod
def from_dict(cls, save_dict: dict) -> "SimpleGraphStore":
data = SimpleGraphStoreData.from_dict(save_dict)
return cls(data)
def to_dict(self) -> dict:
return self._data.to_dict()