182 lines
6.0 KiB
Python
182 lines
6.0 KiB
Python
"""Simple graph store index."""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import fsspec
|
|
from dataclasses_json import DataClassJsonMixin
|
|
|
|
from llama_index.graph_stores.types import (
|
|
DEFAULT_PERSIST_DIR,
|
|
DEFAULT_PERSIST_FNAME,
|
|
GraphStore,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class SimpleGraphStoreData(DataClassJsonMixin):
|
|
"""Simple Graph Store Data container.
|
|
|
|
Args:
|
|
graph_dict (Optional[dict]): dict mapping subject to
|
|
"""
|
|
|
|
graph_dict: Dict[str, List[List[str]]] = field(default_factory=dict)
|
|
|
|
def get_rel_map(
|
|
self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
|
|
) -> Dict[str, List[List[str]]]:
|
|
"""Get subjects' rel map in max depth."""
|
|
if subjs is None:
|
|
subjs = list(self.graph_dict.keys())
|
|
rel_map = {}
|
|
for subj in subjs:
|
|
rel_map[subj] = self._get_rel_map(subj, depth=depth, limit=limit)
|
|
# TBD, truncate the rel_map in a spread way, now just truncate based
|
|
# on iteration order
|
|
rel_count = 0
|
|
return_map = {}
|
|
for subj in rel_map:
|
|
if rel_count + len(rel_map[subj]) > limit:
|
|
return_map[subj] = rel_map[subj][: limit - rel_count]
|
|
break
|
|
else:
|
|
return_map[subj] = rel_map[subj]
|
|
rel_count += len(rel_map[subj])
|
|
return return_map
|
|
|
|
def _get_rel_map(
|
|
self, subj: str, depth: int = 2, limit: int = 30
|
|
) -> List[List[str]]:
|
|
"""Get one subect's rel map in max depth."""
|
|
if depth == 0:
|
|
return []
|
|
rel_map = []
|
|
rel_count = 0
|
|
if subj in self.graph_dict:
|
|
for rel, obj in self.graph_dict[subj]:
|
|
if rel_count >= limit:
|
|
break
|
|
rel_map.append([subj, rel, obj])
|
|
rel_map += self._get_rel_map(obj, depth=depth - 1)
|
|
rel_count += 1
|
|
return rel_map
|
|
|
|
|
|
class SimpleGraphStore(GraphStore):
|
|
"""Simple Graph Store.
|
|
|
|
In this graph store, triplets are stored within a simple, in-memory dictionary.
|
|
|
|
Args:
|
|
simple_graph_store_data_dict (Optional[dict]): data dict
|
|
containing the triplets. See SimpleGraphStoreData
|
|
for more details.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data: Optional[SimpleGraphStoreData] = None,
|
|
fs: Optional[fsspec.AbstractFileSystem] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize params."""
|
|
self._data = data or SimpleGraphStoreData()
|
|
self._fs = fs or fsspec.filesystem("file")
|
|
|
|
@classmethod
|
|
def from_persist_dir(
|
|
cls,
|
|
persist_dir: str = DEFAULT_PERSIST_DIR,
|
|
fs: Optional[fsspec.AbstractFileSystem] = None,
|
|
) -> "SimpleGraphStore":
|
|
"""Load from persist dir."""
|
|
persist_path = os.path.join(persist_dir, DEFAULT_PERSIST_FNAME)
|
|
return cls.from_persist_path(persist_path, fs=fs)
|
|
|
|
@property
|
|
def client(self) -> None:
|
|
"""Get client.
|
|
Not applicable for this store.
|
|
"""
|
|
return
|
|
|
|
def get(self, subj: str) -> List[List[str]]:
|
|
"""Get triplets."""
|
|
return self._data.graph_dict.get(subj, [])
|
|
|
|
def get_rel_map(
|
|
self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
|
|
) -> Dict[str, List[List[str]]]:
|
|
"""Get depth-aware rel map."""
|
|
return self._data.get_rel_map(subjs=subjs, depth=depth, limit=limit)
|
|
|
|
def upsert_triplet(self, subj: str, rel: str, obj: str) -> None:
|
|
"""Add triplet."""
|
|
if subj not in self._data.graph_dict:
|
|
self._data.graph_dict[subj] = []
|
|
if (rel, obj) not in self._data.graph_dict[subj]:
|
|
self._data.graph_dict[subj].append([rel, obj])
|
|
|
|
def delete(self, subj: str, rel: str, obj: str) -> None:
|
|
"""Delete triplet."""
|
|
if subj in self._data.graph_dict:
|
|
if (rel, obj) in self._data.graph_dict[subj]:
|
|
self._data.graph_dict[subj].remove([rel, obj])
|
|
if len(self._data.graph_dict[subj]) == 0:
|
|
del self._data.graph_dict[subj]
|
|
|
|
def persist(
|
|
self,
|
|
persist_path: str = os.path.join(DEFAULT_PERSIST_DIR, DEFAULT_PERSIST_FNAME),
|
|
fs: Optional[fsspec.AbstractFileSystem] = None,
|
|
) -> None:
|
|
"""Persist the SimpleGraphStore to a directory."""
|
|
fs = fs or self._fs
|
|
dirpath = os.path.dirname(persist_path)
|
|
if not fs.exists(dirpath):
|
|
fs.makedirs(dirpath)
|
|
|
|
with fs.open(persist_path, "w") as f:
|
|
json.dump(self._data.to_dict(), f)
|
|
|
|
def get_schema(self, refresh: bool = False) -> str:
|
|
"""Get the schema of the Simple Graph store."""
|
|
raise NotImplementedError("SimpleGraphStore does not support get_schema")
|
|
|
|
def query(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any:
|
|
"""Query the Simple Graph store."""
|
|
raise NotImplementedError("SimpleGraphStore does not support query")
|
|
|
|
@classmethod
|
|
def from_persist_path(
|
|
cls, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None
|
|
) -> "SimpleGraphStore":
|
|
"""Create a SimpleGraphStore from a persist directory."""
|
|
fs = fs or fsspec.filesystem("file")
|
|
if not fs.exists(persist_path):
|
|
logger.warning(
|
|
f"No existing {__name__} found at {persist_path}. "
|
|
"Initializing a new graph_store from scratch. "
|
|
)
|
|
return cls()
|
|
|
|
logger.debug(f"Loading {__name__} from {persist_path}.")
|
|
with fs.open(persist_path, "rb") as f:
|
|
data_dict = json.load(f)
|
|
data = SimpleGraphStoreData.from_dict(data_dict)
|
|
return cls(data)
|
|
|
|
@classmethod
|
|
def from_dict(cls, save_dict: dict) -> "SimpleGraphStore":
|
|
data = SimpleGraphStoreData.from_dict(save_dict)
|
|
return cls(data)
|
|
|
|
def to_dict(self) -> dict:
|
|
return self._data.to_dict()
|