Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion FlagEmbedding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .abc.inference import *
from .inference import *
from .evaluation import *
8 changes: 4 additions & 4 deletions FlagEmbedding/abc/evaluation/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from typing import List, Union, Tuple

from FlagEmbedding import FlagAutoModel, FlagAutoReranker
from FlagEmbedding import FlagAutoModel, FlagAutoReranker, AbsEmbedder, AbsReranker

from .arguments import AbsEvalArgs, AbsEvalModelArgs
from .evaluator import AbsEvaluator
Expand Down Expand Up @@ -34,15 +34,15 @@ def __init__(
self.evaluator = self.load_evaluator()

@staticmethod
def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]:
def get_models(model_args: AbsEvalModelArgs) -> Tuple[AbsEmbedder, Union[AbsReranker, None]]:
"""Get the embedding and reranker model

Args:
model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments.

Returns:
Tuple[FlagAutoModel, Union[FlagAutoReranker, None]]: A :class:FlagAutoModel object of embedding model, and
:class:FlagAutoReranker object of reranker model if path provided.
Tuple[AbsEmbedder, Union[AbsReranker, None]]: A :class:AbsEmbedder object of embedding model, and
:class:AbsReranker object of reranker model if path provided.
"""
embedder = FlagAutoModel.from_finetuned(
model_name_or_path=model_args.embedder_name_or_path,
Expand Down
8 changes: 4 additions & 4 deletions FlagEmbedding/abc/finetune/embedder/AbsModeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import nn, Tensor
import torch.nn.functional as F
import torch.distributed as dist
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizer
from transformers.file_utils import ModelOutput

import logging
Expand All @@ -29,7 +29,7 @@ class AbsEmbedderModel(ABC, nn.Module):

Args:
base_model: The base model to train on.
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
tokenizer (PreTrainedTokenizer, optional): The tokenizer to use. Defaults to ``None``.
negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
Expand All @@ -39,13 +39,13 @@ class AbsEmbedderModel(ABC, nn.Module):
def __init__(
self,
base_model,
tokenizer: AutoTokenizer = None,
tokenizer: PreTrainedTokenizer = None,
negatives_cross_device: bool = False,
temperature: float = 1.0,
sub_batch_size: int = -1,
kd_loss_type: str = 'kl_div',
):
super().__init__()
nn.Module.__init__(self)
self.model = base_model
self.tokenizer = tokenizer

Expand Down
8 changes: 4 additions & 4 deletions FlagEmbedding/abc/finetune/reranker/AbsModeling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import nn, Tensor
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizer
from transformers.file_utils import ModelOutput

import logging
Expand All @@ -22,16 +22,16 @@ class AbsRerankerModel(ABC, nn.Module):

Args:
base_model: The base model to train on.
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
tokenizer (PreTrainedTokenizer, optional): The tokenizer to use. Defaults to ``None``.
train_batch_size (int, optional): Batch size used for training. Defaults to ``4``.
"""
def __init__(
self,
base_model: None,
tokenizer: AutoTokenizer = None,
tokenizer: PreTrainedTokenizer = None,
train_batch_size: int = 4,
):
super().__init__()
nn.Module.__init__(self)
self.model = base_model
self.tokenizer = tokenizer
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
Expand Down
10 changes: 5 additions & 5 deletions FlagEmbedding/finetune/embedder/decoder_only/base/modeling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

import torch
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModel, PreTrainedModel, PreTrainedTokenizer

from FlagEmbedding.abc.finetune.embedder import AbsEmbedderModel

Expand All @@ -12,8 +12,8 @@ class BiDecoderOnlyEmbedderModel(AbsEmbedderModel):
"""Embedder model class for decoder only model.

Args:
base_model (AutoModel): The base model to train on.
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
base_model (PreTrainedModel): The base model to train on.
tokenizer (PreTrainedTokenizer, optional): The tokenizer to use. Defaults to ``None``.
negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
Expand All @@ -26,8 +26,8 @@ class BiDecoderOnlyEmbedderModel(AbsEmbedderModel):

def __init__(
self,
base_model: AutoModel,
tokenizer: AutoTokenizer = None,
base_model: PreTrainedModel,
tokenizer: PreTrainedTokenizer = None,
negatives_cross_device: bool = False,
temperature: float = 1.0,
sub_batch_size: int = -1,
Expand Down
10 changes: 5 additions & 5 deletions FlagEmbedding/finetune/embedder/decoder_only/icl/modeling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

import torch
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModel, PreTrainedModel, PreTrainedTokenizer

from FlagEmbedding.abc.finetune.embedder import AbsEmbedderModel

Expand All @@ -12,8 +12,8 @@ class BiDecoderOnlyEmbedderICLModel(AbsEmbedderModel):
"""Embedder model class for decoder only model.

Args:
base_model (AutoModel): The base model to train on.
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
base_model (PreTrainedModel): The base model to train on.
tokenizer (PreTrainedTokenizer, optional): The tokenizer to use. Defaults to ``None``.
negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
Expand All @@ -26,8 +26,8 @@ class BiDecoderOnlyEmbedderICLModel(AbsEmbedderModel):

def __init__(
self,
base_model: AutoModel,
tokenizer: AutoTokenizer = None,
base_model: PreTrainedModel,
tokenizer: PreTrainedTokenizer = None,
negatives_cross_device: bool = False,
temperature: float = 1.0,
sub_batch_size: int = -1,
Expand Down
10 changes: 5 additions & 5 deletions FlagEmbedding/finetune/embedder/encoder_only/base/modeling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

import torch
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModel, PreTrainedModel, PreTrainedTokenizer

from FlagEmbedding.abc.finetune.embedder import AbsEmbedderModel

Expand All @@ -12,8 +12,8 @@ class BiEncoderOnlyEmbedderModel(AbsEmbedderModel):
"""Embedder class for encoder only model.

Args:
base_model (AutoModel): The base model to train on.
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
base_model (PreTrainedModel): The base model to train on.
tokenizer (PreTrainedTokenizer, optional): The tokenizer to use. Defaults to ``None``.
negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
Expand All @@ -26,8 +26,8 @@ class BiEncoderOnlyEmbedderModel(AbsEmbedderModel):

def __init__(
self,
base_model: AutoModel,
tokenizer: AutoTokenizer = None,
base_model: PreTrainedModel,
tokenizer: PreTrainedTokenizer = None,
negatives_cross_device: bool = False,
temperature: float = 1.0,
sub_batch_size: int = -1,
Expand Down
40 changes: 28 additions & 12 deletions FlagEmbedding/finetune/embedder/encoder_only/m3/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import Tensor
import torch.nn.functional as F
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizer

from FlagEmbedding.abc.finetune.embedder import AbsEmbedderModel, EmbedderOutput

Expand All @@ -16,8 +16,8 @@ class EncoderOnlyEmbedderM3Model(AbsEmbedderModel):
"""Embedder class for M3 model.

Args:
base_model (AutoModel): The base model to train on.
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
base_model (dict[str, Any]): The base model to train on.
tokenizer (PreTrainedTokenizer, optional): The tokenizer to use. Defaults to ``None``.
negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
Expand All @@ -32,7 +32,7 @@ class EncoderOnlyEmbedderM3Model(AbsEmbedderModel):
def __init__(
self,
base_model: Dict[str, Any],
tokenizer: AutoTokenizer = None,
tokenizer: PreTrainedTokenizer = None,
negatives_cross_device: bool = False,
temperature: float = 1,
sub_batch_size: int = -1,
Expand Down Expand Up @@ -122,14 +122,26 @@ def _sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = Tr
token_weights = torch.relu(self.sparse_linear(hidden_state))
if not return_embedding: return token_weights

sparse_embedding = torch.zeros(
input_ids.size(0), self.vocab_size,
dtype=token_weights.dtype,
device=token_weights.device
)
sparse_embedding = sparse_embedding.scatter_reduce(
dim=-1, index=input_ids, src=token_weights.squeeze(-1), reduce="amax"
)
if self.training:
sparse_embedding = torch.zeros(
input_ids.size(0), input_ids.size(1), self.vocab_size,
dtype=token_weights.dtype,
device=token_weights.device
)
sparse_embedding = torch.scatter(sparse_embedding, dim=-1, index=input_ids.unsqueeze(-1), src=token_weights)
sparse_embedding = torch.max(sparse_embedding, dim=1).values
else:
# Optimize suggestion from issue #1364: https://github.com/FlagOpen/FlagEmbedding/issues/1364
# Disable when self.training = True, otherwise will cause:
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
sparse_embedding = torch.zeros(
input_ids.size(0), self.vocab_size,
dtype=token_weights.dtype,
device=token_weights.device
)
sparse_embedding = sparse_embedding.scatter_reduce(
dim=-1, index=input_ids, src=token_weights.squeeze(-1), reduce="amax"
)

unused_tokens = [
self.tokenizer.cls_token_id, self.tokenizer.eos_token_id,
Expand Down Expand Up @@ -528,6 +540,10 @@ def forward(self,
"""
assert return_dense or return_sparse or return_colbert_vecs, 'Must choose one or more from `return_colbert_vecs`, `return_sparse`, `return_dense` to set `True`!'

# this is for sparse embedding computation: using optimization suggestion from
# issue #1364: https://github.com/FlagOpen/FlagEmbedding/issues/1364
self.training = False

last_hidden_state = self.model(**text_input, return_dict=True).last_hidden_state

output = {}
Expand Down
2 changes: 1 addition & 1 deletion Tutorials/1_Embedding/1.2.2_Auto_Embedder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Then use the model exactly same to `FlagModel` (`FlagM3Model` if using BGE M3, `FlagLLMModel` if using BGE Multilingual Gemma2, `FlagICLModel` if using BGE ICL)"
"Then use the model exactly same to `FlagModel` (`BGEM3FlagModel` if using BGE M3, `FlagLLMModel` if using BGE Multilingual Gemma2, `FlagICLModel` if using BGE ICL)"
]
},
{
Expand Down
30 changes: 24 additions & 6 deletions research/BGE_M3/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,26 @@ def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = Tru
token_weights = torch.relu(self.sparse_linear(hidden_state))
if not return_embedding: return token_weights

sparse_embedding = torch.zeros(input_ids.size(0), self.vocab_size,
dtype=token_weights.dtype,
device=token_weights.device)
sparse_embedding = sparse_embedding.scatter_reduce(
dim=-1, index=input_ids, src=token_weights.squeeze(-1), reduce="amax"
)
if self.training:
sparse_embedding = torch.zeros(
input_ids.size(0), input_ids.size(1), self.vocab_size,
dtype=token_weights.dtype,
device=token_weights.device
)
sparse_embedding = torch.scatter(sparse_embedding, dim=-1, index=input_ids.unsqueeze(-1), src=token_weights)
sparse_embedding = torch.max(sparse_embedding, dim=1).values
else:
# Optimize suggestion from issue #1364: https://github.com/FlagOpen/FlagEmbedding/issues/1364
# Disable when self.training = True, otherwise will cause:
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
sparse_embedding = torch.zeros(
input_ids.size(0), self.vocab_size,
dtype=token_weights.dtype,
device=token_weights.device
)
sparse_embedding = sparse_embedding.scatter_reduce(
dim=-1, index=input_ids, src=token_weights.squeeze(-1), reduce="amax"
)

unused_tokens = [self.tokenizer.cls_token_id, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id,
self.tokenizer.unk_token_id]
Expand Down Expand Up @@ -349,6 +363,10 @@ def forward(self,
return_sparse_embedding: bool = False):
assert return_dense or return_sparse or return_colbert, 'Must choose one or more from `return_colbert`, `return_sparse`, `return_dense` to set `True`!'

# this is for sparse embedding computation: using optimization suggestion from
# issue #1364: https://github.com/FlagOpen/FlagEmbedding/issues/1364
self.training = False

last_hidden_state = self.model(**text_input, return_dict=True).last_hidden_state

output = {}
Expand Down
Loading