Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions paddlenlp/transformers/llm_embed/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
input_texts = self.preprocess_sentences_for_llara(input_texts, query_or_doc="query")
if self.model_flag == "bge-en-icl":
input_texts = self.preprocess_sentences_for_bge_en_icl(input_texts, query_or_doc="query")
if self.model_flag == "qwen3":
input_texts = self.preprocess_sentences_for_qwen3(input_texts, query_or_doc="query")

encode_results = self.encode_sentences(sentences=input_texts, model=self.query_model, tokenizer=self.tokenizer)
return encode_results
Expand Down Expand Up @@ -356,6 +358,8 @@ def encode_corpus(self, corpus: List[Union[Dict[str, str], str]], **kwargs) -> n
if isinstance(corpus[0], dict):
input_texts = [doc["text"] for doc in corpus]
input_titles = [doc.get("title", "") for doc in corpus]
if self.model_flag == "qwen3":
input_texts = self.preprocess_sentences_for_qwen3(input_texts, query_or_doc="doc")

encode_results = self.encode_sentences(
sentences=input_texts, titles=input_titles, model=self.corpus_model, tokenizer=self.tokenizer
Expand Down Expand Up @@ -415,3 +419,7 @@ def preprocess_sentences_for_llara(self, sentences: List[str], query_or_doc: str
sentences_after_process = [prefix + " " + sentence + " " + suffix for sentence in sentences_after_process]

return sentences_after_process

def preprocess_sentences_for_qwen3(self, sentences: List[str], query_or_doc: str, **kwargs) -> List[str]:
sentences = [f"{sentence}<|endoftext|>" for sentence in sentences]
return sentences
68 changes: 59 additions & 9 deletions slm/pipelines/examples/contrastive_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
conda create --name paddle_env python=3.10
conda activate paddle_env

# 安装 paddlenlp develop版本
# 安装 paddlenlp develop版本
pip install --pre --upgrade paddlenlp -f https://www.paddlepaddle.org.cn/whl/paddlenlp.html

# 安装 paddlepaddle-gpu nightly版本
Expand Down Expand Up @@ -206,6 +206,7 @@ python -u evaluation/eval_mteb.py \
| [NV‑Embed‑v1](https://huggingface.co/nvidia/NV-Embed-v1) | 4096 |
| [BGE‑EN‑ICL](https://huggingface.co/BAAI/bge-en-icl) | 4096 |
| [LLARA‑passage](https://huggingface.co/BAAI/LLARA-passage) | 4096 |
| [Qwen3-Embedding-8B](https://huggingface.co/Qwen/Qwen3-Embedding-8B) | 32K |

可支持配置的参数:
- `base_model_name_or_path`: 模型名称或路径
Expand All @@ -231,6 +232,7 @@ MSMARCO-Title 数据集, MRR@10, Recall@10, NDCG@10分数:
| RocketQA v1 | 36.94 | 65.67 | 43.51 |
| RocketQA v2 | 38.88 | 67.06 | 45.28 |
| bge-large-en-v1.5 | 35.30 | 64.24 | 41.96 |
| Qwen3-Embedding-8B | 38.66 | 69.46 | 45.72 |
| repllama-v1-7b-lora-passage | 38.24 | 66.26 | 45.13 |
| NV-Embed-v1 | 38.39 | 68.02 | 45.21 |
| bge-en-icl (zero-shot) | 42.74 | 71.90 | 49.47 |
Expand All @@ -246,6 +248,7 @@ MTEB-Retrieval 数据集, MRR@10分数:
| NV‑Embed‑v1 | 65.24 | 60.28 | 45.17 | 48.14 | 80.19 | 86.78 | 69.24 | 88.36 | 39.73 | 59.40 | 66.70 | 88.35 | 34.27 | 75.17 | 42.50 | 94.33 |
| bge‑en‑icl (zero‑shot) | 69.29 | 77.83 | 57.88 | 45.69 | 82.04 | 92.50 | 65.78 | 92.76 | 39.97 | 61.84 | 69.64 | 90.22 | 41.14 | 75.13 | 56.56 | 90.33 |
| LLARA-passage | 60.11 | 38.77 | 34.58 | 36.19 | 75.50 | 81.02 | 51.72 | 86.36 | 38.81 | 57.69 | 56.85 | 80.58 | 30.15 | 73.17 | 67.20 | 93.07 |
| Qwen3-Embedding-8B | 69.79 | 70.12 | 61.52 | 52.23 | 81.29 | 93.54 | 69.72 | 89.46 | 37.44 | 61.46 | 59.63 | 88.01 | 49.32 | 74.06 | 59.09 | 100.00 |

MTEB-Retrieval 数据集, Recall@10分数:
| 模型 | 平均分数 | ArguAna | ClimateFEVER | CQADupstackRetrieval | DBPedia | FEVER | FiQA2018 | HotpotQA | MSMARCO | NFCorpus | NQ | QuoraRetrieval | SCIDOCS | SciFact | Touche2020 | TRECCOVID |
Expand All @@ -257,6 +260,7 @@ MTEB-Retrieval 数据集, Recall@10分数:
| NV‑Embed‑v1 | 58.78 | 93.95 | 41.07 | 64.66 | 28.67 | 95.24 | 70.62 | 85.19 | 69.15 | 18.45 | 89.16 | 95.92 | 21.27 | 90.02 | 15.94 | 2.36 |
| bge‑en‑icl (zero‑shot) | 60.62 | 97.08 | 52.19 | 60.38 | 29.81 | 96.92 | 67.42 | 88.33 | 69.53 | 20.42 | 90.96 | 97.02 | 27.33 | 91.05 | 18.81 | 2.11 |
| LLARA-passage | 52.30 | 76.17 | 32.52 | 47.91 | 26.33 | 90.48 | 51.09 | 71.16 | 67.82 | 17.67 | 81.89 | 92.54 | 18.12 | 86.80 | 21.81 | 2.23 |
| Qwen3-Embedding-8B | 60.96 | 97.51 | 52.34 | 67.80 | 28.99 | 96.01 | 71.33 | 79.05 | 67.43 | 19.95 | 84.87 | 96.16 | 34.57 | 93.50 | 22.41 | 2.50 |

MTEB-Retrieval 数据集, NDCG@10分数:
| 模型 | 平均分数 | ArguAna | ClimateFEVER | CQADupstackRetrieval | DBPedia | FEVER | FiQA2018 | HotpotQA | MSMARCO | NFCorpus | NQ | QuoraRetrieval | SCIDOCS | SciFact | Touche2020 | TRECCOVID |
Expand All @@ -268,6 +272,7 @@ MTEB-Retrieval 数据集, NDCG@10分数:
| NV‑Embed‑v1 | 58.86 | 68.30 | 34.37 | 50.27 | 48.29 | 86.58 | 62.90 | 79.92 | 46.48 | 37.98 | 71.22 | 89.20 | 20.16 | 78.30 | 23.98 | 84.91 |
| bge‑en‑icl (zero‑shot) | 61.62 | 82.34 | 45.33 | 47.27 | 50.60 | 91.91 | 59.13 | 84.90 | 46.78 | 40.66 | 73.85 | 91.03 | 25.46 | 77.91 | 30.71 | 76.38 |
| LLARA-passage | 52.48 | 47.51 | 26.13 | 37.26 | 44.12 | 81.09 | 43.98 | 69.17 | 45.49 | 37.07 | 61.76 | 82.29 | 17.30 | 76.07 | 36.73 | 81.30 |
| Qwen3-Embedding-8B | 62.36 | 76.63 | 47.13 | 54.01 | 48.86 | 91.82 | 62.14 | 76.28 | 44.29 | 41.31 | 64.63 | 89.03 | 32.22 | 78.48 | 34.99 | 93.54 |


## 压缩
Expand All @@ -277,9 +282,9 @@ MTEB-Retrieval 数据集, NDCG@10分数:

#### 使用方法

通过以下命令执行剪枝脚本。可指定原始模型、输出路径、要剪枝的层数以及模型中transformer层的路径
通过以下命令执行剪枝脚本。可指定原始模型、输出路径、要剪枝的层数以及模型中 transformer 层的路径

以repllama-v1-7b-lora-passage为例
以 repllama-v1-7b-lora-passage 为例
```bash
python shortgpt_prune.py \
--model_name_or_path castorini/repllama-v1-7b-lora-passage \
Expand All @@ -288,21 +293,66 @@ python shortgpt_prune.py \
--layers_path "llama.layers"
```

以NV-Embed-v1为例:
以 NV-Embed-v1为例:
```bash
python shortgpt_prune.py \
--model_name_or_path nvidia/NV-Embed-v1 \
--output_model_path /pruned-NV-Embed-v1_pruned_26 \
--output_model_path ./pruned-NV-Embed-v1_pruned_26 \
--n_prune_layers 6 \
--layers_path "layers"
```
可配置参数包括:
- `--model_name_or_path`: 原始模型的名称或本地路径。
- `--output_model_path`: 剪枝后模型的保存路径。
- `--n_prune_layers`: 希望移除的层数。脚本会自动找出最不重要的N层
- `--layers_path`: 模型对象中指向transformer层列表的点分隔路径(例如repllama为`"llama.layers"`, llama为`"model.layers"`)。
- `--n_prune_layers`: 希望移除的层数。脚本会自动找出最不重要的 N 层
- `--layers_path`: 模型对象中指向 transformer 层列表的点分隔路径(例如 repllama 为`"llama.layers"`, llama 为`"model.layers"`)。

可用output_model_path路径中的模型跑评估[评估部分的代码](#评估)
#### 性能评估
剪枝完成后,可以使用 output_model_path 路径下的新模型进行[MTEB 评估](#评估)。

在多个检索任务上评估了`RepLLaMA`模型剪枝前后的性能和推理速度。所有实验均在单张 80G A100 GPU 上进行。


| 模型 | 指标 | MSMARCO-Title<br>(MRR@10) | SciFact<br>(NDCG@10) | FiQA2018<br>(NDCG@10)| QuoraRetrieval<br>(NDCG@10) | NFCorpus<br>(NDCG@10) |
| :--- | :--- | :---: | :---: | :---: | :---: | :---: |
| **RepLLaMA** | Batchsize | 7 | 22 | 8 | 320 | 15 |
| | 时间 (s) | 106863 | 257 | 1381 | 1649 | 198 |
| | 性能 | 38.33 | 76.19 | 45.95 | 88.25 | 38.02 |
| **+shortgpt** | Batchsize | 7 | 22 | 9 | 390 | 15 |
| | 时间 (s) | 88064 | 211 | 1169 | 1475 | 162 |
| | 性能 | 36.31 | 73.82 | 44.36 | 88.06 | 38.05 |

### 模型量化
支持对向量模型进行量化加载,以降低显存占用和推理延迟。

#### 使用方法
```bash
python -u evaluation/eval_mteb.py \
--base_model_name_or_path castorini/repllama-v1-7b-lora-passage \
--output_folder eval_results/repllama-v1-7b-lora-passage \
--task_name 'SciFact' \
--eval_batch_size 8 \
--max_seq_length 2048 \
--task_split dev \
--quant_type weight_only_int8 \
--kv_cache_reuse 1
```
可配置参数包括:
* `--quant_type`:是否使用量化加载,可选项包括 weight_only_int8,weight_only_int4,no,默认为 no,即不进行量化
* `--kv_cache_reuse`: 量化加载时,是否仅预分配首层 kv_cache 并重复利用,0 表示不复用,1 表示复用,默认为 0,此策略可降低量化加载时显存占用


#### 性能评估
在多个检索任务上评估了`RepLLaMA`模型量化加载前后的性能和推理速度。所有实验均在单张 80G A100 GPU 上进行。

| 模型 | 指标 | MSMARCO-Title<br>(MRR@10) | SciFact<br>(NDCG@10) | FiQA2018<br>(NDCG@10)| QuoraRetrieval<br>(NDCG@10) | NFCorpus<br>(NDCG@10) |
| :--- | :--- | :---: | :---: | :---: | :---: | :---: |
| **RepLLaMA** | Batchsize | 7 | 22 | 8 | 320 | 15 |
| | 时间 (s) | 106863 | 257 | 1381 | 1649 | 198 |
| | 性能 | 38.33 | 76.19 | 45.95 | 88.25 | 38.02 |
| **+int8量化** | Batchsize | 50 | 180 | 80 | 180 | 180 |
| | 时间 (s) | 70888 | 172 | 904 | 1143 | 132 |
| | 性能 | 37.21 | 75.92 | 45.67 | 88.00 | 37.71 |

## Reference

Expand All @@ -324,4 +374,4 @@ python shortgpt_prune.py \

[9] Ruiyang Ren, Yingqi Qu, Jing Liu, Wayne Xin Zhao, Qiaoqiao She, Hua Wu, Haifeng Wang, Ji-Rong Wen: RocketQAv2: A Joint Training Method for Dense Passage Retrieval and Passage Re-ranking. EMNLP 2021

[10] Xin Men, Mingyu Xu, Qingyu Zhang, Bingning Wang, Hongyu Lin, Yaojie Lu, Xianpei Han, Weipeng Chen: Shortgpt: Layers in large language models are more redundant than you expect. ACL Findings 2025
[10] Xin Men, Mingyu Xu, Qingyu Zhang, Bingning Wang, Hongyu Lin, Yaojie Lu, Xianpei Han, Weipeng Chen: Shortgpt: Layers in large language models are more redundant than you expect. ACL Findings 2025
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def get_args():
normalized=False,
sentence_pooling_method=args.pooling_method,
query_instruction=args.query_instruction,
document_instruction=args.document_instruction,
tokenizer=tokenizer,
eval_batch_size=args.eval_batch_size,
max_seq_length=args.max_seq_length,
Expand All @@ -199,6 +200,7 @@ def get_args():
normalized=True,
sentence_pooling_method=args.pooling_method,
query_instruction=args.query_instruction,
document_instruction=args.document_instruction,
tokenizer=tokenizer,
eval_batch_size=args.eval_batch_size,
max_seq_length=args.max_seq_length,
Expand All @@ -214,6 +216,7 @@ def get_args():
normalized=True,
sentence_pooling_method=args.pooling_method,
query_instruction=args.query_instruction,
document_instruction=args.document_instruction,
tokenizer=tokenizer,
eval_batch_size=args.eval_batch_size,
max_seq_length=args.max_seq_length,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def __init__(
self.sentence_pooling_method = sentence_pooling_method
self.query_instruction = query_instruction
self.document_instruction = document_instruction
self.document_instruction = document_instruction
self.eval_batch_size = eval_batch_size
self.max_seq_length = max_seq_length
self.model_flag = model_flag
Expand Down
Loading