677 lines
25 KiB
Python
677 lines
25 KiB
Python
"""NebulaGraph graph store index."""
|
|
import logging
|
|
import os
|
|
from string import Template
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|
|
|
from llama_index.graph_stores.types import GraphStore
|
|
|
|
QUOTE = '"'
|
|
RETRY_TIMES = 3
|
|
WAIT_MIN_SECONDS = 0.5
|
|
WAIT_MAX_SECONDS = 10
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
rel_query_sample_edge = Template(
|
|
"""
|
|
MATCH ()-[e:`$edge_type`]->()
|
|
RETURN [src(e), dst(e)] AS sample_edge LIMIT 1
|
|
"""
|
|
)
|
|
|
|
rel_query_edge_type = Template(
|
|
"""
|
|
MATCH (m)-[:`$edge_type`]->(n)
|
|
WHERE id(m) == $quote$src_id$quote AND id(n) == $quote$dst_id$quote
|
|
RETURN "(:" + tags(m)[0] + ")-[:$edge_type]->(:" + tags(n)[0] + ")" AS rels
|
|
"""
|
|
)
|
|
|
|
|
|
def hash_string_to_rank(string: str) -> int:
|
|
# get signed 64-bit hash value
|
|
signed_hash = hash(string)
|
|
|
|
# reduce the hash value to a 64-bit range
|
|
mask = (1 << 64) - 1
|
|
signed_hash &= mask
|
|
|
|
# convert the signed hash value to an unsigned 64-bit integer
|
|
if signed_hash & (1 << 63):
|
|
unsigned_hash = -((signed_hash ^ mask) + 1)
|
|
else:
|
|
unsigned_hash = signed_hash
|
|
|
|
return unsigned_hash
|
|
|
|
|
|
def prepare_subjs_param(
|
|
subjs: Optional[List[str]], vid_type: str = "FIXED_STRING(256)"
|
|
) -> Dict:
|
|
"""Prepare parameters for query."""
|
|
if subjs is None:
|
|
return {}
|
|
from nebula3.common import ttypes
|
|
|
|
subjs_list = []
|
|
subjs_byte = ttypes.Value()
|
|
|
|
# filter non-digit string for INT64 vid type
|
|
if vid_type == "INT64":
|
|
subjs = [subj for subj in subjs if subj.isdigit()]
|
|
if len(subjs) == 0:
|
|
logger.warning(
|
|
f"KG is with INT64 vid type, but no digit string is provided."
|
|
f"Return empty subjs, and no query will be executed."
|
|
f"subjs: {subjs}"
|
|
)
|
|
return {}
|
|
for subj in subjs:
|
|
if not isinstance(subj, str):
|
|
raise TypeError(f"Subject should be str, but got {type(subj).__name__}.")
|
|
subj_byte = ttypes.Value()
|
|
if vid_type == "INT64":
|
|
assert subj.isdigit(), (
|
|
"Subject should be a digit string in current "
|
|
"graph store, where vid type is INT64."
|
|
)
|
|
subj_byte.set_iVal(int(subj))
|
|
else:
|
|
subj_byte.set_sVal(subj)
|
|
subjs_list.append(subj_byte)
|
|
subjs_nlist = ttypes.NList(values=subjs_list)
|
|
subjs_byte.set_lVal(subjs_nlist)
|
|
return {"subjs": subjs_byte}
|
|
|
|
|
|
def escape_str(value: str) -> str:
|
|
"""Escape String for NebulaGraph Query."""
|
|
patterns = {
|
|
'"': " ",
|
|
}
|
|
for pattern in patterns:
|
|
if pattern in value:
|
|
value = value.replace(pattern, patterns[pattern])
|
|
if value[0] == " " or value[-1] == " ":
|
|
value = value.strip()
|
|
|
|
return value
|
|
|
|
|
|
class NebulaGraphStore(GraphStore):
|
|
"""NebulaGraph graph store."""
|
|
|
|
def __init__(
|
|
self,
|
|
session_pool: Optional[Any] = None,
|
|
space_name: Optional[str] = None,
|
|
edge_types: Optional[List[str]] = ["relationship"],
|
|
rel_prop_names: Optional[List[str]] = ["relationship,"],
|
|
tags: Optional[List[str]] = ["entity"],
|
|
tag_prop_names: Optional[List[str]] = ["name,"],
|
|
include_vid: bool = True,
|
|
session_pool_kwargs: Optional[Dict[str, Any]] = {},
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize NebulaGraph graph store.
|
|
|
|
Args:
|
|
session_pool: NebulaGraph session pool.
|
|
space_name: NebulaGraph space name.
|
|
edge_types: Edge types.
|
|
rel_prop_names: Relation property names corresponding to edge types.
|
|
tags: Tags.
|
|
tag_prop_names: Tag property names corresponding to tags.
|
|
session_pool_kwargs: Keyword arguments for NebulaGraph session pool.
|
|
**kwargs: Keyword arguments.
|
|
"""
|
|
try:
|
|
import nebula3 # noqa
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install NebulaGraph Python client first: "
|
|
"`pip install nebula3-python`"
|
|
)
|
|
assert space_name is not None, "space_name should be provided."
|
|
self._space_name = space_name
|
|
self._session_pool_kwargs = session_pool_kwargs
|
|
|
|
self._session_pool: Any = session_pool
|
|
if self._session_pool is None:
|
|
self.init_session_pool()
|
|
|
|
self._vid_type = self._get_vid_type()
|
|
|
|
self._tags = tags or ["entity"]
|
|
self._edge_types = edge_types or ["rel"]
|
|
self._rel_prop_names = rel_prop_names or ["predicate,"]
|
|
if len(self._edge_types) != len(self._rel_prop_names):
|
|
raise ValueError(
|
|
"edge_types and rel_prop_names to define relation and relation name"
|
|
"should be provided, yet with same length."
|
|
)
|
|
if len(self._edge_types) == 0:
|
|
raise ValueError("Length of `edge_types` should be greater than 0.")
|
|
|
|
if tag_prop_names is None or len(self._tags) != len(tag_prop_names):
|
|
raise ValueError(
|
|
"tag_prop_names to define tag and tag property name should be "
|
|
"provided, yet with same length."
|
|
)
|
|
|
|
if len(self._tags) == 0:
|
|
raise ValueError("Length of `tags` should be greater than 0.")
|
|
|
|
# for building query
|
|
self._edge_dot_rel = [
|
|
f"`{edge_type}`.`{rel_prop_name}`"
|
|
for edge_type, rel_prop_name in zip(self._edge_types, self._rel_prop_names)
|
|
]
|
|
|
|
self._edge_prop_map = {}
|
|
for edge_type, rel_prop_name in zip(self._edge_types, self._rel_prop_names):
|
|
self._edge_prop_map[edge_type] = [
|
|
prop.strip() for prop in rel_prop_name.split(",")
|
|
]
|
|
|
|
# cypher string like: map{`follow`: "degree", `serve`: "start_year,end_year"}
|
|
self._edge_prop_map_cypher_string = (
|
|
"map{"
|
|
+ ", ".join(
|
|
[
|
|
f"`{edge_type}`: \"{','.join(rel_prop_names)}\""
|
|
for edge_type, rel_prop_names in self._edge_prop_map.items()
|
|
]
|
|
)
|
|
+ "}"
|
|
)
|
|
|
|
# build tag_prop_names map
|
|
self._tag_prop_names_map = {}
|
|
for tag, prop_names in zip(self._tags, tag_prop_names or []):
|
|
if prop_names is not None:
|
|
self._tag_prop_names_map[tag] = f"`{tag}`.`{prop_names}`"
|
|
self._tag_prop_names: List[str] = list(
|
|
{
|
|
prop_name.strip()
|
|
for prop_names in tag_prop_names or []
|
|
if prop_names is not None
|
|
for prop_name in prop_names.split(",")
|
|
}
|
|
)
|
|
|
|
self._include_vid = include_vid
|
|
|
|
def init_session_pool(self) -> Any:
|
|
"""Return NebulaGraph session pool."""
|
|
from nebula3.Config import SessionPoolConfig
|
|
from nebula3.gclient.net.SessionPool import SessionPool
|
|
|
|
# ensure "NEBULA_USER", "NEBULA_PASSWORD", "NEBULA_ADDRESS" are set
|
|
# in environment variables
|
|
if not all(
|
|
key in os.environ
|
|
for key in ["NEBULA_USER", "NEBULA_PASSWORD", "NEBULA_ADDRESS"]
|
|
):
|
|
raise ValueError(
|
|
"NEBULA_USER, NEBULA_PASSWORD, NEBULA_ADDRESS should be set in "
|
|
"environment variables when NebulaGraph Session Pool is not "
|
|
"directly passed."
|
|
)
|
|
graphd_host, graphd_port = os.environ["NEBULA_ADDRESS"].split(":")
|
|
session_pool = SessionPool(
|
|
os.environ["NEBULA_USER"],
|
|
os.environ["NEBULA_PASSWORD"],
|
|
self._space_name,
|
|
[(graphd_host, int(graphd_port))],
|
|
)
|
|
|
|
seesion_pool_config = SessionPoolConfig()
|
|
session_pool.init(seesion_pool_config)
|
|
self._session_pool = session_pool
|
|
return self._session_pool
|
|
|
|
def _get_vid_type(self) -> str:
|
|
"""Get vid type."""
|
|
return (
|
|
self.execute(f"DESCRIBE SPACE {self._space_name}")
|
|
.column_values("Vid Type")[0]
|
|
.cast()
|
|
)
|
|
|
|
def __del__(self) -> None:
|
|
"""Close NebulaGraph session pool."""
|
|
self._session_pool.close()
|
|
|
|
@retry(
|
|
wait=wait_random_exponential(min=WAIT_MIN_SECONDS, max=WAIT_MAX_SECONDS),
|
|
stop=stop_after_attempt(RETRY_TIMES),
|
|
)
|
|
def execute(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any:
|
|
"""Execute query.
|
|
|
|
Args:
|
|
query: Query.
|
|
param_map: Parameter map.
|
|
|
|
Returns:
|
|
Query result.
|
|
"""
|
|
from nebula3.Exception import IOErrorException
|
|
from nebula3.fbthrift.transport.TTransport import TTransportException
|
|
|
|
# Clean the query string by removing triple backticks
|
|
query = query.replace("```", "").strip()
|
|
|
|
try:
|
|
result = self._session_pool.execute_parameter(query, param_map)
|
|
if result is None:
|
|
raise ValueError(f"Query failed. Query: {query}, Param: {param_map}")
|
|
if not result.is_succeeded():
|
|
raise ValueError(
|
|
f"Query failed. Query: {query}, Param: {param_map}"
|
|
f"Error message: {result.error_msg()}"
|
|
)
|
|
return result
|
|
except (TTransportException, IOErrorException, RuntimeError) as e:
|
|
logger.error(
|
|
f"Connection issue, try to recreate session pool. Query: {query}, "
|
|
f"Param: {param_map}"
|
|
f"Error: {e}"
|
|
)
|
|
self.init_session_pool()
|
|
logger.info(
|
|
f"Session pool recreated. Query: {query}, Param: {param_map}"
|
|
f"This was due to error: {e}, and now retrying."
|
|
)
|
|
raise
|
|
|
|
except ValueError as e:
|
|
# query failed on db side
|
|
logger.error(
|
|
f"Query failed. Query: {query}, Param: {param_map}"
|
|
f"Error message: {e}"
|
|
)
|
|
raise
|
|
except Exception as e:
|
|
# other exceptions
|
|
logger.error(
|
|
f"Query failed. Query: {query}, Param: {param_map}"
|
|
f"Error message: {e}"
|
|
)
|
|
raise
|
|
|
|
@classmethod
|
|
def from_dict(cls, config_dict: Dict[str, Any]) -> "GraphStore":
|
|
"""Initialize graph store from configuration dictionary.
|
|
|
|
Args:
|
|
config_dict: Configuration dictionary.
|
|
|
|
Returns:
|
|
Graph store.
|
|
"""
|
|
return cls(**config_dict)
|
|
|
|
@property
|
|
def client(self) -> Any:
|
|
"""Return NebulaGraph session pool."""
|
|
return self._session_pool
|
|
|
|
@property
|
|
def config_dict(self) -> dict:
|
|
"""Return configuration dictionary."""
|
|
return {
|
|
"session_pool": self._session_pool,
|
|
"space_name": self._space_name,
|
|
"edge_types": self._edge_types,
|
|
"rel_prop_names": self._rel_prop_names,
|
|
"session_pool_kwargs": self._session_pool_kwargs,
|
|
}
|
|
|
|
def get(self, subj: str) -> List[List[str]]:
|
|
"""Get triplets.
|
|
|
|
Args:
|
|
subj: Subject.
|
|
|
|
Returns:
|
|
Triplets.
|
|
"""
|
|
rel_map = self.get_flat_rel_map([subj], depth=1)
|
|
rels = list(rel_map.values())
|
|
if len(rels) == 0:
|
|
return []
|
|
return rels[0]
|
|
|
|
def get_flat_rel_map(
|
|
self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
|
|
) -> Dict[str, List[List[str]]]:
|
|
"""Get flat rel map."""
|
|
# The flat means for multi-hop relation path, we could get
|
|
# knowledge like: subj -rel-> obj -rel-> obj <-rel- obj.
|
|
# This type of knowledge is useful for some tasks.
|
|
# +---------------------+---------------------------------------------...-----+
|
|
# | subj | flattened_rels ... |
|
|
# +---------------------+---------------------------------------------...-----+
|
|
# | "{name:Tony Parker}"| "{name: Tony Parker}-[follow:{degree:95}]-> ...ili}"|
|
|
# | "{name:Tony Parker}"| "{name: Tony Parker}-[follow:{degree:95}]-> ...r}" |
|
|
# ...
|
|
rel_map: Dict[Any, List[Any]] = {}
|
|
if subjs is None or len(subjs) == 0:
|
|
# unlike simple graph_store, we don't do get_all here
|
|
return rel_map
|
|
|
|
# WITH map{`true`: "-[", `false`: "<-["} AS arrow_l,
|
|
# map{`true`: "]->", `false`: "]-"} AS arrow_r,
|
|
# map{`follow`: "degree", `serve`: "start_year,end_year"} AS edge_type_map
|
|
# MATCH p=(start)-[e:follow|serve*..2]-()
|
|
# WHERE id(start) IN ["player100", "player101"]
|
|
# WITH start, id(start) AS vid, nodes(p) AS nodes, e AS rels,
|
|
# length(p) AS rel_count, arrow_l, arrow_r, edge_type_map
|
|
# WITH
|
|
# REDUCE(s = vid + '{', key IN [key_ in ["name"]
|
|
# WHERE properties(start)[key_] IS NOT NULL] | s + key + ': ' +
|
|
# COALESCE(TOSTRING(properties(start)[key]), 'null') + ', ')
|
|
# + '}'
|
|
# AS subj,
|
|
# [item in [i IN RANGE(0, rel_count - 1) | [nodes[i], nodes[i + 1],
|
|
# rels[i], typeid(rels[i]) > 0, type(rels[i]) ]] | [
|
|
# arrow_l[tostring(item[3])] +
|
|
# item[4] + ':' +
|
|
# REDUCE(s = '{', key IN SPLIT(edge_type_map[item[4]], ',') |
|
|
# s + key + ': ' + COALESCE(TOSTRING(properties(item[2])[key]),
|
|
# 'null') + ', ') + '}'
|
|
# +
|
|
# arrow_r[tostring(item[3])],
|
|
# REDUCE(s = id(item[1]) + '{', key IN [key_ in ["name"]
|
|
# WHERE properties(item[1])[key_] IS NOT NULL] | s + key + ': ' +
|
|
# COALESCE(TOSTRING(properties(item[1])[key]), 'null') + ', ') + '}'
|
|
# ]
|
|
# ] AS rels
|
|
# WITH
|
|
# REPLACE(subj, ', }', '}') AS subj,
|
|
# REDUCE(acc = collect(NULL), l in rels | acc + l) AS flattened_rels
|
|
# RETURN
|
|
# subj,
|
|
# REPLACE(REDUCE(acc = subj,l in flattened_rels|acc + ' ' + l),
|
|
# ', }', '}')
|
|
# AS flattened_rels
|
|
# LIMIT 30
|
|
|
|
# Based on self._include_vid
|
|
# {name: Tim Duncan} or player100{name: Tim Duncan} for entity
|
|
s_prefix = "vid + '{'" if self._include_vid else "'{'"
|
|
s1 = "id(item[1]) + '{'" if self._include_vid else "'{'"
|
|
|
|
query = (
|
|
f"WITH map{{`true`: '-[', `false`: '<-['}} AS arrow_l,"
|
|
f" map{{`true`: ']->', `false`: ']-'}} AS arrow_r,"
|
|
f" {self._edge_prop_map_cypher_string} AS edge_type_map "
|
|
f"MATCH p=(start)-[e:`{'`|`'.join(self._edge_types)}`*..{depth}]-() "
|
|
f" WHERE id(start) IN $subjs "
|
|
f"WITH start, id(start) AS vid, nodes(p) AS nodes, e AS rels,"
|
|
f" length(p) AS rel_count, arrow_l, arrow_r, edge_type_map "
|
|
f"WITH "
|
|
f" REDUCE(s = {s_prefix}, key IN [key_ in {self._tag_prop_names!s} "
|
|
f" WHERE properties(start)[key_] IS NOT NULL] | s + key + ': ' + "
|
|
f" COALESCE(TOSTRING(properties(start)[key]), 'null') + ', ')"
|
|
f" + '}}'"
|
|
f" AS subj,"
|
|
f" [item in [i IN RANGE(0, rel_count - 1)|[nodes[i], nodes[i + 1],"
|
|
f" rels[i], typeid(rels[i]) > 0, type(rels[i]) ]] | ["
|
|
f" arrow_l[tostring(item[3])] +"
|
|
f" item[4] + ':' +"
|
|
f" REDUCE(s = '{{', key IN SPLIT(edge_type_map[item[4]], ',') | "
|
|
f" s + key + ': ' + COALESCE(TOSTRING(properties(item[2])[key]),"
|
|
f" 'null') + ', ') + '}}'"
|
|
f" +"
|
|
f" arrow_r[tostring(item[3])],"
|
|
f" REDUCE(s = {s1}, key IN [key_ in "
|
|
f" {self._tag_prop_names!s} WHERE properties(item[1])[key_] "
|
|
f" IS NOT NULL] | s + key + ': ' + "
|
|
f" COALESCE(TOSTRING(properties(item[1])[key]), 'null') + ', ')"
|
|
f" + '}}'"
|
|
f" ]"
|
|
f" ] AS rels "
|
|
f"WITH "
|
|
f" REPLACE(subj, ', }}', '}}') AS subj,"
|
|
f" REDUCE(acc = collect(NULL), l in rels | acc + l) AS flattened_rels "
|
|
f"RETURN "
|
|
f" subj,"
|
|
f" REPLACE(REDUCE(acc = subj, l in flattened_rels | acc + ' ' + l), "
|
|
f" ', }}', '}}') "
|
|
f" AS flattened_rels"
|
|
f" LIMIT {limit}"
|
|
)
|
|
subjs_param = prepare_subjs_param(subjs, self._vid_type)
|
|
logger.debug(f"get_flat_rel_map()\nsubjs_param: {subjs},\nquery: {query}")
|
|
if subjs_param == {}:
|
|
# This happens when subjs is None after prepare_subjs_param()
|
|
# Probably because vid type is INT64, but no digit string is provided.
|
|
return rel_map
|
|
result = self.execute(query, subjs_param)
|
|
if result is None:
|
|
return rel_map
|
|
|
|
# get raw data
|
|
subjs_ = result.column_values("subj") or []
|
|
rels_ = result.column_values("flattened_rels") or []
|
|
|
|
for subj, rel in zip(subjs_, rels_):
|
|
subj_ = subj.cast()
|
|
rel_ = rel.cast()
|
|
if subj_ not in rel_map:
|
|
rel_map[subj_] = []
|
|
rel_map[subj_].append(rel_)
|
|
return rel_map
|
|
|
|
def get_rel_map(
|
|
self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
|
|
) -> Dict[str, List[List[str]]]:
|
|
"""Get rel map."""
|
|
# We put rels in a long list for depth>= 1, this is different from
|
|
# SimpleGraphStore.get_rel_map() though.
|
|
# But this makes more sense for multi-hop relation path.
|
|
|
|
if subjs is not None:
|
|
subjs = [
|
|
escape_str(subj) for subj in subjs if isinstance(subj, str) and subj
|
|
]
|
|
if len(subjs) == 0:
|
|
return {}
|
|
|
|
return self.get_flat_rel_map(subjs, depth, limit)
|
|
|
|
def upsert_triplet(self, subj: str, rel: str, obj: str) -> None:
|
|
"""Add triplet."""
|
|
# Note, to enable leveraging existing knowledge graph,
|
|
# the (triplet -- property graph) mapping
|
|
# makes (n:1) edge_type.prop_name --> triplet.rel
|
|
# thus we have to assume rel to be the first edge_type.prop_name
|
|
# here in upsert_triplet().
|
|
# This applies to the type of entity(tags) with subject and object, too,
|
|
# thus we have to assume subj to be the first entity.tag_name
|
|
|
|
# lower case subj, rel, obj
|
|
subj = escape_str(subj)
|
|
rel = escape_str(rel)
|
|
obj = escape_str(obj)
|
|
if self._vid_type == "INT64":
|
|
assert all(
|
|
[subj.isdigit(), obj.isdigit()]
|
|
), "Subject and object should be digit strings in current graph store."
|
|
subj_field = subj
|
|
obj_field = obj
|
|
else:
|
|
subj_field = f"{QUOTE}{subj}{QUOTE}"
|
|
obj_field = f"{QUOTE}{obj}{QUOTE}"
|
|
edge_field = f"{subj_field}->{obj_field}"
|
|
|
|
edge_type = self._edge_types[0]
|
|
rel_prop_name = self._rel_prop_names[0]
|
|
entity_type = self._tags[0]
|
|
rel_hash = hash_string_to_rank(rel)
|
|
dml_query = (
|
|
f"INSERT VERTEX `{entity_type}`(name) "
|
|
f" VALUES {subj_field}:({QUOTE}{subj}{QUOTE});"
|
|
f"INSERT VERTEX `{entity_type}`(name) "
|
|
f" VALUES {obj_field}:({QUOTE}{obj}{QUOTE});"
|
|
f"INSERT EDGE `{edge_type}`(`{rel_prop_name}`) "
|
|
f" VALUES "
|
|
f"{edge_field}"
|
|
f"@{rel_hash}:({QUOTE}{rel}{QUOTE});"
|
|
)
|
|
logger.debug(f"upsert_triplet()\nDML query: {dml_query}")
|
|
result = self.execute(dml_query)
|
|
assert (
|
|
result and result.is_succeeded()
|
|
), f"Failed to upsert triplet: {subj} {rel} {obj}, query: {dml_query}"
|
|
|
|
def delete(self, subj: str, rel: str, obj: str) -> None:
|
|
"""Delete triplet.
|
|
1. Similar to upsert_triplet(),
|
|
we have to assume rel to be the first edge_type.prop_name.
|
|
2. After edge being deleted, we need to check if the subj or
|
|
obj are isolated vertices,
|
|
if so, delete them, too.
|
|
"""
|
|
# lower case subj, rel, obj
|
|
subj = escape_str(subj)
|
|
rel = escape_str(rel)
|
|
obj = escape_str(obj)
|
|
|
|
if self._vid_type == "INT64":
|
|
assert all(
|
|
[subj.isdigit(), obj.isdigit()]
|
|
), "Subject and object should be digit strings in current graph store."
|
|
subj_field = subj
|
|
obj_field = obj
|
|
else:
|
|
subj_field = f"{QUOTE}{subj}{QUOTE}"
|
|
obj_field = f"{QUOTE}{obj}{QUOTE}"
|
|
edge_field = f"{subj_field}->{obj_field}"
|
|
|
|
# DELETE EDGE serve "player100" -> "team204"@7696463696635583936;
|
|
edge_type = self._edge_types[0]
|
|
# rel_prop_name = self._rel_prop_names[0]
|
|
rel_hash = hash_string_to_rank(rel)
|
|
dml_query = f"DELETE EDGE `{edge_type}`" f" {edge_field}@{rel_hash};"
|
|
logger.debug(f"delete()\nDML query: {dml_query}")
|
|
result = self.execute(dml_query)
|
|
assert (
|
|
result and result.is_succeeded()
|
|
), f"Failed to delete triplet: {subj} {rel} {obj}, query: {dml_query}"
|
|
# Get isolated vertices to be deleted
|
|
# MATCH (s) WHERE id(s) IN ["player700"] AND NOT (s)-[]-()
|
|
# RETURN id(s) AS isolated
|
|
query = (
|
|
f"MATCH (s) "
|
|
f" WHERE id(s) IN [{subj_field}, {obj_field}] "
|
|
f" AND NOT (s)-[]-() "
|
|
f"RETURN id(s) AS isolated"
|
|
)
|
|
result = self.execute(query)
|
|
isolated = result.column_values("isolated")
|
|
if not isolated:
|
|
return
|
|
# DELETE VERTEX "player700" or DELETE VERTEX 700
|
|
quote_field = QUOTE if self._vid_type != "INT64" else ""
|
|
vertex_ids = ",".join(
|
|
[f"{quote_field}{v.cast()}{quote_field}" for v in isolated]
|
|
)
|
|
dml_query = f"DELETE VERTEX {vertex_ids};"
|
|
|
|
result = self.execute(dml_query)
|
|
assert (
|
|
result and result.is_succeeded()
|
|
), f"Failed to delete isolated vertices: {isolated}, query: {dml_query}"
|
|
|
|
def refresh_schema(self) -> None:
|
|
"""
|
|
Refreshes the NebulaGraph Store Schema.
|
|
"""
|
|
tags_schema, edge_types_schema, relationships = [], [], []
|
|
for tag in self.execute("SHOW TAGS").column_values("Name"):
|
|
tag_name = tag.cast()
|
|
tag_schema = {"tag": tag_name, "properties": []}
|
|
r = self.execute(f"DESCRIBE TAG `{tag_name}`")
|
|
props, types, comments = (
|
|
r.column_values("Field"),
|
|
r.column_values("Type"),
|
|
r.column_values("Comment"),
|
|
)
|
|
for i in range(r.row_size()):
|
|
# back compatible with old version of nebula-python
|
|
property_defination = (
|
|
(props[i].cast(), types[i].cast())
|
|
if comments[i].is_empty()
|
|
else (props[i].cast(), types[i].cast(), comments[i].cast())
|
|
)
|
|
tag_schema["properties"].append(property_defination)
|
|
tags_schema.append(tag_schema)
|
|
for edge_type in self.execute("SHOW EDGES").column_values("Name"):
|
|
edge_type_name = edge_type.cast()
|
|
edge_schema = {"edge": edge_type_name, "properties": []}
|
|
r = self.execute(f"DESCRIBE EDGE `{edge_type_name}`")
|
|
props, types, comments = (
|
|
r.column_values("Field"),
|
|
r.column_values("Type"),
|
|
r.column_values("Comment"),
|
|
)
|
|
for i in range(r.row_size()):
|
|
# back compatible with old version of nebula-python
|
|
property_defination = (
|
|
(props[i].cast(), types[i].cast())
|
|
if comments[i].is_empty()
|
|
else (props[i].cast(), types[i].cast(), comments[i].cast())
|
|
)
|
|
edge_schema["properties"].append(property_defination)
|
|
edge_types_schema.append(edge_schema)
|
|
|
|
# build relationships types
|
|
sample_edge = self.execute(
|
|
rel_query_sample_edge.substitute(edge_type=edge_type_name)
|
|
).column_values("sample_edge")
|
|
if len(sample_edge) == 0:
|
|
continue
|
|
src_id, dst_id = sample_edge[0].cast()
|
|
r = self.execute(
|
|
rel_query_edge_type.substitute(
|
|
edge_type=edge_type_name,
|
|
src_id=src_id,
|
|
dst_id=dst_id,
|
|
quote="" if self._vid_type == "INT64" else QUOTE,
|
|
)
|
|
).column_values("rels")
|
|
if len(r) > 0:
|
|
relationships.append(r[0].cast())
|
|
|
|
self.schema = (
|
|
f"Node properties: {tags_schema}\n"
|
|
f"Edge properties: {edge_types_schema}\n"
|
|
f"Relationships: {relationships}\n"
|
|
)
|
|
|
|
def get_schema(self, refresh: bool = False) -> str:
|
|
"""Get the schema of the NebulaGraph store."""
|
|
if self.schema and not refresh:
|
|
return self.schema
|
|
self.refresh_schema()
|
|
logger.debug(f"get_schema()\nschema: {self.schema}")
|
|
return self.schema
|
|
|
|
def query(self, query: str, param_map: Optional[Dict[str, Any]] = {}) -> Any:
|
|
result = self.execute(query, param_map)
|
|
columns = result.keys()
|
|
d: Dict[str, list] = {}
|
|
for col_num in range(result.col_size()):
|
|
col_name = columns[col_num]
|
|
col_list = result.column_values(col_name)
|
|
d[col_name] = [x.cast() for x in col_list]
|
|
return d
|