229 lines
8.6 KiB
Python
229 lines
8.6 KiB
Python
"""Kùzu graph store index."""
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from llama_index.graph_stores.types import GraphStore
|
|
|
|
|
|
class KuzuGraphStore(GraphStore):
|
|
def __init__(
|
|
self,
|
|
database: Any,
|
|
node_table_name: str = "entity",
|
|
rel_table_name: str = "links",
|
|
**kwargs: Any,
|
|
) -> None:
|
|
try:
|
|
import kuzu
|
|
except ImportError:
|
|
raise ImportError("Please install kuzu: pip install kuzu")
|
|
self.database = database
|
|
self.connection = kuzu.Connection(database)
|
|
self.node_table_name = node_table_name
|
|
self.rel_table_name = rel_table_name
|
|
self.init_schema()
|
|
|
|
def init_schema(self) -> None:
|
|
"""Initialize schema if the tables do not exist."""
|
|
node_tables = self.connection._get_node_table_names()
|
|
if self.node_table_name not in node_tables:
|
|
self.connection.execute(
|
|
"CREATE NODE TABLE %s (ID STRING, PRIMARY KEY(ID))"
|
|
% self.node_table_name
|
|
)
|
|
rel_tables = self.connection._get_rel_table_names()
|
|
rel_tables = [rel_table["name"] for rel_table in rel_tables]
|
|
if self.rel_table_name not in rel_tables:
|
|
self.connection.execute(
|
|
"CREATE REL TABLE {} (FROM {} TO {}, predicate STRING)".format(
|
|
self.rel_table_name, self.node_table_name, self.node_table_name
|
|
)
|
|
)
|
|
|
|
@property
|
|
def client(self) -> Any:
|
|
return self.connection
|
|
|
|
def get(self, subj: str) -> List[List[str]]:
|
|
"""Get triplets."""
|
|
query = """
|
|
MATCH (n1:%s)-[r:%s]->(n2:%s)
|
|
WHERE n1.ID = $subj
|
|
RETURN r.predicate, n2.ID;
|
|
"""
|
|
prepared_statement = self.connection.prepare(
|
|
query % (self.node_table_name, self.rel_table_name, self.node_table_name)
|
|
)
|
|
query_result = self.connection.execute(prepared_statement, [("subj", subj)])
|
|
retval = []
|
|
while query_result.has_next():
|
|
row = query_result.get_next()
|
|
retval.append([row[0], row[1]])
|
|
return retval
|
|
|
|
def get_rel_map(
|
|
self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
|
|
) -> Dict[str, List[List[str]]]:
|
|
"""Get depth-aware rel map."""
|
|
rel_wildcard = "r:%s*1..%d" % (self.rel_table_name, depth)
|
|
match_clause = "MATCH (n1:{})-[{}]->(n2:{})".format(
|
|
self.node_table_name,
|
|
rel_wildcard,
|
|
self.node_table_name,
|
|
)
|
|
return_clause = "RETURN n1, r, n2 LIMIT %d" % limit
|
|
params = []
|
|
if subjs is not None:
|
|
for i, curr_subj in enumerate(subjs):
|
|
if i == 0:
|
|
where_clause = "WHERE n1.ID = $%d" % i
|
|
else:
|
|
where_clause += " OR n1.ID = $%d" % i
|
|
params.append((str(i), curr_subj))
|
|
else:
|
|
where_clause = ""
|
|
query = f"{match_clause} {where_clause} {return_clause}"
|
|
prepared_statement = self.connection.prepare(query)
|
|
if subjs is not None:
|
|
query_result = self.connection.execute(prepared_statement, params)
|
|
else:
|
|
query_result = self.connection.execute(prepared_statement)
|
|
retval: Dict[str, List[List[str]]] = {}
|
|
while query_result.has_next():
|
|
row = query_result.get_next()
|
|
curr_path = []
|
|
subj = row[0]
|
|
recursive_rel = row[1]
|
|
obj = row[2]
|
|
nodes_map = {}
|
|
nodes_map[(subj["_id"]["table"], subj["_id"]["offset"])] = subj["ID"]
|
|
nodes_map[(obj["_id"]["table"], obj["_id"]["offset"])] = obj["ID"]
|
|
for node in recursive_rel["_nodes"]:
|
|
nodes_map[(node["_id"]["table"], node["_id"]["offset"])] = node["ID"]
|
|
for rel in recursive_rel["_rels"]:
|
|
predicate = rel["predicate"]
|
|
curr_subj_id = nodes_map[(rel["_src"]["table"], rel["_src"]["offset"])]
|
|
curr_path.append(curr_subj_id)
|
|
curr_path.append(predicate)
|
|
# Add the last node
|
|
curr_path.append(obj["ID"])
|
|
if subj["ID"] not in retval:
|
|
retval[subj["ID"]] = []
|
|
retval[subj["ID"]].append(curr_path)
|
|
return retval
|
|
|
|
def upsert_triplet(self, subj: str, rel: str, obj: str) -> None:
|
|
"""Add triplet."""
|
|
|
|
def check_entity_exists(connection: Any, entity: str) -> bool:
|
|
is_exists_result = connection.execute(
|
|
"MATCH (n:%s) WHERE n.ID = $entity RETURN n.ID" % self.node_table_name,
|
|
[("entity", entity)],
|
|
)
|
|
return is_exists_result.has_next()
|
|
|
|
def create_entity(connection: Any, entity: str) -> None:
|
|
connection.execute(
|
|
"CREATE (n:%s {ID: $entity})" % self.node_table_name,
|
|
[("entity", entity)],
|
|
)
|
|
|
|
def check_rel_exists(connection: Any, subj: str, obj: str, rel: str) -> bool:
|
|
is_exists_result = connection.execute(
|
|
(
|
|
"MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $subj AND n2.ID = "
|
|
"$obj AND r.predicate = $pred RETURN r.predicate"
|
|
).format(
|
|
self.node_table_name, self.rel_table_name, self.node_table_name
|
|
),
|
|
[("subj", subj), ("obj", obj), ("pred", rel)],
|
|
)
|
|
return is_exists_result.has_next()
|
|
|
|
def create_rel(connection: Any, subj: str, obj: str, rel: str) -> None:
|
|
connection.execute(
|
|
(
|
|
"MATCH (n1:{}), (n2:{}) WHERE n1.ID = $subj AND n2.ID = $obj "
|
|
"CREATE (n1)-[r:{} {{predicate: $pred}}]->(n2)"
|
|
).format(
|
|
self.node_table_name, self.node_table_name, self.rel_table_name
|
|
),
|
|
[("subj", subj), ("obj", obj), ("pred", rel)],
|
|
)
|
|
|
|
is_subj_exists = check_entity_exists(self.connection, subj)
|
|
is_obj_exists = check_entity_exists(self.connection, obj)
|
|
|
|
if not is_subj_exists:
|
|
create_entity(self.connection, subj)
|
|
if not is_obj_exists:
|
|
create_entity(self.connection, obj)
|
|
|
|
if is_subj_exists and is_obj_exists:
|
|
is_rel_exists = check_rel_exists(self.connection, subj, obj, rel)
|
|
if is_rel_exists:
|
|
return
|
|
|
|
create_rel(self.connection, subj, obj, rel)
|
|
|
|
def delete(self, subj: str, rel: str, obj: str) -> None:
|
|
"""Delete triplet."""
|
|
|
|
def delete_rel(connection: Any, subj: str, obj: str, rel: str) -> None:
|
|
connection.execute(
|
|
(
|
|
"MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $subj AND n2.ID"
|
|
" = $obj AND r.predicate = $pred DELETE r"
|
|
).format(
|
|
self.node_table_name, self.rel_table_name, self.node_table_name
|
|
),
|
|
[("subj", subj), ("obj", obj), ("pred", rel)],
|
|
)
|
|
|
|
def delete_entity(connection: Any, entity: str) -> None:
|
|
connection.execute(
|
|
"MATCH (n:%s) WHERE n.ID = $entity DELETE n" % self.node_table_name,
|
|
[("entity", entity)],
|
|
)
|
|
|
|
def check_edges(connection: Any, entity: str) -> bool:
|
|
is_exists_result = connection.execute(
|
|
"MATCH (n1:{})-[r:{}]-(n2:{}) WHERE n2.ID = $entity RETURN r.predicate".format(
|
|
self.node_table_name, self.rel_table_name, self.node_table_name
|
|
),
|
|
[("entity", entity)],
|
|
)
|
|
return is_exists_result.has_next()
|
|
|
|
delete_rel(self.connection, subj, obj, rel)
|
|
if not check_edges(self.connection, subj):
|
|
delete_entity(self.connection, subj)
|
|
if not check_edges(self.connection, obj):
|
|
delete_entity(self.connection, obj)
|
|
|
|
@classmethod
|
|
def from_persist_dir(
|
|
cls,
|
|
persist_dir: str,
|
|
node_table_name: str = "entity",
|
|
rel_table_name: str = "links",
|
|
) -> "KuzuGraphStore":
|
|
"""Load from persist dir."""
|
|
try:
|
|
import kuzu
|
|
except ImportError:
|
|
raise ImportError("Please install kuzu: pip install kuzu")
|
|
database = kuzu.Database(persist_dir)
|
|
return cls(database, node_table_name, rel_table_name)
|
|
|
|
@classmethod
|
|
def from_dict(cls, config_dict: Dict[str, Any]) -> "KuzuGraphStore":
|
|
"""Initialize graph store from configuration dictionary.
|
|
|
|
Args:
|
|
config_dict: Configuration dictionary.
|
|
|
|
Returns:
|
|
Graph store.
|
|
"""
|
|
return cls(**config_dict)
|