faiss_rag_enterprise/llama_index/vector_stores/azureaisearch.py

750 lines
28 KiB
Python

"""Azure AI Search vector store."""
import enum
import json
import logging
from enum import auto
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from llama_index.schema import BaseNode, MetadataMode, TextNode
from llama_index.vector_stores.types import (
ExactMatchFilter,
MetadataFilters,
VectorStore,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import (
legacy_metadata_dict_to_node,
metadata_dict_to_node,
node_to_metadata_dict,
)
logger = logging.getLogger(__name__)
class MetadataIndexFieldType(int, enum.Enum):
"""
Enumeration representing the supported types for metadata fields in an
Azure AI Search Index, corresponds with types supported in a flat
metadata dictionary.
"""
STRING = auto() # "Edm.String"
BOOLEAN = auto() # "Edm.Boolean"
INT32 = auto() # "Edm.Int32"
INT64 = auto() # "Edm.Int64"
DOUBLE = auto() # "Edm.Double"
class IndexManagement(int, enum.Enum):
"""Enumeration representing the supported index management operations."""
NO_VALIDATION = auto()
VALIDATE_INDEX = auto()
CREATE_IF_NOT_EXISTS = auto()
class AzureAISearchVectorStore(VectorStore):
stores_text: bool = True
flat_metadata: bool = True
def _normalise_metadata_to_index_fields(
self,
filterable_metadata_field_keys: Union[
List[str],
Dict[str, str],
Dict[str, Tuple[str, MetadataIndexFieldType]],
None,
] = [],
) -> Dict[str, Tuple[str, MetadataIndexFieldType]]:
index_field_spec: Dict[str, Tuple[str, MetadataIndexFieldType]] = {}
if isinstance(filterable_metadata_field_keys, List):
for field in filterable_metadata_field_keys:
# Index field name and the metadata field name are the same
# Use String as the default index field type
index_field_spec[field] = (field, MetadataIndexFieldType.STRING)
elif isinstance(filterable_metadata_field_keys, Dict):
for k, v in filterable_metadata_field_keys.items():
if isinstance(v, tuple):
# Index field name and metadata field name may differ
# The index field type used is as supplied
index_field_spec[k] = v
else:
# Index field name and metadata field name may differ
# Use String as the default index field type
index_field_spec[k] = (v, MetadataIndexFieldType.STRING)
return index_field_spec
def _create_index_if_not_exists(self, index_name: str) -> None:
if index_name not in self._index_client.list_index_names():
logger.info(
f"Index {index_name} does not exist in Azure AI Search, creating index"
)
self._create_index(index_name)
def _create_metadata_index_fields(self) -> List[Any]:
"""Create a list of index fields for storing metadata values."""
from azure.search.documents.indexes.models import SimpleField
index_fields = []
# create search fields
for v in self._metadata_to_index_field_map.values():
field_name, field_type = v
if field_type == MetadataIndexFieldType.STRING:
index_field_type = "Edm.String"
elif field_type == MetadataIndexFieldType.INT32:
index_field_type = "Edm.Int32"
elif field_type == MetadataIndexFieldType.INT64:
index_field_type = "Edm.Int64"
elif field_type == MetadataIndexFieldType.DOUBLE:
index_field_type = "Edm.Double"
elif field_type == MetadataIndexFieldType.BOOLEAN:
index_field_type = "Edm.Boolean"
field = SimpleField(name=field_name, type=index_field_type, filterable=True)
index_fields.append(field)
return index_fields
def _create_index(self, index_name: Optional[str]) -> None:
"""
Creates a default index based on the supplied index name, key field names and
metadata filtering keys.
"""
from azure.search.documents.indexes.models import (
ExhaustiveKnnAlgorithmConfiguration,
ExhaustiveKnnParameters,
HnswAlgorithmConfiguration,
HnswParameters,
SearchableField,
SearchField,
SearchFieldDataType,
SearchIndex,
SemanticConfiguration,
SemanticField,
SemanticPrioritizedFields,
SemanticSearch,
SimpleField,
VectorSearch,
VectorSearchAlgorithmKind,
VectorSearchAlgorithmMetric,
VectorSearchProfile,
)
logger.info(f"Configuring {index_name} fields for Azure AI Search")
fields = [
SimpleField(name=self._field_mapping["id"], type="Edm.String", key=True),
SearchableField(
name=self._field_mapping["chunk"],
type="Edm.String",
analyzer_name="en.microsoft",
),
SearchField(
name=self._field_mapping["embedding"],
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
vector_search_dimensions=self.embedding_dimensionality,
vector_search_profile_name="default",
),
SimpleField(name=self._field_mapping["metadata"], type="Edm.String"),
SimpleField(
name=self._field_mapping["doc_id"], type="Edm.String", filterable=True
),
]
logger.info(f"Configuring {index_name} metadata fields")
metadata_index_fields = self._create_metadata_index_fields()
fields.extend(metadata_index_fields)
logger.info(f"Configuring {index_name} vector search")
# Configure the vector search algorithms and profiles
vector_search = VectorSearch(
algorithms=[
HnswAlgorithmConfiguration(
name="myHnsw",
kind=VectorSearchAlgorithmKind.HNSW,
# For more information on HNSw parameters, visit https://learn.microsoft.com//azure/search/vector-search-ranking#creating-the-hnsw-graph
parameters=HnswParameters(
m=4,
ef_construction=400,
ef_search=500,
metric=VectorSearchAlgorithmMetric.COSINE,
),
),
ExhaustiveKnnAlgorithmConfiguration(
name="myExhaustiveKnn",
kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN,
parameters=ExhaustiveKnnParameters(
metric=VectorSearchAlgorithmMetric.COSINE,
),
),
],
profiles=[
VectorSearchProfile(
name="myHnswProfile",
algorithm_configuration_name="myHnsw",
),
# Add more profiles if needed
VectorSearchProfile(
name="myExhaustiveKnnProfile",
algorithm_configuration_name="myExhaustiveKnn",
),
# Add more profiles if needed
],
)
logger.info(f"Configuring {index_name} semantic search")
semantic_config = SemanticConfiguration(
name="mySemanticConfig",
prioritized_fields=SemanticPrioritizedFields(
content_fields=[SemanticField(field_name=self._field_mapping["chunk"])],
),
)
semantic_search = SemanticSearch(configurations=[semantic_config])
index = SearchIndex(
name=index_name,
fields=fields,
vector_search=vector_search,
semantic_search=semantic_search,
)
logger.debug(f"Creating {index_name} search index")
self._index_client.create_index(index)
def _validate_index(self, index_name: Optional[str]) -> None:
if self._index_client and index_name:
if index_name not in self._index_client.list_index_names():
raise ValueError(
f"Validation failed, index {index_name} does not exist."
)
def __init__(
self,
search_or_index_client: Any,
id_field_key: str,
chunk_field_key: str,
embedding_field_key: str,
metadata_string_field_key: str,
doc_id_field_key: str,
filterable_metadata_field_keys: Optional[
Union[
List[str],
Dict[str, str],
Dict[str, Tuple[str, MetadataIndexFieldType]],
]
] = None,
index_name: Optional[str] = None,
index_mapping: Optional[
Callable[[Dict[str, str], Dict[str, Any]], Dict[str, str]]
] = None,
index_management: IndexManagement = IndexManagement.NO_VALIDATION,
embedding_dimensionality: int = 1536,
**kwargs: Any,
) -> None:
# ruff: noqa: E501
"""
Embeddings and documents are stored in an Azure AI Search index,
a merge or upload approach is used when adding embeddings.
When adding multiple embeddings the index is updated by this vector store
in batches of 10 documents, very large nodes may result in failure due to
the batch byte size being exceeded.
Args:
search_client (azure.search.documents.SearchClient):
Client for index to populated / queried.
id_field_key (str): Index field storing the id
chunk_field_key (str): Index field storing the node text
embedding_field_key (str): Index field storing the embedding vector
metadata_string_field_key (str):
Index field storing node metadata as a json string.
Schema is arbitrary, to filter on metadata values they must be stored
as separate fields in the index, use filterable_metadata_field_keys
to specify the metadata values that should be stored in these filterable fields
doc_id_field_key (str): Index field storing doc_id
index_mapping:
Optional function with definition
(enriched_doc: Dict[str, str], metadata: Dict[str, Any]): Dict[str,str]
used to map document fields to the AI search index fields
(return value of function).
If none is specified a default mapping is provided which uses
the field keys. The keys in the enriched_doc are
["id", "chunk", "embedding", "metadata"]
The default mapping is:
- "id" to id_field_key
- "chunk" to chunk_field_key
- "embedding" to embedding_field_key
- "metadata" to metadata_field_key
*kwargs (Any): Additional keyword arguments.
Raises:
ImportError: Unable to import `azure-search-documents`
ValueError: If `search_or_index_client` is not provided
ValueError: If `index_name` is not provided and `search_or_index_client`
is of type azure.search.documents.SearchIndexClient
ValueError: If `index_name` is provided and `search_or_index_client`
is of type azure.search.documents.SearchClient
ValueError: If `create_index_if_not_exists` is true and
`search_or_index_client` is of type azure.search.documents.SearchClient
"""
import_err_msg = (
"`azure-search-documents` package not found, please run "
"`pip install azure-search-documents==11.4.0`"
)
try:
import azure.search.documents # noqa
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
except ImportError:
raise ImportError(import_err_msg)
self._index_client: SearchIndexClient = cast(SearchIndexClient, None)
self._search_client: SearchClient = cast(SearchClient, None)
self.embedding_dimensionality = embedding_dimensionality
# Validate search_or_index_client
if search_or_index_client is not None:
if isinstance(search_or_index_client, SearchIndexClient):
# If SearchIndexClient is supplied so must index_name
self._index_client = cast(SearchIndexClient, search_or_index_client)
if not index_name:
raise ValueError(
"index_name must be supplied if search_or_index_client is of "
"type azure.search.documents.SearchIndexClient"
)
self._search_client = self._index_client.get_search_client(
index_name=index_name
)
elif isinstance(search_or_index_client, SearchClient):
self._search_client = cast(SearchClient, search_or_index_client)
# Validate index_name
if index_name:
raise ValueError(
"index_name cannot be supplied if search_or_index_client "
"is of type azure.search.documents.SearchClient"
)
if not self._index_client and not self._search_client:
raise ValueError(
"search_or_index_client must be of type "
"azure.search.documents.SearchClient or "
"azure.search.documents.SearchIndexClient"
)
else:
raise ValueError("search_or_index_client not specified")
if (
index_management == IndexManagement.CREATE_IF_NOT_EXISTS
and not self._index_client
):
raise ValueError(
"index_management has value of IndexManagement.CREATE_IF_NOT_EXISTS "
"but search_or_index_client is not of type "
"azure.search.documents.SearchIndexClient"
)
self._index_management = index_management
# Default field mapping
field_mapping = {
"id": id_field_key,
"chunk": chunk_field_key,
"embedding": embedding_field_key,
"metadata": metadata_string_field_key,
"doc_id": doc_id_field_key,
}
self._field_mapping = field_mapping
self._index_mapping = (
self._default_index_mapping if index_mapping is None else index_mapping
)
# self._filterable_metadata_field_keys = filterable_metadata_field_keys
self._metadata_to_index_field_map = self._normalise_metadata_to_index_fields(
filterable_metadata_field_keys
)
if self._index_management == IndexManagement.CREATE_IF_NOT_EXISTS:
if index_name:
self._create_index_if_not_exists(index_name)
if self._index_management == IndexManagement.VALIDATE_INDEX:
self._validate_index(index_name)
@property
def client(self) -> Any:
"""Get client."""
return self._search_client
def _default_index_mapping(
self, enriched_doc: Dict[str, str], metadata: Dict[str, Any]
) -> Dict[str, str]:
index_doc: Dict[str, str] = {}
for field in self._field_mapping:
index_doc[self._field_mapping[field]] = enriched_doc[field]
for metadata_field_name, (
index_field_name,
_,
) in self._metadata_to_index_field_map.items():
metadata_value = metadata.get(metadata_field_name)
if metadata_value:
index_doc[index_field_name] = metadata_value
return index_doc
def add(
self,
nodes: List[BaseNode],
**add_kwargs: Any,
) -> List[str]:
"""Add nodes to index associated with the configured search client.
Args:
nodes: List[BaseNode]: nodes with embeddings
"""
if not self._search_client:
raise ValueError("Search client not initialized")
documents = []
ids = []
for node in nodes:
logger.debug(f"Processing embedding: {node.node_id}")
ids.append(node.node_id)
index_document = self._create_index_document(node)
documents.append(index_document)
if len(documents) >= 10:
logger.info(
f"Uploading batch of size {len(documents)}, "
f"current progress {len(ids)} of {len(nodes)}"
)
self._search_client.merge_or_upload_documents(documents)
documents = []
# Upload remaining batch of less than 10 documents
if len(documents) > 0:
logger.info(
f"Uploading remaining batch of size {len(documents)}, "
f"current progress {len(ids)} of {len(nodes)}"
)
self._search_client.merge_or_upload_documents(documents)
documents = []
return ids
def _create_index_document(self, node: BaseNode) -> Dict[str, Any]:
"""Create AI Search index document from embedding result."""
doc: Dict[str, Any] = {}
doc["id"] = node.node_id
doc["chunk"] = node.get_content(metadata_mode=MetadataMode.NONE) or ""
doc["embedding"] = node.get_embedding()
doc["doc_id"] = node.ref_doc_id
node_metadata = node_to_metadata_dict(
node,
remove_text=True,
flat_metadata=self.flat_metadata,
)
doc["metadata"] = json.dumps(node_metadata)
return self._index_mapping(doc, node_metadata)
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Delete documents from the AI Search Index
with doc_id_field_key field equal to ref_doc_id.
"""
# Locate documents to delete
filter = f'{self._field_mapping["doc_id"]} eq \'{ref_doc_id}\''
results = self._search_client.search(search_text="*", filter=filter)
logger.debug(f"Searching with filter {filter}")
docs_to_delete = []
for result in results:
doc = {}
doc["id"] = result[self._field_mapping["id"]]
logger.debug(f"Found document to delete: {doc}")
docs_to_delete.append(doc)
if len(docs_to_delete) > 0:
logger.debug(f"Deleting {len(docs_to_delete)} documents")
self._search_client.delete_documents(docs_to_delete)
def _create_odata_filter(self, metadata_filters: MetadataFilters) -> str:
"""Generate an OData filter string using supplied metadata filters."""
odata_filter: List[str] = []
for f in metadata_filters.legacy_filters():
if not isinstance(f, ExactMatchFilter):
raise NotImplementedError(
"Only `ExactMatchFilter` filters are supported"
)
# Raise error if filtering on a metadata field that lacks a mapping to
# an index field
metadata_mapping = self._metadata_to_index_field_map.get(f.key)
if not metadata_mapping:
raise ValueError(
f"Metadata field '{f.key}' is missing a mapping to an index field, "
"provide entry in 'filterable_metadata_field_keys' for this "
"vector store"
)
index_field = metadata_mapping[0]
if len(odata_filter) > 0:
odata_filter.append(" and ")
if isinstance(f.value, str):
escaped_value = "".join([("''" if s == "'" else s) for s in f.value])
odata_filter.append(f"{index_field} eq '{escaped_value}'")
else:
odata_filter.append(f"{index_field} eq {f.value}")
odata_expr = "".join(odata_filter)
logger.info(f"Odata filter: {odata_expr}")
return odata_expr
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
odata_filter = None
if query.filters is not None:
odata_filter = self._create_odata_filter(query.filters)
azure_query_result_search: AzureQueryResultSearchBase = (
AzureQueryResultSearchDefault(
query, self._field_mapping, odata_filter, self._search_client
)
)
if query.mode == VectorStoreQueryMode.SPARSE:
azure_query_result_search = AzureQueryResultSearchSparse(
query, self._field_mapping, odata_filter, self._search_client
)
elif query.mode == VectorStoreQueryMode.HYBRID:
azure_query_result_search = AzureQueryResultSearchHybrid(
query, self._field_mapping, odata_filter, self._search_client
)
elif query.mode == VectorStoreQueryMode.SEMANTIC_HYBRID:
azure_query_result_search = AzureQueryResultSearchSemanticHybrid(
query, self._field_mapping, odata_filter, self._search_client
)
return azure_query_result_search.search()
class AzureQueryResultSearchBase:
def __init__(
self,
query: VectorStoreQuery,
field_mapping: Dict[str, str],
odata_filter: Optional[str],
search_client: Any,
) -> None:
self._query = query
self._field_mapping = field_mapping
self._odata_filter = odata_filter
self._search_client = search_client
@property
def _select_fields(self) -> List[str]:
return [
self._field_mapping["id"],
self._field_mapping["chunk"],
self._field_mapping["metadata"],
self._field_mapping["doc_id"],
]
def _create_search_query(self) -> str:
return "*"
def _create_query_vector(self) -> Optional[List[Any]]:
return None
def _create_query_result(
self, search_query: str, vectors: Optional[List[Any]]
) -> VectorStoreQueryResult:
results = self._search_client.search(
search_text=search_query,
vector_queries=vectors,
top=self._query.similarity_top_k,
select=self._select_fields,
filter=self._odata_filter,
)
id_result = []
node_result = []
score_result = []
for result in results:
node_id = result[self._field_mapping["id"]]
metadata = json.loads(result[self._field_mapping["metadata"]])
score = result["@search.score"]
chunk = result[self._field_mapping["chunk"]]
try:
node = metadata_dict_to_node(metadata)
node.set_content(chunk)
except Exception:
# NOTE: deprecated legacy logic for backward compatibility
metadata, node_info, relationships = legacy_metadata_dict_to_node(
metadata
)
node = TextNode(
text=chunk,
id_=node_id,
metadata=metadata,
start_char_idx=node_info.get("start", None),
end_char_idx=node_info.get("end", None),
relationships=relationships,
)
logger.debug(f"Retrieved node id {node_id} with node data of {node}")
id_result.append(node_id)
node_result.append(node)
score_result.append(score)
logger.debug(
f"Search query '{search_query}' returned {len(id_result)} results."
)
return VectorStoreQueryResult(
nodes=node_result, similarities=score_result, ids=id_result
)
def search(self) -> VectorStoreQueryResult:
search_query = self._create_search_query()
vectors = self._create_query_vector()
return self._create_query_result(search_query, vectors)
class AzureQueryResultSearchDefault(AzureQueryResultSearchBase):
def _create_query_vector(self) -> Optional[List[Any]]:
"""Query vector store."""
from azure.search.documents.models import VectorizedQuery
if not self._query.query_embedding:
raise ValueError("Query missing embedding")
vectorized_query = VectorizedQuery(
vector=self._query.query_embedding,
k_nearest_neighbors=self._query.similarity_top_k,
fields=self._field_mapping["embedding"],
)
vector_queries = [vectorized_query]
logger.info("Vector search with supplied embedding")
return vector_queries
class AzureQueryResultSearchSparse(AzureQueryResultSearchBase):
def _create_search_query(self) -> str:
if self._query.query_str is None:
raise ValueError("Query missing query string")
search_query = self._query.query_str
logger.info(f"Hybrid search with search text: {search_query}")
return search_query
class AzureQueryResultSearchHybrid(
AzureQueryResultSearchDefault, AzureQueryResultSearchSparse
):
def _create_query_vector(self) -> Optional[List[Any]]:
return AzureQueryResultSearchDefault._create_query_vector(self)
def _create_search_query(self) -> str:
return AzureQueryResultSearchSparse._create_search_query(self)
class AzureQueryResultSearchSemanticHybrid(AzureQueryResultSearchHybrid):
def _create_query_vector(self) -> Optional[List[Any]]:
"""Query vector store."""
from azure.search.documents.models import VectorizedQuery
if not self._query.query_embedding:
raise ValueError("Query missing embedding")
# k is set to 50 to align with the number of accept document in azure semantic reranking model.
# https://learn.microsoft.com/azure/search/semantic-search-overview
vectorized_query = VectorizedQuery(
vector=self._query.query_embedding,
k_nearest_neighbors=50,
fields=self._field_mapping["embedding"],
)
vector_queries = [vectorized_query]
logger.info("Vector search with supplied embedding")
return vector_queries
def _create_query_result(
self, search_query: str, vector_queries: Optional[List[Any]]
) -> VectorStoreQueryResult:
results = self._search_client.search(
search_text=search_query,
vector_queries=vector_queries,
top=self._query.similarity_top_k,
select=self._select_fields,
filter=self._odata_filter,
query_type="semantic",
semantic_configuration_name="mySemanticConfig",
)
id_result = []
node_result = []
score_result = []
for result in results:
node_id = result[self._field_mapping["id"]]
metadata = json.loads(result[self._field_mapping["metadata"]])
# use reranker_score instead of score
score = result["@search.reranker_score"]
chunk = result[self._field_mapping["chunk"]]
try:
node = metadata_dict_to_node(metadata)
node.set_content(chunk)
except Exception:
# NOTE: deprecated legacy logic for backward compatibility
metadata, node_info, relationships = legacy_metadata_dict_to_node(
metadata
)
node = TextNode(
text=chunk,
id_=node_id,
metadata=metadata,
start_char_idx=node_info.get("start", None),
end_char_idx=node_info.get("end", None),
relationships=relationships,
)
logger.debug(f"Retrieved node id {node_id} with node data of {node}")
id_result.append(node_id)
node_result.append(node)
score_result.append(score)
logger.debug(
f"Search query '{search_query}' returned {len(id_result)} results."
)
return VectorStoreQueryResult(
nodes=node_result, similarities=score_result, ids=id_result
)
CognitiveSearchVectorStore = AzureAISearchVectorStore