faiss_rag_enterprise/llama_index/graph_stores/nebulagraph.py

677 lines
25 KiB
Python

"""NebulaGraph graph store index."""
import logging
import os
from string import Template
from typing import Any, Dict, List, Optional
from tenacity import retry, stop_after_attempt, wait_random_exponential
from llama_index.graph_stores.types import GraphStore
QUOTE = '"'
RETRY_TIMES = 3
WAIT_MIN_SECONDS = 0.5
WAIT_MAX_SECONDS = 10
logger = logging.getLogger(__name__)
rel_query_sample_edge = Template(
"""
MATCH ()-[e:`$edge_type`]->()
RETURN [src(e), dst(e)] AS sample_edge LIMIT 1
"""
)
rel_query_edge_type = Template(
"""
MATCH (m)-[:`$edge_type`]->(n)
WHERE id(m) == $quote$src_id$quote AND id(n) == $quote$dst_id$quote
RETURN "(:" + tags(m)[0] + ")-[:$edge_type]->(:" + tags(n)[0] + ")" AS rels
"""
)
def hash_string_to_rank(string: str) -> int:
# get signed 64-bit hash value
signed_hash = hash(string)
# reduce the hash value to a 64-bit range
mask = (1 << 64) - 1
signed_hash &= mask
# convert the signed hash value to an unsigned 64-bit integer
if signed_hash & (1 << 63):
unsigned_hash = -((signed_hash ^ mask) + 1)
else:
unsigned_hash = signed_hash
return unsigned_hash
def prepare_subjs_param(
subjs: Optional[List[str]], vid_type: str = "FIXED_STRING(256)"
) -> Dict:
"""Prepare parameters for query."""
if subjs is None:
return {}
from nebula3.common import ttypes
subjs_list = []
subjs_byte = ttypes.Value()
# filter non-digit string for INT64 vid type
if vid_type == "INT64":
subjs = [subj for subj in subjs if subj.isdigit()]
if len(subjs) == 0:
logger.warning(
f"KG is with INT64 vid type, but no digit string is provided."
f"Return empty subjs, and no query will be executed."
f"subjs: {subjs}"
)
return {}
for subj in subjs:
if not isinstance(subj, str):
raise TypeError(f"Subject should be str, but got {type(subj).__name__}.")
subj_byte = ttypes.Value()
if vid_type == "INT64":
assert subj.isdigit(), (
"Subject should be a digit string in current "
"graph store, where vid type is INT64."
)
subj_byte.set_iVal(int(subj))
else:
subj_byte.set_sVal(subj)
subjs_list.append(subj_byte)
subjs_nlist = ttypes.NList(values=subjs_list)
subjs_byte.set_lVal(subjs_nlist)
return {"subjs": subjs_byte}
def escape_str(value: str) -> str:
"""Escape String for NebulaGraph Query."""
patterns = {
'"': " ",
}
for pattern in patterns:
if pattern in value:
value = value.replace(pattern, patterns[pattern])
if value[0] == " " or value[-1] == " ":
value = value.strip()
return value
class NebulaGraphStore(GraphStore):
"""NebulaGraph graph store."""
def __init__(
self,
session_pool: Optional[Any] = None,
space_name: Optional[str] = None,
edge_types: Optional[List[str]] = ["relationship"],
rel_prop_names: Optional[List[str]] = ["relationship,"],
tags: Optional[List[str]] = ["entity"],
tag_prop_names: Optional[List[str]] = ["name,"],
include_vid: bool = True,
session_pool_kwargs: Optional[Dict[str, Any]] = {},
**kwargs: Any,
) -> None:
"""Initialize NebulaGraph graph store.
Args:
session_pool: NebulaGraph session pool.
space_name: NebulaGraph space name.
edge_types: Edge types.
rel_prop_names: Relation property names corresponding to edge types.
tags: Tags.
tag_prop_names: Tag property names corresponding to tags.
session_pool_kwargs: Keyword arguments for NebulaGraph session pool.
**kwargs: Keyword arguments.
"""
try:
import nebula3 # noqa
except ImportError:
raise ImportError(
"Please install NebulaGraph Python client first: "
"`pip install nebula3-python`"
)
assert space_name is not None, "space_name should be provided."
self._space_name = space_name
self._session_pool_kwargs = session_pool_kwargs
self._session_pool: Any = session_pool
if self._session_pool is None:
self.init_session_pool()
self._vid_type = self._get_vid_type()
self._tags = tags or ["entity"]
self._edge_types = edge_types or ["rel"]
self._rel_prop_names = rel_prop_names or ["predicate,"]
if len(self._edge_types) != len(self._rel_prop_names):
raise ValueError(
"edge_types and rel_prop_names to define relation and relation name"
"should be provided, yet with same length."
)
if len(self._edge_types) == 0:
raise ValueError("Length of `edge_types` should be greater than 0.")
if tag_prop_names is None or len(self._tags) != len(tag_prop_names):
raise ValueError(
"tag_prop_names to define tag and tag property name should be "
"provided, yet with same length."
)
if len(self._tags) == 0:
raise ValueError("Length of `tags` should be greater than 0.")
# for building query
self._edge_dot_rel = [
f"`{edge_type}`.`{rel_prop_name}`"
for edge_type, rel_prop_name in zip(self._edge_types, self._rel_prop_names)
]
self._edge_prop_map = {}
for edge_type, rel_prop_name in zip(self._edge_types, self._rel_prop_names):
self._edge_prop_map[edge_type] = [
prop.strip() for prop in rel_prop_name.split(",")
]
# cypher string like: map{`follow`: "degree", `serve`: "start_year,end_year"}
self._edge_prop_map_cypher_string = (
"map{"
+ ", ".join(
[
f"`{edge_type}`: \"{','.join(rel_prop_names)}\""
for edge_type, rel_prop_names in self._edge_prop_map.items()
]
)
+ "}"
)
# build tag_prop_names map
self._tag_prop_names_map = {}
for tag, prop_names in zip(self._tags, tag_prop_names or []):
if prop_names is not None:
self._tag_prop_names_map[tag] = f"`{tag}`.`{prop_names}`"
self._tag_prop_names: List[str] = list(
{
prop_name.strip()
for prop_names in tag_prop_names or []
if prop_names is not None
for prop_name in prop_names.split(",")
}
)
self._include_vid = include_vid
def init_session_pool(self) -> Any:
"""Return NebulaGraph session pool."""
from nebula3.Config import SessionPoolConfig
from nebula3.gclient.net.SessionPool import SessionPool
# ensure "NEBULA_USER", "NEBULA_PASSWORD", "NEBULA_ADDRESS" are set
# in environment variables
if not all(
key in os.environ
for key in ["NEBULA_USER", "NEBULA_PASSWORD", "NEBULA_ADDRESS"]
):
raise ValueError(
"NEBULA_USER, NEBULA_PASSWORD, NEBULA_ADDRESS should be set in "
"environment variables when NebulaGraph Session Pool is not "
"directly passed."
)
graphd_host, graphd_port = os.environ["NEBULA_ADDRESS"].split(":")
session_pool = SessionPool(
os.environ["NEBULA_USER"],
os.environ["NEBULA_PASSWORD"],
self._space_name,
[(graphd_host, int(graphd_port))],
)
seesion_pool_config = SessionPoolConfig()
session_pool.init(seesion_pool_config)
self._session_pool = session_pool
return self._session_pool
def _get_vid_type(self) -> str:
"""Get vid type."""
return (
self.execute(f"DESCRIBE SPACE {self._space_name}")
.column_values("Vid Type")[0]
.cast()
)
def __del__(self) -> None:
"""Close NebulaGraph session pool."""
self._session_pool.close()
@retry(
wait=wait_random_exponential(min=WAIT_MIN_SECONDS, max=WAIT_MAX_SECONDS),
stop=stop_after_attempt(RETRY_TIMES),
)
def execute(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any:
"""Execute query.
Args:
query: Query.
param_map: Parameter map.
Returns:
Query result.
"""
from nebula3.Exception import IOErrorException
from nebula3.fbthrift.transport.TTransport import TTransportException
# Clean the query string by removing triple backticks
query = query.replace("```", "").strip()
try:
result = self._session_pool.execute_parameter(query, param_map)
if result is None:
raise ValueError(f"Query failed. Query: {query}, Param: {param_map}")
if not result.is_succeeded():
raise ValueError(
f"Query failed. Query: {query}, Param: {param_map}"
f"Error message: {result.error_msg()}"
)
return result
except (TTransportException, IOErrorException, RuntimeError) as e:
logger.error(
f"Connection issue, try to recreate session pool. Query: {query}, "
f"Param: {param_map}"
f"Error: {e}"
)
self.init_session_pool()
logger.info(
f"Session pool recreated. Query: {query}, Param: {param_map}"
f"This was due to error: {e}, and now retrying."
)
raise
except ValueError as e:
# query failed on db side
logger.error(
f"Query failed. Query: {query}, Param: {param_map}"
f"Error message: {e}"
)
raise
except Exception as e:
# other exceptions
logger.error(
f"Query failed. Query: {query}, Param: {param_map}"
f"Error message: {e}"
)
raise
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "GraphStore":
"""Initialize graph store from configuration dictionary.
Args:
config_dict: Configuration dictionary.
Returns:
Graph store.
"""
return cls(**config_dict)
@property
def client(self) -> Any:
"""Return NebulaGraph session pool."""
return self._session_pool
@property
def config_dict(self) -> dict:
"""Return configuration dictionary."""
return {
"session_pool": self._session_pool,
"space_name": self._space_name,
"edge_types": self._edge_types,
"rel_prop_names": self._rel_prop_names,
"session_pool_kwargs": self._session_pool_kwargs,
}
def get(self, subj: str) -> List[List[str]]:
"""Get triplets.
Args:
subj: Subject.
Returns:
Triplets.
"""
rel_map = self.get_flat_rel_map([subj], depth=1)
rels = list(rel_map.values())
if len(rels) == 0:
return []
return rels[0]
def get_flat_rel_map(
self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
) -> Dict[str, List[List[str]]]:
"""Get flat rel map."""
# The flat means for multi-hop relation path, we could get
# knowledge like: subj -rel-> obj -rel-> obj <-rel- obj.
# This type of knowledge is useful for some tasks.
# +---------------------+---------------------------------------------...-----+
# | subj | flattened_rels ... |
# +---------------------+---------------------------------------------...-----+
# | "{name:Tony Parker}"| "{name: Tony Parker}-[follow:{degree:95}]-> ...ili}"|
# | "{name:Tony Parker}"| "{name: Tony Parker}-[follow:{degree:95}]-> ...r}" |
# ...
rel_map: Dict[Any, List[Any]] = {}
if subjs is None or len(subjs) == 0:
# unlike simple graph_store, we don't do get_all here
return rel_map
# WITH map{`true`: "-[", `false`: "<-["} AS arrow_l,
# map{`true`: "]->", `false`: "]-"} AS arrow_r,
# map{`follow`: "degree", `serve`: "start_year,end_year"} AS edge_type_map
# MATCH p=(start)-[e:follow|serve*..2]-()
# WHERE id(start) IN ["player100", "player101"]
# WITH start, id(start) AS vid, nodes(p) AS nodes, e AS rels,
# length(p) AS rel_count, arrow_l, arrow_r, edge_type_map
# WITH
# REDUCE(s = vid + '{', key IN [key_ in ["name"]
# WHERE properties(start)[key_] IS NOT NULL] | s + key + ': ' +
# COALESCE(TOSTRING(properties(start)[key]), 'null') + ', ')
# + '}'
# AS subj,
# [item in [i IN RANGE(0, rel_count - 1) | [nodes[i], nodes[i + 1],
# rels[i], typeid(rels[i]) > 0, type(rels[i]) ]] | [
# arrow_l[tostring(item[3])] +
# item[4] + ':' +
# REDUCE(s = '{', key IN SPLIT(edge_type_map[item[4]], ',') |
# s + key + ': ' + COALESCE(TOSTRING(properties(item[2])[key]),
# 'null') + ', ') + '}'
# +
# arrow_r[tostring(item[3])],
# REDUCE(s = id(item[1]) + '{', key IN [key_ in ["name"]
# WHERE properties(item[1])[key_] IS NOT NULL] | s + key + ': ' +
# COALESCE(TOSTRING(properties(item[1])[key]), 'null') + ', ') + '}'
# ]
# ] AS rels
# WITH
# REPLACE(subj, ', }', '}') AS subj,
# REDUCE(acc = collect(NULL), l in rels | acc + l) AS flattened_rels
# RETURN
# subj,
# REPLACE(REDUCE(acc = subj,l in flattened_rels|acc + ' ' + l),
# ', }', '}')
# AS flattened_rels
# LIMIT 30
# Based on self._include_vid
# {name: Tim Duncan} or player100{name: Tim Duncan} for entity
s_prefix = "vid + '{'" if self._include_vid else "'{'"
s1 = "id(item[1]) + '{'" if self._include_vid else "'{'"
query = (
f"WITH map{{`true`: '-[', `false`: '<-['}} AS arrow_l,"
f" map{{`true`: ']->', `false`: ']-'}} AS arrow_r,"
f" {self._edge_prop_map_cypher_string} AS edge_type_map "
f"MATCH p=(start)-[e:`{'`|`'.join(self._edge_types)}`*..{depth}]-() "
f" WHERE id(start) IN $subjs "
f"WITH start, id(start) AS vid, nodes(p) AS nodes, e AS rels,"
f" length(p) AS rel_count, arrow_l, arrow_r, edge_type_map "
f"WITH "
f" REDUCE(s = {s_prefix}, key IN [key_ in {self._tag_prop_names!s} "
f" WHERE properties(start)[key_] IS NOT NULL] | s + key + ': ' + "
f" COALESCE(TOSTRING(properties(start)[key]), 'null') + ', ')"
f" + '}}'"
f" AS subj,"
f" [item in [i IN RANGE(0, rel_count - 1)|[nodes[i], nodes[i + 1],"
f" rels[i], typeid(rels[i]) > 0, type(rels[i]) ]] | ["
f" arrow_l[tostring(item[3])] +"
f" item[4] + ':' +"
f" REDUCE(s = '{{', key IN SPLIT(edge_type_map[item[4]], ',') | "
f" s + key + ': ' + COALESCE(TOSTRING(properties(item[2])[key]),"
f" 'null') + ', ') + '}}'"
f" +"
f" arrow_r[tostring(item[3])],"
f" REDUCE(s = {s1}, key IN [key_ in "
f" {self._tag_prop_names!s} WHERE properties(item[1])[key_] "
f" IS NOT NULL] | s + key + ': ' + "
f" COALESCE(TOSTRING(properties(item[1])[key]), 'null') + ', ')"
f" + '}}'"
f" ]"
f" ] AS rels "
f"WITH "
f" REPLACE(subj, ', }}', '}}') AS subj,"
f" REDUCE(acc = collect(NULL), l in rels | acc + l) AS flattened_rels "
f"RETURN "
f" subj,"
f" REPLACE(REDUCE(acc = subj, l in flattened_rels | acc + ' ' + l), "
f" ', }}', '}}') "
f" AS flattened_rels"
f" LIMIT {limit}"
)
subjs_param = prepare_subjs_param(subjs, self._vid_type)
logger.debug(f"get_flat_rel_map()\nsubjs_param: {subjs},\nquery: {query}")
if subjs_param == {}:
# This happens when subjs is None after prepare_subjs_param()
# Probably because vid type is INT64, but no digit string is provided.
return rel_map
result = self.execute(query, subjs_param)
if result is None:
return rel_map
# get raw data
subjs_ = result.column_values("subj") or []
rels_ = result.column_values("flattened_rels") or []
for subj, rel in zip(subjs_, rels_):
subj_ = subj.cast()
rel_ = rel.cast()
if subj_ not in rel_map:
rel_map[subj_] = []
rel_map[subj_].append(rel_)
return rel_map
def get_rel_map(
self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
) -> Dict[str, List[List[str]]]:
"""Get rel map."""
# We put rels in a long list for depth>= 1, this is different from
# SimpleGraphStore.get_rel_map() though.
# But this makes more sense for multi-hop relation path.
if subjs is not None:
subjs = [
escape_str(subj) for subj in subjs if isinstance(subj, str) and subj
]
if len(subjs) == 0:
return {}
return self.get_flat_rel_map(subjs, depth, limit)
def upsert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet."""
# Note, to enable leveraging existing knowledge graph,
# the (triplet -- property graph) mapping
# makes (n:1) edge_type.prop_name --> triplet.rel
# thus we have to assume rel to be the first edge_type.prop_name
# here in upsert_triplet().
# This applies to the type of entity(tags) with subject and object, too,
# thus we have to assume subj to be the first entity.tag_name
# lower case subj, rel, obj
subj = escape_str(subj)
rel = escape_str(rel)
obj = escape_str(obj)
if self._vid_type == "INT64":
assert all(
[subj.isdigit(), obj.isdigit()]
), "Subject and object should be digit strings in current graph store."
subj_field = subj
obj_field = obj
else:
subj_field = f"{QUOTE}{subj}{QUOTE}"
obj_field = f"{QUOTE}{obj}{QUOTE}"
edge_field = f"{subj_field}->{obj_field}"
edge_type = self._edge_types[0]
rel_prop_name = self._rel_prop_names[0]
entity_type = self._tags[0]
rel_hash = hash_string_to_rank(rel)
dml_query = (
f"INSERT VERTEX `{entity_type}`(name) "
f" VALUES {subj_field}:({QUOTE}{subj}{QUOTE});"
f"INSERT VERTEX `{entity_type}`(name) "
f" VALUES {obj_field}:({QUOTE}{obj}{QUOTE});"
f"INSERT EDGE `{edge_type}`(`{rel_prop_name}`) "
f" VALUES "
f"{edge_field}"
f"@{rel_hash}:({QUOTE}{rel}{QUOTE});"
)
logger.debug(f"upsert_triplet()\nDML query: {dml_query}")
result = self.execute(dml_query)
assert (
result and result.is_succeeded()
), f"Failed to upsert triplet: {subj} {rel} {obj}, query: {dml_query}"
def delete(self, subj: str, rel: str, obj: str) -> None:
"""Delete triplet.
1. Similar to upsert_triplet(),
we have to assume rel to be the first edge_type.prop_name.
2. After edge being deleted, we need to check if the subj or
obj are isolated vertices,
if so, delete them, too.
"""
# lower case subj, rel, obj
subj = escape_str(subj)
rel = escape_str(rel)
obj = escape_str(obj)
if self._vid_type == "INT64":
assert all(
[subj.isdigit(), obj.isdigit()]
), "Subject and object should be digit strings in current graph store."
subj_field = subj
obj_field = obj
else:
subj_field = f"{QUOTE}{subj}{QUOTE}"
obj_field = f"{QUOTE}{obj}{QUOTE}"
edge_field = f"{subj_field}->{obj_field}"
# DELETE EDGE serve "player100" -> "team204"@7696463696635583936;
edge_type = self._edge_types[0]
# rel_prop_name = self._rel_prop_names[0]
rel_hash = hash_string_to_rank(rel)
dml_query = f"DELETE EDGE `{edge_type}`" f" {edge_field}@{rel_hash};"
logger.debug(f"delete()\nDML query: {dml_query}")
result = self.execute(dml_query)
assert (
result and result.is_succeeded()
), f"Failed to delete triplet: {subj} {rel} {obj}, query: {dml_query}"
# Get isolated vertices to be deleted
# MATCH (s) WHERE id(s) IN ["player700"] AND NOT (s)-[]-()
# RETURN id(s) AS isolated
query = (
f"MATCH (s) "
f" WHERE id(s) IN [{subj_field}, {obj_field}] "
f" AND NOT (s)-[]-() "
f"RETURN id(s) AS isolated"
)
result = self.execute(query)
isolated = result.column_values("isolated")
if not isolated:
return
# DELETE VERTEX "player700" or DELETE VERTEX 700
quote_field = QUOTE if self._vid_type != "INT64" else ""
vertex_ids = ",".join(
[f"{quote_field}{v.cast()}{quote_field}" for v in isolated]
)
dml_query = f"DELETE VERTEX {vertex_ids};"
result = self.execute(dml_query)
assert (
result and result.is_succeeded()
), f"Failed to delete isolated vertices: {isolated}, query: {dml_query}"
def refresh_schema(self) -> None:
"""
Refreshes the NebulaGraph Store Schema.
"""
tags_schema, edge_types_schema, relationships = [], [], []
for tag in self.execute("SHOW TAGS").column_values("Name"):
tag_name = tag.cast()
tag_schema = {"tag": tag_name, "properties": []}
r = self.execute(f"DESCRIBE TAG `{tag_name}`")
props, types, comments = (
r.column_values("Field"),
r.column_values("Type"),
r.column_values("Comment"),
)
for i in range(r.row_size()):
# back compatible with old version of nebula-python
property_defination = (
(props[i].cast(), types[i].cast())
if comments[i].is_empty()
else (props[i].cast(), types[i].cast(), comments[i].cast())
)
tag_schema["properties"].append(property_defination)
tags_schema.append(tag_schema)
for edge_type in self.execute("SHOW EDGES").column_values("Name"):
edge_type_name = edge_type.cast()
edge_schema = {"edge": edge_type_name, "properties": []}
r = self.execute(f"DESCRIBE EDGE `{edge_type_name}`")
props, types, comments = (
r.column_values("Field"),
r.column_values("Type"),
r.column_values("Comment"),
)
for i in range(r.row_size()):
# back compatible with old version of nebula-python
property_defination = (
(props[i].cast(), types[i].cast())
if comments[i].is_empty()
else (props[i].cast(), types[i].cast(), comments[i].cast())
)
edge_schema["properties"].append(property_defination)
edge_types_schema.append(edge_schema)
# build relationships types
sample_edge = self.execute(
rel_query_sample_edge.substitute(edge_type=edge_type_name)
).column_values("sample_edge")
if len(sample_edge) == 0:
continue
src_id, dst_id = sample_edge[0].cast()
r = self.execute(
rel_query_edge_type.substitute(
edge_type=edge_type_name,
src_id=src_id,
dst_id=dst_id,
quote="" if self._vid_type == "INT64" else QUOTE,
)
).column_values("rels")
if len(r) > 0:
relationships.append(r[0].cast())
self.schema = (
f"Node properties: {tags_schema}\n"
f"Edge properties: {edge_types_schema}\n"
f"Relationships: {relationships}\n"
)
def get_schema(self, refresh: bool = False) -> str:
"""Get the schema of the NebulaGraph store."""
if self.schema and not refresh:
return self.schema
self.refresh_schema()
logger.debug(f"get_schema()\nschema: {self.schema}")
return self.schema
def query(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any:
result = self.execute(query, param_map)
columns = result.keys()
d: Dict[str, list] = {}
for col_num in range(result.col_size()):
col_name = columns[col_num]
col_list = result.column_values(col_name)
d[col_name] = [x.cast() for x in col_list]
return d