faiss_rag_enterprise/llama_index/graph_stores/kuzu.py

229 lines
8.6 KiB
Python

"""Kùzu graph store index."""
from typing import Any, Dict, List, Optional
from llama_index.graph_stores.types import GraphStore
class KuzuGraphStore(GraphStore):
def __init__(
self,
database: Any,
node_table_name: str = "entity",
rel_table_name: str = "links",
**kwargs: Any,
) -> None:
try:
import kuzu
except ImportError:
raise ImportError("Please install kuzu: pip install kuzu")
self.database = database
self.connection = kuzu.Connection(database)
self.node_table_name = node_table_name
self.rel_table_name = rel_table_name
self.init_schema()
def init_schema(self) -> None:
"""Initialize schema if the tables do not exist."""
node_tables = self.connection._get_node_table_names()
if self.node_table_name not in node_tables:
self.connection.execute(
"CREATE NODE TABLE %s (ID STRING, PRIMARY KEY(ID))"
% self.node_table_name
)
rel_tables = self.connection._get_rel_table_names()
rel_tables = [rel_table["name"] for rel_table in rel_tables]
if self.rel_table_name not in rel_tables:
self.connection.execute(
"CREATE REL TABLE {} (FROM {} TO {}, predicate STRING)".format(
self.rel_table_name, self.node_table_name, self.node_table_name
)
)
@property
def client(self) -> Any:
return self.connection
def get(self, subj: str) -> List[List[str]]:
"""Get triplets."""
query = """
MATCH (n1:%s)-[r:%s]->(n2:%s)
WHERE n1.ID = $subj
RETURN r.predicate, n2.ID;
"""
prepared_statement = self.connection.prepare(
query % (self.node_table_name, self.rel_table_name, self.node_table_name)
)
query_result = self.connection.execute(prepared_statement, [("subj", subj)])
retval = []
while query_result.has_next():
row = query_result.get_next()
retval.append([row[0], row[1]])
return retval
def get_rel_map(
self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
) -> Dict[str, List[List[str]]]:
"""Get depth-aware rel map."""
rel_wildcard = "r:%s*1..%d" % (self.rel_table_name, depth)
match_clause = "MATCH (n1:{})-[{}]->(n2:{})".format(
self.node_table_name,
rel_wildcard,
self.node_table_name,
)
return_clause = "RETURN n1, r, n2 LIMIT %d" % limit
params = []
if subjs is not None:
for i, curr_subj in enumerate(subjs):
if i == 0:
where_clause = "WHERE n1.ID = $%d" % i
else:
where_clause += " OR n1.ID = $%d" % i
params.append((str(i), curr_subj))
else:
where_clause = ""
query = f"{match_clause} {where_clause} {return_clause}"
prepared_statement = self.connection.prepare(query)
if subjs is not None:
query_result = self.connection.execute(prepared_statement, params)
else:
query_result = self.connection.execute(prepared_statement)
retval: Dict[str, List[List[str]]] = {}
while query_result.has_next():
row = query_result.get_next()
curr_path = []
subj = row[0]
recursive_rel = row[1]
obj = row[2]
nodes_map = {}
nodes_map[(subj["_id"]["table"], subj["_id"]["offset"])] = subj["ID"]
nodes_map[(obj["_id"]["table"], obj["_id"]["offset"])] = obj["ID"]
for node in recursive_rel["_nodes"]:
nodes_map[(node["_id"]["table"], node["_id"]["offset"])] = node["ID"]
for rel in recursive_rel["_rels"]:
predicate = rel["predicate"]
curr_subj_id = nodes_map[(rel["_src"]["table"], rel["_src"]["offset"])]
curr_path.append(curr_subj_id)
curr_path.append(predicate)
# Add the last node
curr_path.append(obj["ID"])
if subj["ID"] not in retval:
retval[subj["ID"]] = []
retval[subj["ID"]].append(curr_path)
return retval
def upsert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet."""
def check_entity_exists(connection: Any, entity: str) -> bool:
is_exists_result = connection.execute(
"MATCH (n:%s) WHERE n.ID = $entity RETURN n.ID" % self.node_table_name,
[("entity", entity)],
)
return is_exists_result.has_next()
def create_entity(connection: Any, entity: str) -> None:
connection.execute(
"CREATE (n:%s {ID: $entity})" % self.node_table_name,
[("entity", entity)],
)
def check_rel_exists(connection: Any, subj: str, obj: str, rel: str) -> bool:
is_exists_result = connection.execute(
(
"MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $subj AND n2.ID = "
"$obj AND r.predicate = $pred RETURN r.predicate"
).format(
self.node_table_name, self.rel_table_name, self.node_table_name
),
[("subj", subj), ("obj", obj), ("pred", rel)],
)
return is_exists_result.has_next()
def create_rel(connection: Any, subj: str, obj: str, rel: str) -> None:
connection.execute(
(
"MATCH (n1:{}), (n2:{}) WHERE n1.ID = $subj AND n2.ID = $obj "
"CREATE (n1)-[r:{} {{predicate: $pred}}]->(n2)"
).format(
self.node_table_name, self.node_table_name, self.rel_table_name
),
[("subj", subj), ("obj", obj), ("pred", rel)],
)
is_subj_exists = check_entity_exists(self.connection, subj)
is_obj_exists = check_entity_exists(self.connection, obj)
if not is_subj_exists:
create_entity(self.connection, subj)
if not is_obj_exists:
create_entity(self.connection, obj)
if is_subj_exists and is_obj_exists:
is_rel_exists = check_rel_exists(self.connection, subj, obj, rel)
if is_rel_exists:
return
create_rel(self.connection, subj, obj, rel)
def delete(self, subj: str, rel: str, obj: str) -> None:
"""Delete triplet."""
def delete_rel(connection: Any, subj: str, obj: str, rel: str) -> None:
connection.execute(
(
"MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $subj AND n2.ID"
" = $obj AND r.predicate = $pred DELETE r"
).format(
self.node_table_name, self.rel_table_name, self.node_table_name
),
[("subj", subj), ("obj", obj), ("pred", rel)],
)
def delete_entity(connection: Any, entity: str) -> None:
connection.execute(
"MATCH (n:%s) WHERE n.ID = $entity DELETE n" % self.node_table_name,
[("entity", entity)],
)
def check_edges(connection: Any, entity: str) -> bool:
is_exists_result = connection.execute(
"MATCH (n1:{})-[r:{}]-(n2:{}) WHERE n2.ID = $entity RETURN r.predicate".format(
self.node_table_name, self.rel_table_name, self.node_table_name
),
[("entity", entity)],
)
return is_exists_result.has_next()
delete_rel(self.connection, subj, obj, rel)
if not check_edges(self.connection, subj):
delete_entity(self.connection, subj)
if not check_edges(self.connection, obj):
delete_entity(self.connection, obj)
@classmethod
def from_persist_dir(
cls,
persist_dir: str,
node_table_name: str = "entity",
rel_table_name: str = "links",
) -> "KuzuGraphStore":
"""Load from persist dir."""
try:
import kuzu
except ImportError:
raise ImportError("Please install kuzu: pip install kuzu")
database = kuzu.Database(persist_dir)
return cls(database, node_table_name, rel_table_name)
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "KuzuGraphStore":
"""Initialize graph store from configuration dictionary.
Args:
config_dict: Configuration dictionary.
Returns:
Graph store.
"""
return cls(**config_dict)