This commit is contained in:
parent
696116b78c
commit
2bf8c7c4d3
|
|
@ -27,5 +27,9 @@ class BGEEmbedding:
|
||||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||||
|
|
||||||
|
def get_agg_embedding_from_queries(self, queries):
|
||||||
|
embeddings = self.encode(queries)
|
||||||
|
return np.mean(embeddings, axis=0) # 聚合多个嵌入
|
||||||
|
|
||||||
embedder = BGEEmbedding()
|
embedder = BGEEmbedding()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue