This commit is contained in:
parent
696116b78c
commit
2bf8c7c4d3
|
|
@ -27,5 +27,9 @@ class BGEEmbedding:
|
|||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
|
||||
def get_agg_embedding_from_queries(self, queries):
|
||||
embeddings = self.encode(queries)
|
||||
return np.mean(embeddings, axis=0) # 聚合多个嵌入
|
||||
|
||||
embedder = BGEEmbedding()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue