diff --git a/paddlenlp/transformers/llm_embed/modeling.py b/paddlenlp/transformers/llm_embed/modeling.py
index 22bb1c29e595..febd5673e1f4 100644
--- a/paddlenlp/transformers/llm_embed/modeling.py
+++ b/paddlenlp/transformers/llm_embed/modeling.py
@@ -326,6 +326,8 @@ def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
input_texts = self.preprocess_sentences_for_llara(input_texts, query_or_doc="query")
if self.model_flag == "bge-en-icl":
input_texts = self.preprocess_sentences_for_bge_en_icl(input_texts, query_or_doc="query")
+ if self.model_flag == "qwen3":
+ input_texts = self.preprocess_sentences_for_qwen3(input_texts, query_or_doc="query")
encode_results = self.encode_sentences(sentences=input_texts, model=self.query_model, tokenizer=self.tokenizer)
return encode_results
@@ -356,6 +358,8 @@ def encode_corpus(self, corpus: List[Union[Dict[str, str], str]], **kwargs) -> n
if isinstance(corpus[0], dict):
input_texts = [doc["text"] for doc in corpus]
input_titles = [doc.get("title", "") for doc in corpus]
+ if self.model_flag == "qwen3":
+ input_texts = self.preprocess_sentences_for_qwen3(input_texts, query_or_doc="doc")
encode_results = self.encode_sentences(
sentences=input_texts, titles=input_titles, model=self.corpus_model, tokenizer=self.tokenizer
@@ -415,3 +419,7 @@ def preprocess_sentences_for_llara(self, sentences: List[str], query_or_doc: str
sentences_after_process = [prefix + " " + sentence + " " + suffix for sentence in sentences_after_process]
return sentences_after_process
+
+ def preprocess_sentences_for_qwen3(self, sentences: List[str], query_or_doc: str, **kwargs) -> List[str]:
+ sentences = [f"{sentence}<|endoftext|>" for sentence in sentences]
+ return sentences
diff --git a/slm/pipelines/examples/contrastive_training/README.md b/slm/pipelines/examples/contrastive_training/README.md
index f89b4c69b209..3b6b06e91a4e 100644
--- a/slm/pipelines/examples/contrastive_training/README.md
+++ b/slm/pipelines/examples/contrastive_training/README.md
@@ -7,7 +7,7 @@
conda create --name paddle_env python=3.10
conda activate paddle_env
-# 安装 paddlenlp develop版本
+# 安装 paddlenlp develop版本
pip install --pre --upgrade paddlenlp -f https://www.paddlepaddle.org.cn/whl/paddlenlp.html
# 安装 paddlepaddle-gpu nightly版本
@@ -206,6 +206,7 @@ python -u evaluation/eval_mteb.py \
| [NV‑Embed‑v1](https://huggingface.co/nvidia/NV-Embed-v1) | 4096 |
| [BGE‑EN‑ICL](https://huggingface.co/BAAI/bge-en-icl) | 4096 |
| [LLARA‑passage](https://huggingface.co/BAAI/LLARA-passage) | 4096 |
+| [Qwen3-Embedding-8B](https://huggingface.co/Qwen/Qwen3-Embedding-8B) | 32K |
可支持配置的参数:
- `base_model_name_or_path`: 模型名称或路径
@@ -231,6 +232,7 @@ MSMARCO-Title 数据集, MRR@10, Recall@10, NDCG@10分数:
| RocketQA v1 | 36.94 | 65.67 | 43.51 |
| RocketQA v2 | 38.88 | 67.06 | 45.28 |
| bge-large-en-v1.5 | 35.30 | 64.24 | 41.96 |
+| Qwen3-Embedding-8B | 38.66 | 69.46 | 45.72 |
| repllama-v1-7b-lora-passage | 38.24 | 66.26 | 45.13 |
| NV-Embed-v1 | 38.39 | 68.02 | 45.21 |
| bge-en-icl (zero-shot) | 42.74 | 71.90 | 49.47 |
@@ -246,6 +248,7 @@ MTEB-Retrieval 数据集, MRR@10分数:
| NV‑Embed‑v1 | 65.24 | 60.28 | 45.17 | 48.14 | 80.19 | 86.78 | 69.24 | 88.36 | 39.73 | 59.40 | 66.70 | 88.35 | 34.27 | 75.17 | 42.50 | 94.33 |
| bge‑en‑icl (zero‑shot) | 69.29 | 77.83 | 57.88 | 45.69 | 82.04 | 92.50 | 65.78 | 92.76 | 39.97 | 61.84 | 69.64 | 90.22 | 41.14 | 75.13 | 56.56 | 90.33 |
| LLARA-passage | 60.11 | 38.77 | 34.58 | 36.19 | 75.50 | 81.02 | 51.72 | 86.36 | 38.81 | 57.69 | 56.85 | 80.58 | 30.15 | 73.17 | 67.20 | 93.07 |
+| Qwen3-Embedding-8B | 69.79 | 70.12 | 61.52 | 52.23 | 81.29 | 93.54 | 69.72 | 89.46 | 37.44 | 61.46 | 59.63 | 88.01 | 49.32 | 74.06 | 59.09 | 100.00 |
MTEB-Retrieval 数据集, Recall@10分数:
| 模型 | 平均分数 | ArguAna | ClimateFEVER | CQADupstackRetrieval | DBPedia | FEVER | FiQA2018 | HotpotQA | MSMARCO | NFCorpus | NQ | QuoraRetrieval | SCIDOCS | SciFact | Touche2020 | TRECCOVID |
@@ -257,6 +260,7 @@ MTEB-Retrieval 数据集, Recall@10分数:
| NV‑Embed‑v1 | 58.78 | 93.95 | 41.07 | 64.66 | 28.67 | 95.24 | 70.62 | 85.19 | 69.15 | 18.45 | 89.16 | 95.92 | 21.27 | 90.02 | 15.94 | 2.36 |
| bge‑en‑icl (zero‑shot) | 60.62 | 97.08 | 52.19 | 60.38 | 29.81 | 96.92 | 67.42 | 88.33 | 69.53 | 20.42 | 90.96 | 97.02 | 27.33 | 91.05 | 18.81 | 2.11 |
| LLARA-passage | 52.30 | 76.17 | 32.52 | 47.91 | 26.33 | 90.48 | 51.09 | 71.16 | 67.82 | 17.67 | 81.89 | 92.54 | 18.12 | 86.80 | 21.81 | 2.23 |
+| Qwen3-Embedding-8B | 60.96 | 97.51 | 52.34 | 67.80 | 28.99 | 96.01 | 71.33 | 79.05 | 67.43 | 19.95 | 84.87 | 96.16 | 34.57 | 93.50 | 22.41 | 2.50 |
MTEB-Retrieval 数据集, NDCG@10分数:
| 模型 | 平均分数 | ArguAna | ClimateFEVER | CQADupstackRetrieval | DBPedia | FEVER | FiQA2018 | HotpotQA | MSMARCO | NFCorpus | NQ | QuoraRetrieval | SCIDOCS | SciFact | Touche2020 | TRECCOVID |
@@ -268,6 +272,7 @@ MTEB-Retrieval 数据集, NDCG@10分数:
| NV‑Embed‑v1 | 58.86 | 68.30 | 34.37 | 50.27 | 48.29 | 86.58 | 62.90 | 79.92 | 46.48 | 37.98 | 71.22 | 89.20 | 20.16 | 78.30 | 23.98 | 84.91 |
| bge‑en‑icl (zero‑shot) | 61.62 | 82.34 | 45.33 | 47.27 | 50.60 | 91.91 | 59.13 | 84.90 | 46.78 | 40.66 | 73.85 | 91.03 | 25.46 | 77.91 | 30.71 | 76.38 |
| LLARA-passage | 52.48 | 47.51 | 26.13 | 37.26 | 44.12 | 81.09 | 43.98 | 69.17 | 45.49 | 37.07 | 61.76 | 82.29 | 17.30 | 76.07 | 36.73 | 81.30 |
+| Qwen3-Embedding-8B | 62.36 | 76.63 | 47.13 | 54.01 | 48.86 | 91.82 | 62.14 | 76.28 | 44.29 | 41.31 | 64.63 | 89.03 | 32.22 | 78.48 | 34.99 | 93.54 |
## 压缩
@@ -277,9 +282,9 @@ MTEB-Retrieval 数据集, NDCG@10分数:
#### 使用方法
-通过以下命令执行剪枝脚本。可指定原始模型、输出路径、要剪枝的层数以及模型中transformer层的路径。
+通过以下命令执行剪枝脚本。可指定原始模型、输出路径、要剪枝的层数以及模型中 transformer 层的路径。
-以repllama-v1-7b-lora-passage为例:
+以 repllama-v1-7b-lora-passage 为例:
```bash
python shortgpt_prune.py \
--model_name_or_path castorini/repllama-v1-7b-lora-passage \
@@ -288,21 +293,66 @@ python shortgpt_prune.py \
--layers_path "llama.layers"
```
-以NV-Embed-v1为例:
+以 NV-Embed-v1为例:
```bash
python shortgpt_prune.py \
--model_name_or_path nvidia/NV-Embed-v1 \
- --output_model_path /pruned-NV-Embed-v1_pruned_26 \
+ --output_model_path ./pruned-NV-Embed-v1_pruned_26 \
--n_prune_layers 6 \
--layers_path "layers"
```
可配置参数包括:
- `--model_name_or_path`: 原始模型的名称或本地路径。
- `--output_model_path`: 剪枝后模型的保存路径。
-- `--n_prune_layers`: 希望移除的层数。脚本会自动找出最不重要的N层。
-- `--layers_path`: 模型对象中指向transformer层列表的点分隔路径(例如repllama为`"llama.layers"`, llama为`"model.layers"`)。
+- `--n_prune_layers`: 希望移除的层数。脚本会自动找出最不重要的 N 层。
+- `--layers_path`: 模型对象中指向 transformer 层列表的点分隔路径(例如 repllama 为`"llama.layers"`, llama 为`"model.layers"`)。
-可用output_model_path路径中的模型跑评估[评估部分的代码](#评估)
+#### 性能评估
+剪枝完成后,可以使用 output_model_path 路径下的新模型进行[MTEB 评估](#评估)。
+
+在多个检索任务上评估了`RepLLaMA`模型剪枝前后的性能和推理速度。所有实验均在单张 80G A100 GPU 上进行。
+
+
+| 模型 | 指标 | MSMARCO-Title
(MRR@10) | SciFact
(NDCG@10) | FiQA2018
(NDCG@10)| QuoraRetrieval
(NDCG@10) | NFCorpus
(NDCG@10) |
+| :--- | :--- | :---: | :---: | :---: | :---: | :---: |
+| **RepLLaMA** | Batchsize | 7 | 22 | 8 | 320 | 15 |
+| | 时间 (s) | 106863 | 257 | 1381 | 1649 | 198 |
+| | 性能 | 38.33 | 76.19 | 45.95 | 88.25 | 38.02 |
+| **+shortgpt** | Batchsize | 7 | 22 | 9 | 390 | 15 |
+| | 时间 (s) | 88064 | 211 | 1169 | 1475 | 162 |
+| | 性能 | 36.31 | 73.82 | 44.36 | 88.06 | 38.05 |
+
+### 模型量化
+支持对向量模型进行量化加载,以降低显存占用和推理延迟。
+
+#### 使用方法
+```bash
+python -u evaluation/eval_mteb.py \
+ --base_model_name_or_path castorini/repllama-v1-7b-lora-passage \
+ --output_folder eval_results/repllama-v1-7b-lora-passage \
+ --task_name 'SciFact' \
+ --eval_batch_size 8 \
+ --max_seq_length 2048 \
+ --task_split dev \
+ --quant_type weight_only_int8 \
+ --kv_cache_reuse 1
+```
+可配置参数包括:
+* `--quant_type`:是否使用量化加载,可选项包括 weight_only_int8,weight_only_int4,no,默认为 no,即不进行量化
+* `--kv_cache_reuse`: 量化加载时,是否仅预分配首层 kv_cache 并重复利用,0 表示不复用,1 表示复用,默认为 0,此策略可降低量化加载时显存占用
+
+
+#### 性能评估
+在多个检索任务上评估了`RepLLaMA`模型量化加载前后的性能和推理速度。所有实验均在单张 80G A100 GPU 上进行。
+
+| 模型 | 指标 | MSMARCO-Title
(MRR@10) | SciFact
(NDCG@10) | FiQA2018
(NDCG@10)| QuoraRetrieval
(NDCG@10) | NFCorpus
(NDCG@10) |
+| :--- | :--- | :---: | :---: | :---: | :---: | :---: |
+| **RepLLaMA** | Batchsize | 7 | 22 | 8 | 320 | 15 |
+| | 时间 (s) | 106863 | 257 | 1381 | 1649 | 198 |
+| | 性能 | 38.33 | 76.19 | 45.95 | 88.25 | 38.02 |
+| **+int8量化** | Batchsize | 50 | 180 | 80 | 180 | 180 |
+| | 时间 (s) | 70888 | 172 | 904 | 1143 | 132 |
+| | 性能 | 37.21 | 75.92 | 45.67 | 88.00 | 37.71 |
## Reference
@@ -324,4 +374,4 @@ python shortgpt_prune.py \
[9] Ruiyang Ren, Yingqi Qu, Jing Liu, Wayne Xin Zhao, Qiaoqiao She, Hua Wu, Haifeng Wang, Ji-Rong Wen: RocketQAv2: A Joint Training Method for Dense Passage Retrieval and Passage Re-ranking. EMNLP 2021
-[10] Xin Men, Mingyu Xu, Qingyu Zhang, Bingning Wang, Hongyu Lin, Yaojie Lu, Xianpei Han, Weipeng Chen: Shortgpt: Layers in large language models are more redundant than you expect. ACL Findings 2025
\ No newline at end of file
+[10] Xin Men, Mingyu Xu, Qingyu Zhang, Bingning Wang, Hongyu Lin, Yaojie Lu, Xianpei Han, Weipeng Chen: Shortgpt: Layers in large language models are more redundant than you expect. ACL Findings 2025
diff --git a/slm/pipelines/examples/contrastive_training/evaluation/eval_mteb.py b/slm/pipelines/examples/contrastive_training/evaluation/eval_mteb.py
index eb78268958dc..729585adf9a0 100644
--- a/slm/pipelines/examples/contrastive_training/evaluation/eval_mteb.py
+++ b/slm/pipelines/examples/contrastive_training/evaluation/eval_mteb.py
@@ -173,6 +173,7 @@ def get_args():
normalized=False,
sentence_pooling_method=args.pooling_method,
query_instruction=args.query_instruction,
+ document_instruction=args.document_instruction,
tokenizer=tokenizer,
eval_batch_size=args.eval_batch_size,
max_seq_length=args.max_seq_length,
@@ -199,6 +200,7 @@ def get_args():
normalized=True,
sentence_pooling_method=args.pooling_method,
query_instruction=args.query_instruction,
+ document_instruction=args.document_instruction,
tokenizer=tokenizer,
eval_batch_size=args.eval_batch_size,
max_seq_length=args.max_seq_length,
@@ -214,6 +216,7 @@ def get_args():
normalized=True,
sentence_pooling_method=args.pooling_method,
query_instruction=args.query_instruction,
+ document_instruction=args.document_instruction,
tokenizer=tokenizer,
eval_batch_size=args.eval_batch_size,
max_seq_length=args.max_seq_length,
diff --git a/slm/pipelines/examples/contrastive_training/evaluation/modelling_quant.py b/slm/pipelines/examples/contrastive_training/evaluation/modelling_quant.py
index 438f2b67cd74..fb3ee3571e9b 100644
--- a/slm/pipelines/examples/contrastive_training/evaluation/modelling_quant.py
+++ b/slm/pipelines/examples/contrastive_training/evaluation/modelling_quant.py
@@ -215,7 +215,6 @@ def __init__(
self.sentence_pooling_method = sentence_pooling_method
self.query_instruction = query_instruction
self.document_instruction = document_instruction
- self.document_instruction = document_instruction
self.eval_batch_size = eval_batch_size
self.max_seq_length = max_seq_length
self.model_flag = model_flag