Text Embeddings for Retrieval and RAG based on transformers
Project description
Documentation | Tutorials | 中文
Open-Retrievals is an easy-to-use python framework getting SOTA text embeddings, oriented to information retrieval and LLM retrieval augmented generation, based on PyTorch and Transformers.
- Contrastive learning enhanced embeddings
- LLM embeddings
Installation
Prerequisites
pip install transformers
pip install faiss
pip install peft
With pip
pip install open-retrievals
Usage
Use Pretrained sentence embedding
from retrievals import AutoModelForEmbedding
sentences = ["Hello world", "How are you?"]
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path, pooling_method="mean", normalize_embeddings=True)
sentence_embeddings = model.encode(sentences, convert_to_tensor=True)
print(sentence_embeddings)
Finetune transformers by contrastive learning
from transformers import AutoTokenizer
from retrievals import AutoModelForEmbedding, AutoModelForMatch, RetrievalTrainer, PairCollator, TripletCollator
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
from retrievals.data import RetrievalDataset, RerankDataset
train_dataset = RetrievalDataset(args=data_args)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast=False)
model = AutoModelForEmbedding(
model_args.model_name_or_path,
pooling_method="cls"
)
optimizer = get_optimizer(model, lr=5e-5, weight_decay=1e-3)
lr_scheduler = get_scheduler(optimizer, num_train_steps=int(len(train_dataset) / 2 * 1))
trainer = RetrievalTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=TripletCollator(tokenizer, max_length=data_args.query_max_len),
loss_fn=TripletLoss(),
)
trainer.optimizer = optimizer
trainer.scheduler = lr_scheduler
trainer.train()
Finetune LLM for embedding by Contrastive learning
from retrievals import AutoModelForEmbedding
model = AutoModelForEmbedding(
"mistralai/Mistral-7B-v0.1",
pooling_method='cls',
query_instruction=f'Instruct: Retrieve semantically similar text\nQuery: '
)
Search by Cosine similarity/KNN
from retrievals import AutoModelForEmbedding, AutoModelForMatch
query_texts = ['A dog is chasing car.']
passage_texts = ['A man is playing a guitar.', 'A bee is flying low']
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding('')
query_embeddings = model.encode(query_texts, convert_to_tensor=True)
passage_embeddings = model.encode(passage_texts, convert_to_tensor=True)
matcher = AutoModelForMatch(method='cosine')
dists, indices = matcher.similarity_search(query_embeddings, passage_embeddings, top_k=1)
Search by Faiss
from retrievals import AutoModelForEmbedding, AutoModelForMatch
sentences = ['A dog is chasing car.', 'A man is playing a guitar.']
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path)
model.build_index(sentences)
matcher = AutoModelForMatch()
results = matcher.faiss_search("He plays guitar.")
Rerank
from transformers import AutoTokenizer
from retrievals import RerankCollator, RerankModel, RerankTrainer, RerankDataset
train_dataset = RerankDataset(args=data_args)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast=False)
model = RerankModel(
model_args.model_name_or_path,
pooling_method="mean"
)
optimizer = get_optimizer(model, lr=5e-5, weight_decay=1e-3)
lr_scheduler = get_scheduler(optimizer, num_train_steps=int(len(train_dataset) / 2 * 1))
trainer = RerankTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=RerankCollator(tokenizer, max_length=data_args.query_max_len),
)
trainer.optimizer = optimizer
trainer.scheduler = lr_scheduler
trainer.train()
Reference & Acknowledge
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
open-retrievals-0.0.1.tar.gz
(31.9 kB
view hashes)
Built Distribution
Close
Hashes for open_retrievals-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0220a7154d7c2e313f4240926742e7b9daa8e2d26f483c64e77cb7e2901bf8a9 |
|
MD5 | eec44c7db2c31fb8e832a5d4eb242859 |
|
BLAKE2b-256 | 91a951fce99d7de8e843ed5ac41f74277a824f21e2f2f2572dab67bea4c3ebbd |