Skip to main content

Text Embeddings for Retrieval and RAG based on transformers

Project description


LICENSE PyPI Version Build Status Lint Status Docs Status Code Coverage

Documentation | Tutorials | 中文

Open-Retrievals is an easy-to-use python framework getting SOTA text embeddings, oriented to information retrieval and LLM retrieval augmented generation, based on PyTorch and Transformers.

  • Contrastive learning enhanced embeddings
  • LLM embeddings

Installation

Prerequisites

pip install transformers
pip install faiss
pip install peft

With pip

pip install open-retrievals

Usage

Use Pretrained sentence embedding

from retrievals import AutoModelForEmbedding

sentences = ["Hello world", "How are you?"]
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path, pooling_method="mean", normalize_embeddings=True)
sentence_embeddings = model.encode(sentences, convert_to_tensor=True)
print(sentence_embeddings)

Finetune transformers by contrastive learning

from transformers import AutoTokenizer
from retrievals import AutoModelForEmbedding, AutoModelForMatch, RetrievalTrainer, PairCollator, TripletCollator
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
from retrievals.data import  RetrievalDataset, RerankDataset


train_dataset = RetrievalDataset(args=data_args)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast=False)

model = AutoModelForEmbedding(
    model_args.model_name_or_path,
    pooling_method="cls"
)
optimizer = get_optimizer(model, lr=5e-5, weight_decay=1e-3)

lr_scheduler = get_scheduler(optimizer, num_train_steps=int(len(train_dataset) / 2 * 1))

trainer = RetrievalTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=TripletCollator(tokenizer, max_length=data_args.query_max_len),
    loss_fn=TripletLoss(),
)
trainer.optimizer = optimizer
trainer.scheduler = lr_scheduler
trainer.train()

Finetune LLM for embedding by Contrastive learning

from retrievals import AutoModelForEmbedding

model = AutoModelForEmbedding(
    "mistralai/Mistral-7B-v0.1",
    pooling_method='cls',
    query_instruction=f'Instruct: Retrieve semantically similar text\nQuery: '
)

Search by Cosine similarity/KNN

from retrievals import AutoModelForEmbedding, AutoModelForMatch

query_texts = ['A dog is chasing car.']
passage_texts = ['A man is playing a guitar.', 'A bee is flying low']
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding('')
query_embeddings = model.encode(query_texts, convert_to_tensor=True)
passage_embeddings = model.encode(passage_texts, convert_to_tensor=True)

matcher = AutoModelForMatch(method='cosine')
dists, indices = matcher.similarity_search(query_embeddings, passage_embeddings, top_k=1)

Search by Faiss

from retrievals import AutoModelForEmbedding, AutoModelForMatch

sentences = ['A dog is chasing car.', 'A man is playing a guitar.']
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
model = AutoModelForEmbedding(model_name_or_path)
model.build_index(sentences)

matcher = AutoModelForMatch()
results = matcher.faiss_search("He plays guitar.")

Rerank

from transformers import AutoTokenizer
from retrievals import RerankCollator, RerankModel, RerankTrainer, RerankDataset

train_dataset = RerankDataset(args=data_args)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast=False)

model = RerankModel(
    model_args.model_name_or_path,
    pooling_method="mean"
)
optimizer = get_optimizer(model, lr=5e-5, weight_decay=1e-3)

lr_scheduler = get_scheduler(optimizer, num_train_steps=int(len(train_dataset) / 2 * 1))

trainer = RerankTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=RerankCollator(tokenizer, max_length=data_args.query_max_len),
)
trainer.optimizer = optimizer
trainer.scheduler = lr_scheduler
trainer.train()

Reference & Acknowledge

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

open-retrievals-0.0.1.tar.gz (31.9 kB view hashes)

Uploaded Source

Built Distribution

open_retrievals-0.0.1-py3-none-any.whl (39.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page