750 lines
28 KiB
Python
750 lines
28 KiB
Python
"""Azure AI Search vector store."""
|
|
import enum
|
|
import json
|
|
import logging
|
|
from enum import auto
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
|
|
|
from llama_index.schema import BaseNode, MetadataMode, TextNode
|
|
from llama_index.vector_stores.types import (
|
|
ExactMatchFilter,
|
|
MetadataFilters,
|
|
VectorStore,
|
|
VectorStoreQuery,
|
|
VectorStoreQueryMode,
|
|
VectorStoreQueryResult,
|
|
)
|
|
from llama_index.vector_stores.utils import (
|
|
legacy_metadata_dict_to_node,
|
|
metadata_dict_to_node,
|
|
node_to_metadata_dict,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MetadataIndexFieldType(int, enum.Enum):
|
|
"""
|
|
Enumeration representing the supported types for metadata fields in an
|
|
Azure AI Search Index, corresponds with types supported in a flat
|
|
metadata dictionary.
|
|
"""
|
|
|
|
STRING = auto() # "Edm.String"
|
|
BOOLEAN = auto() # "Edm.Boolean"
|
|
INT32 = auto() # "Edm.Int32"
|
|
INT64 = auto() # "Edm.Int64"
|
|
DOUBLE = auto() # "Edm.Double"
|
|
|
|
|
|
class IndexManagement(int, enum.Enum):
|
|
"""Enumeration representing the supported index management operations."""
|
|
|
|
NO_VALIDATION = auto()
|
|
VALIDATE_INDEX = auto()
|
|
CREATE_IF_NOT_EXISTS = auto()
|
|
|
|
|
|
class AzureAISearchVectorStore(VectorStore):
|
|
stores_text: bool = True
|
|
flat_metadata: bool = True
|
|
|
|
def _normalise_metadata_to_index_fields(
|
|
self,
|
|
filterable_metadata_field_keys: Union[
|
|
List[str],
|
|
Dict[str, str],
|
|
Dict[str, Tuple[str, MetadataIndexFieldType]],
|
|
None,
|
|
] = [],
|
|
) -> Dict[str, Tuple[str, MetadataIndexFieldType]]:
|
|
index_field_spec: Dict[str, Tuple[str, MetadataIndexFieldType]] = {}
|
|
|
|
if isinstance(filterable_metadata_field_keys, List):
|
|
for field in filterable_metadata_field_keys:
|
|
# Index field name and the metadata field name are the same
|
|
# Use String as the default index field type
|
|
index_field_spec[field] = (field, MetadataIndexFieldType.STRING)
|
|
|
|
elif isinstance(filterable_metadata_field_keys, Dict):
|
|
for k, v in filterable_metadata_field_keys.items():
|
|
if isinstance(v, tuple):
|
|
# Index field name and metadata field name may differ
|
|
# The index field type used is as supplied
|
|
index_field_spec[k] = v
|
|
else:
|
|
# Index field name and metadata field name may differ
|
|
# Use String as the default index field type
|
|
index_field_spec[k] = (v, MetadataIndexFieldType.STRING)
|
|
|
|
return index_field_spec
|
|
|
|
def _create_index_if_not_exists(self, index_name: str) -> None:
|
|
if index_name not in self._index_client.list_index_names():
|
|
logger.info(
|
|
f"Index {index_name} does not exist in Azure AI Search, creating index"
|
|
)
|
|
self._create_index(index_name)
|
|
|
|
def _create_metadata_index_fields(self) -> List[Any]:
|
|
"""Create a list of index fields for storing metadata values."""
|
|
from azure.search.documents.indexes.models import SimpleField
|
|
|
|
index_fields = []
|
|
|
|
# create search fields
|
|
for v in self._metadata_to_index_field_map.values():
|
|
field_name, field_type = v
|
|
|
|
if field_type == MetadataIndexFieldType.STRING:
|
|
index_field_type = "Edm.String"
|
|
elif field_type == MetadataIndexFieldType.INT32:
|
|
index_field_type = "Edm.Int32"
|
|
elif field_type == MetadataIndexFieldType.INT64:
|
|
index_field_type = "Edm.Int64"
|
|
elif field_type == MetadataIndexFieldType.DOUBLE:
|
|
index_field_type = "Edm.Double"
|
|
elif field_type == MetadataIndexFieldType.BOOLEAN:
|
|
index_field_type = "Edm.Boolean"
|
|
|
|
field = SimpleField(name=field_name, type=index_field_type, filterable=True)
|
|
index_fields.append(field)
|
|
|
|
return index_fields
|
|
|
|
def _create_index(self, index_name: Optional[str]) -> None:
|
|
"""
|
|
Creates a default index based on the supplied index name, key field names and
|
|
metadata filtering keys.
|
|
"""
|
|
from azure.search.documents.indexes.models import (
|
|
ExhaustiveKnnAlgorithmConfiguration,
|
|
ExhaustiveKnnParameters,
|
|
HnswAlgorithmConfiguration,
|
|
HnswParameters,
|
|
SearchableField,
|
|
SearchField,
|
|
SearchFieldDataType,
|
|
SearchIndex,
|
|
SemanticConfiguration,
|
|
SemanticField,
|
|
SemanticPrioritizedFields,
|
|
SemanticSearch,
|
|
SimpleField,
|
|
VectorSearch,
|
|
VectorSearchAlgorithmKind,
|
|
VectorSearchAlgorithmMetric,
|
|
VectorSearchProfile,
|
|
)
|
|
|
|
logger.info(f"Configuring {index_name} fields for Azure AI Search")
|
|
fields = [
|
|
SimpleField(name=self._field_mapping["id"], type="Edm.String", key=True),
|
|
SearchableField(
|
|
name=self._field_mapping["chunk"],
|
|
type="Edm.String",
|
|
analyzer_name="en.microsoft",
|
|
),
|
|
SearchField(
|
|
name=self._field_mapping["embedding"],
|
|
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
|
|
searchable=True,
|
|
vector_search_dimensions=self.embedding_dimensionality,
|
|
vector_search_profile_name="default",
|
|
),
|
|
SimpleField(name=self._field_mapping["metadata"], type="Edm.String"),
|
|
SimpleField(
|
|
name=self._field_mapping["doc_id"], type="Edm.String", filterable=True
|
|
),
|
|
]
|
|
logger.info(f"Configuring {index_name} metadata fields")
|
|
metadata_index_fields = self._create_metadata_index_fields()
|
|
fields.extend(metadata_index_fields)
|
|
logger.info(f"Configuring {index_name} vector search")
|
|
# Configure the vector search algorithms and profiles
|
|
vector_search = VectorSearch(
|
|
algorithms=[
|
|
HnswAlgorithmConfiguration(
|
|
name="myHnsw",
|
|
kind=VectorSearchAlgorithmKind.HNSW,
|
|
# For more information on HNSw parameters, visit https://learn.microsoft.com//azure/search/vector-search-ranking#creating-the-hnsw-graph
|
|
parameters=HnswParameters(
|
|
m=4,
|
|
ef_construction=400,
|
|
ef_search=500,
|
|
metric=VectorSearchAlgorithmMetric.COSINE,
|
|
),
|
|
),
|
|
ExhaustiveKnnAlgorithmConfiguration(
|
|
name="myExhaustiveKnn",
|
|
kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN,
|
|
parameters=ExhaustiveKnnParameters(
|
|
metric=VectorSearchAlgorithmMetric.COSINE,
|
|
),
|
|
),
|
|
],
|
|
profiles=[
|
|
VectorSearchProfile(
|
|
name="myHnswProfile",
|
|
algorithm_configuration_name="myHnsw",
|
|
),
|
|
# Add more profiles if needed
|
|
VectorSearchProfile(
|
|
name="myExhaustiveKnnProfile",
|
|
algorithm_configuration_name="myExhaustiveKnn",
|
|
),
|
|
# Add more profiles if needed
|
|
],
|
|
)
|
|
logger.info(f"Configuring {index_name} semantic search")
|
|
semantic_config = SemanticConfiguration(
|
|
name="mySemanticConfig",
|
|
prioritized_fields=SemanticPrioritizedFields(
|
|
content_fields=[SemanticField(field_name=self._field_mapping["chunk"])],
|
|
),
|
|
)
|
|
|
|
semantic_search = SemanticSearch(configurations=[semantic_config])
|
|
|
|
index = SearchIndex(
|
|
name=index_name,
|
|
fields=fields,
|
|
vector_search=vector_search,
|
|
semantic_search=semantic_search,
|
|
)
|
|
logger.debug(f"Creating {index_name} search index")
|
|
self._index_client.create_index(index)
|
|
|
|
def _validate_index(self, index_name: Optional[str]) -> None:
|
|
if self._index_client and index_name:
|
|
if index_name not in self._index_client.list_index_names():
|
|
raise ValueError(
|
|
f"Validation failed, index {index_name} does not exist."
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
search_or_index_client: Any,
|
|
id_field_key: str,
|
|
chunk_field_key: str,
|
|
embedding_field_key: str,
|
|
metadata_string_field_key: str,
|
|
doc_id_field_key: str,
|
|
filterable_metadata_field_keys: Optional[
|
|
Union[
|
|
List[str],
|
|
Dict[str, str],
|
|
Dict[str, Tuple[str, MetadataIndexFieldType]],
|
|
]
|
|
] = None,
|
|
index_name: Optional[str] = None,
|
|
index_mapping: Optional[
|
|
Callable[[Dict[str, str], Dict[str, Any]], Dict[str, str]]
|
|
] = None,
|
|
index_management: IndexManagement = IndexManagement.NO_VALIDATION,
|
|
embedding_dimensionality: int = 1536,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
# ruff: noqa: E501
|
|
"""
|
|
Embeddings and documents are stored in an Azure AI Search index,
|
|
a merge or upload approach is used when adding embeddings.
|
|
When adding multiple embeddings the index is updated by this vector store
|
|
in batches of 10 documents, very large nodes may result in failure due to
|
|
the batch byte size being exceeded.
|
|
|
|
Args:
|
|
search_client (azure.search.documents.SearchClient):
|
|
Client for index to populated / queried.
|
|
id_field_key (str): Index field storing the id
|
|
chunk_field_key (str): Index field storing the node text
|
|
embedding_field_key (str): Index field storing the embedding vector
|
|
metadata_string_field_key (str):
|
|
Index field storing node metadata as a json string.
|
|
Schema is arbitrary, to filter on metadata values they must be stored
|
|
as separate fields in the index, use filterable_metadata_field_keys
|
|
to specify the metadata values that should be stored in these filterable fields
|
|
doc_id_field_key (str): Index field storing doc_id
|
|
index_mapping:
|
|
Optional function with definition
|
|
(enriched_doc: Dict[str, str], metadata: Dict[str, Any]): Dict[str,str]
|
|
used to map document fields to the AI search index fields
|
|
(return value of function).
|
|
If none is specified a default mapping is provided which uses
|
|
the field keys. The keys in the enriched_doc are
|
|
["id", "chunk", "embedding", "metadata"]
|
|
The default mapping is:
|
|
- "id" to id_field_key
|
|
- "chunk" to chunk_field_key
|
|
- "embedding" to embedding_field_key
|
|
- "metadata" to metadata_field_key
|
|
*kwargs (Any): Additional keyword arguments.
|
|
|
|
Raises:
|
|
ImportError: Unable to import `azure-search-documents`
|
|
ValueError: If `search_or_index_client` is not provided
|
|
ValueError: If `index_name` is not provided and `search_or_index_client`
|
|
is of type azure.search.documents.SearchIndexClient
|
|
ValueError: If `index_name` is provided and `search_or_index_client`
|
|
is of type azure.search.documents.SearchClient
|
|
ValueError: If `create_index_if_not_exists` is true and
|
|
`search_or_index_client` is of type azure.search.documents.SearchClient
|
|
"""
|
|
import_err_msg = (
|
|
"`azure-search-documents` package not found, please run "
|
|
"`pip install azure-search-documents==11.4.0`"
|
|
)
|
|
|
|
try:
|
|
import azure.search.documents # noqa
|
|
from azure.search.documents import SearchClient
|
|
from azure.search.documents.indexes import SearchIndexClient
|
|
except ImportError:
|
|
raise ImportError(import_err_msg)
|
|
|
|
self._index_client: SearchIndexClient = cast(SearchIndexClient, None)
|
|
self._search_client: SearchClient = cast(SearchClient, None)
|
|
self.embedding_dimensionality = embedding_dimensionality
|
|
|
|
# Validate search_or_index_client
|
|
if search_or_index_client is not None:
|
|
if isinstance(search_or_index_client, SearchIndexClient):
|
|
# If SearchIndexClient is supplied so must index_name
|
|
self._index_client = cast(SearchIndexClient, search_or_index_client)
|
|
|
|
if not index_name:
|
|
raise ValueError(
|
|
"index_name must be supplied if search_or_index_client is of "
|
|
"type azure.search.documents.SearchIndexClient"
|
|
)
|
|
|
|
self._search_client = self._index_client.get_search_client(
|
|
index_name=index_name
|
|
)
|
|
|
|
elif isinstance(search_or_index_client, SearchClient):
|
|
self._search_client = cast(SearchClient, search_or_index_client)
|
|
|
|
# Validate index_name
|
|
if index_name:
|
|
raise ValueError(
|
|
"index_name cannot be supplied if search_or_index_client "
|
|
"is of type azure.search.documents.SearchClient"
|
|
)
|
|
|
|
if not self._index_client and not self._search_client:
|
|
raise ValueError(
|
|
"search_or_index_client must be of type "
|
|
"azure.search.documents.SearchClient or "
|
|
"azure.search.documents.SearchIndexClient"
|
|
)
|
|
else:
|
|
raise ValueError("search_or_index_client not specified")
|
|
|
|
if (
|
|
index_management == IndexManagement.CREATE_IF_NOT_EXISTS
|
|
and not self._index_client
|
|
):
|
|
raise ValueError(
|
|
"index_management has value of IndexManagement.CREATE_IF_NOT_EXISTS "
|
|
"but search_or_index_client is not of type "
|
|
"azure.search.documents.SearchIndexClient"
|
|
)
|
|
|
|
self._index_management = index_management
|
|
|
|
# Default field mapping
|
|
field_mapping = {
|
|
"id": id_field_key,
|
|
"chunk": chunk_field_key,
|
|
"embedding": embedding_field_key,
|
|
"metadata": metadata_string_field_key,
|
|
"doc_id": doc_id_field_key,
|
|
}
|
|
|
|
self._field_mapping = field_mapping
|
|
|
|
self._index_mapping = (
|
|
self._default_index_mapping if index_mapping is None else index_mapping
|
|
)
|
|
|
|
# self._filterable_metadata_field_keys = filterable_metadata_field_keys
|
|
self._metadata_to_index_field_map = self._normalise_metadata_to_index_fields(
|
|
filterable_metadata_field_keys
|
|
)
|
|
|
|
if self._index_management == IndexManagement.CREATE_IF_NOT_EXISTS:
|
|
if index_name:
|
|
self._create_index_if_not_exists(index_name)
|
|
|
|
if self._index_management == IndexManagement.VALIDATE_INDEX:
|
|
self._validate_index(index_name)
|
|
|
|
@property
|
|
def client(self) -> Any:
|
|
"""Get client."""
|
|
return self._search_client
|
|
|
|
def _default_index_mapping(
|
|
self, enriched_doc: Dict[str, str], metadata: Dict[str, Any]
|
|
) -> Dict[str, str]:
|
|
index_doc: Dict[str, str] = {}
|
|
|
|
for field in self._field_mapping:
|
|
index_doc[self._field_mapping[field]] = enriched_doc[field]
|
|
|
|
for metadata_field_name, (
|
|
index_field_name,
|
|
_,
|
|
) in self._metadata_to_index_field_map.items():
|
|
metadata_value = metadata.get(metadata_field_name)
|
|
if metadata_value:
|
|
index_doc[index_field_name] = metadata_value
|
|
|
|
return index_doc
|
|
|
|
def add(
|
|
self,
|
|
nodes: List[BaseNode],
|
|
**add_kwargs: Any,
|
|
) -> List[str]:
|
|
"""Add nodes to index associated with the configured search client.
|
|
|
|
Args:
|
|
nodes: List[BaseNode]: nodes with embeddings
|
|
|
|
"""
|
|
if not self._search_client:
|
|
raise ValueError("Search client not initialized")
|
|
|
|
documents = []
|
|
ids = []
|
|
|
|
for node in nodes:
|
|
logger.debug(f"Processing embedding: {node.node_id}")
|
|
ids.append(node.node_id)
|
|
|
|
index_document = self._create_index_document(node)
|
|
|
|
documents.append(index_document)
|
|
|
|
if len(documents) >= 10:
|
|
logger.info(
|
|
f"Uploading batch of size {len(documents)}, "
|
|
f"current progress {len(ids)} of {len(nodes)}"
|
|
)
|
|
self._search_client.merge_or_upload_documents(documents)
|
|
documents = []
|
|
|
|
# Upload remaining batch of less than 10 documents
|
|
if len(documents) > 0:
|
|
logger.info(
|
|
f"Uploading remaining batch of size {len(documents)}, "
|
|
f"current progress {len(ids)} of {len(nodes)}"
|
|
)
|
|
self._search_client.merge_or_upload_documents(documents)
|
|
documents = []
|
|
|
|
return ids
|
|
|
|
def _create_index_document(self, node: BaseNode) -> Dict[str, Any]:
|
|
"""Create AI Search index document from embedding result."""
|
|
doc: Dict[str, Any] = {}
|
|
doc["id"] = node.node_id
|
|
doc["chunk"] = node.get_content(metadata_mode=MetadataMode.NONE) or ""
|
|
doc["embedding"] = node.get_embedding()
|
|
doc["doc_id"] = node.ref_doc_id
|
|
|
|
node_metadata = node_to_metadata_dict(
|
|
node,
|
|
remove_text=True,
|
|
flat_metadata=self.flat_metadata,
|
|
)
|
|
|
|
doc["metadata"] = json.dumps(node_metadata)
|
|
|
|
return self._index_mapping(doc, node_metadata)
|
|
|
|
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
|
|
"""
|
|
Delete documents from the AI Search Index
|
|
with doc_id_field_key field equal to ref_doc_id.
|
|
"""
|
|
# Locate documents to delete
|
|
filter = f'{self._field_mapping["doc_id"]} eq \'{ref_doc_id}\''
|
|
results = self._search_client.search(search_text="*", filter=filter)
|
|
|
|
logger.debug(f"Searching with filter {filter}")
|
|
|
|
docs_to_delete = []
|
|
for result in results:
|
|
doc = {}
|
|
doc["id"] = result[self._field_mapping["id"]]
|
|
logger.debug(f"Found document to delete: {doc}")
|
|
docs_to_delete.append(doc)
|
|
|
|
if len(docs_to_delete) > 0:
|
|
logger.debug(f"Deleting {len(docs_to_delete)} documents")
|
|
self._search_client.delete_documents(docs_to_delete)
|
|
|
|
def _create_odata_filter(self, metadata_filters: MetadataFilters) -> str:
|
|
"""Generate an OData filter string using supplied metadata filters."""
|
|
odata_filter: List[str] = []
|
|
for f in metadata_filters.legacy_filters():
|
|
if not isinstance(f, ExactMatchFilter):
|
|
raise NotImplementedError(
|
|
"Only `ExactMatchFilter` filters are supported"
|
|
)
|
|
|
|
# Raise error if filtering on a metadata field that lacks a mapping to
|
|
# an index field
|
|
metadata_mapping = self._metadata_to_index_field_map.get(f.key)
|
|
|
|
if not metadata_mapping:
|
|
raise ValueError(
|
|
f"Metadata field '{f.key}' is missing a mapping to an index field, "
|
|
"provide entry in 'filterable_metadata_field_keys' for this "
|
|
"vector store"
|
|
)
|
|
|
|
index_field = metadata_mapping[0]
|
|
|
|
if len(odata_filter) > 0:
|
|
odata_filter.append(" and ")
|
|
if isinstance(f.value, str):
|
|
escaped_value = "".join([("''" if s == "'" else s) for s in f.value])
|
|
odata_filter.append(f"{index_field} eq '{escaped_value}'")
|
|
else:
|
|
odata_filter.append(f"{index_field} eq {f.value}")
|
|
|
|
odata_expr = "".join(odata_filter)
|
|
|
|
logger.info(f"Odata filter: {odata_expr}")
|
|
|
|
return odata_expr
|
|
|
|
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
|
|
odata_filter = None
|
|
if query.filters is not None:
|
|
odata_filter = self._create_odata_filter(query.filters)
|
|
azure_query_result_search: AzureQueryResultSearchBase = (
|
|
AzureQueryResultSearchDefault(
|
|
query, self._field_mapping, odata_filter, self._search_client
|
|
)
|
|
)
|
|
if query.mode == VectorStoreQueryMode.SPARSE:
|
|
azure_query_result_search = AzureQueryResultSearchSparse(
|
|
query, self._field_mapping, odata_filter, self._search_client
|
|
)
|
|
elif query.mode == VectorStoreQueryMode.HYBRID:
|
|
azure_query_result_search = AzureQueryResultSearchHybrid(
|
|
query, self._field_mapping, odata_filter, self._search_client
|
|
)
|
|
elif query.mode == VectorStoreQueryMode.SEMANTIC_HYBRID:
|
|
azure_query_result_search = AzureQueryResultSearchSemanticHybrid(
|
|
query, self._field_mapping, odata_filter, self._search_client
|
|
)
|
|
return azure_query_result_search.search()
|
|
|
|
|
|
class AzureQueryResultSearchBase:
|
|
def __init__(
|
|
self,
|
|
query: VectorStoreQuery,
|
|
field_mapping: Dict[str, str],
|
|
odata_filter: Optional[str],
|
|
search_client: Any,
|
|
) -> None:
|
|
self._query = query
|
|
self._field_mapping = field_mapping
|
|
self._odata_filter = odata_filter
|
|
self._search_client = search_client
|
|
|
|
@property
|
|
def _select_fields(self) -> List[str]:
|
|
return [
|
|
self._field_mapping["id"],
|
|
self._field_mapping["chunk"],
|
|
self._field_mapping["metadata"],
|
|
self._field_mapping["doc_id"],
|
|
]
|
|
|
|
def _create_search_query(self) -> str:
|
|
return "*"
|
|
|
|
def _create_query_vector(self) -> Optional[List[Any]]:
|
|
return None
|
|
|
|
def _create_query_result(
|
|
self, search_query: str, vectors: Optional[List[Any]]
|
|
) -> VectorStoreQueryResult:
|
|
results = self._search_client.search(
|
|
search_text=search_query,
|
|
vector_queries=vectors,
|
|
top=self._query.similarity_top_k,
|
|
select=self._select_fields,
|
|
filter=self._odata_filter,
|
|
)
|
|
|
|
id_result = []
|
|
node_result = []
|
|
score_result = []
|
|
for result in results:
|
|
node_id = result[self._field_mapping["id"]]
|
|
metadata = json.loads(result[self._field_mapping["metadata"]])
|
|
score = result["@search.score"]
|
|
chunk = result[self._field_mapping["chunk"]]
|
|
|
|
try:
|
|
node = metadata_dict_to_node(metadata)
|
|
node.set_content(chunk)
|
|
except Exception:
|
|
# NOTE: deprecated legacy logic for backward compatibility
|
|
metadata, node_info, relationships = legacy_metadata_dict_to_node(
|
|
metadata
|
|
)
|
|
|
|
node = TextNode(
|
|
text=chunk,
|
|
id_=node_id,
|
|
metadata=metadata,
|
|
start_char_idx=node_info.get("start", None),
|
|
end_char_idx=node_info.get("end", None),
|
|
relationships=relationships,
|
|
)
|
|
|
|
logger.debug(f"Retrieved node id {node_id} with node data of {node}")
|
|
|
|
id_result.append(node_id)
|
|
node_result.append(node)
|
|
score_result.append(score)
|
|
|
|
logger.debug(
|
|
f"Search query '{search_query}' returned {len(id_result)} results."
|
|
)
|
|
|
|
return VectorStoreQueryResult(
|
|
nodes=node_result, similarities=score_result, ids=id_result
|
|
)
|
|
|
|
def search(self) -> VectorStoreQueryResult:
|
|
search_query = self._create_search_query()
|
|
vectors = self._create_query_vector()
|
|
return self._create_query_result(search_query, vectors)
|
|
|
|
|
|
class AzureQueryResultSearchDefault(AzureQueryResultSearchBase):
|
|
def _create_query_vector(self) -> Optional[List[Any]]:
|
|
"""Query vector store."""
|
|
from azure.search.documents.models import VectorizedQuery
|
|
|
|
if not self._query.query_embedding:
|
|
raise ValueError("Query missing embedding")
|
|
|
|
vectorized_query = VectorizedQuery(
|
|
vector=self._query.query_embedding,
|
|
k_nearest_neighbors=self._query.similarity_top_k,
|
|
fields=self._field_mapping["embedding"],
|
|
)
|
|
vector_queries = [vectorized_query]
|
|
logger.info("Vector search with supplied embedding")
|
|
return vector_queries
|
|
|
|
|
|
class AzureQueryResultSearchSparse(AzureQueryResultSearchBase):
|
|
def _create_search_query(self) -> str:
|
|
if self._query.query_str is None:
|
|
raise ValueError("Query missing query string")
|
|
|
|
search_query = self._query.query_str
|
|
|
|
logger.info(f"Hybrid search with search text: {search_query}")
|
|
return search_query
|
|
|
|
|
|
class AzureQueryResultSearchHybrid(
|
|
AzureQueryResultSearchDefault, AzureQueryResultSearchSparse
|
|
):
|
|
def _create_query_vector(self) -> Optional[List[Any]]:
|
|
return AzureQueryResultSearchDefault._create_query_vector(self)
|
|
|
|
def _create_search_query(self) -> str:
|
|
return AzureQueryResultSearchSparse._create_search_query(self)
|
|
|
|
|
|
class AzureQueryResultSearchSemanticHybrid(AzureQueryResultSearchHybrid):
|
|
def _create_query_vector(self) -> Optional[List[Any]]:
|
|
"""Query vector store."""
|
|
from azure.search.documents.models import VectorizedQuery
|
|
|
|
if not self._query.query_embedding:
|
|
raise ValueError("Query missing embedding")
|
|
# k is set to 50 to align with the number of accept document in azure semantic reranking model.
|
|
# https://learn.microsoft.com/azure/search/semantic-search-overview
|
|
vectorized_query = VectorizedQuery(
|
|
vector=self._query.query_embedding,
|
|
k_nearest_neighbors=50,
|
|
fields=self._field_mapping["embedding"],
|
|
)
|
|
vector_queries = [vectorized_query]
|
|
logger.info("Vector search with supplied embedding")
|
|
return vector_queries
|
|
|
|
def _create_query_result(
|
|
self, search_query: str, vector_queries: Optional[List[Any]]
|
|
) -> VectorStoreQueryResult:
|
|
results = self._search_client.search(
|
|
search_text=search_query,
|
|
vector_queries=vector_queries,
|
|
top=self._query.similarity_top_k,
|
|
select=self._select_fields,
|
|
filter=self._odata_filter,
|
|
query_type="semantic",
|
|
semantic_configuration_name="mySemanticConfig",
|
|
)
|
|
|
|
id_result = []
|
|
node_result = []
|
|
score_result = []
|
|
for result in results:
|
|
node_id = result[self._field_mapping["id"]]
|
|
metadata = json.loads(result[self._field_mapping["metadata"]])
|
|
# use reranker_score instead of score
|
|
score = result["@search.reranker_score"]
|
|
chunk = result[self._field_mapping["chunk"]]
|
|
|
|
try:
|
|
node = metadata_dict_to_node(metadata)
|
|
node.set_content(chunk)
|
|
except Exception:
|
|
# NOTE: deprecated legacy logic for backward compatibility
|
|
metadata, node_info, relationships = legacy_metadata_dict_to_node(
|
|
metadata
|
|
)
|
|
|
|
node = TextNode(
|
|
text=chunk,
|
|
id_=node_id,
|
|
metadata=metadata,
|
|
start_char_idx=node_info.get("start", None),
|
|
end_char_idx=node_info.get("end", None),
|
|
relationships=relationships,
|
|
)
|
|
|
|
logger.debug(f"Retrieved node id {node_id} with node data of {node}")
|
|
|
|
id_result.append(node_id)
|
|
node_result.append(node)
|
|
score_result.append(score)
|
|
|
|
logger.debug(
|
|
f"Search query '{search_query}' returned {len(id_result)} results."
|
|
)
|
|
|
|
return VectorStoreQueryResult(
|
|
nodes=node_result, similarities=score_result, ids=id_result
|
|
)
|
|
|
|
|
|
CognitiveSearchVectorStore = AzureAISearchVectorStore
|