Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddlenlp/experimental/transformers/mistral/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,7 @@ def forward(
excess_blocks=None,
draft_tokens=None,
output_padding_offset=None,
**kwargs,
):
outputs = self.Mistral(
input_ids,
Expand All @@ -1578,6 +1579,7 @@ def forward(
excess_blocks=excess_blocks,
draft_tokens=draft_tokens,
output_padding_offset=output_padding_offset,
**kwargs,
)

hidden_states = outputs[0]
Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/transformers/llm_embed/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def sentence_embedding(self, hidden_state, mask):
last_token_indices = sequence_lengths - 1
embeddings = hidden_state[paddle.arange(hidden_state.shape[0]), last_token_indices]
return embeddings
elif self.sentence_pooling_method == "last_8":
last_8_embeddings = hidden_state[paddle.arange(hidden_state.shape[0]), -8:]
embeddings = paddle.mean(last_8_embeddings, axis=1)
return embeddings
else:
raise ValueError(f"Invalid sentence pooling method: {self.sentence_pooling_method}")

Expand Down
9 changes: 5 additions & 4 deletions slm/pipelines/examples/contrastive_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ pip install -r slm/pipelines/examples/contrastive_training/requirements.txt
```


下载 DuReader-Retrieval 中文数据集:
下载 DuReader-Retrieval 和 MMarco-Retrieval 中文数据集:
```
cd data
wget https://paddlenlp.bj.bcebos.com/datasets/dureader_dual.train.jsonl
python download_mmarco.py
```

## 训练
Expand Down Expand Up @@ -73,7 +74,7 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" train.py --do_tr
--query_instruction_for_retrieval "query: " \
--passage_instruction_for_retrieval "" \
--model_name_or_path ${model_name} \
--output_dir ${output_dir}$ \
--output_dir ${output_dir} \
--save_steps 100 \
--train_data ./data/dureader_dual.train.jsonl \
--bf16 \
Expand Down Expand Up @@ -295,7 +296,7 @@ python shortgpt_prune.py \
--model_name_or_path castorini/repllama-v1-7b-lora-passage \
--output_model_path ./pruned-repllama-v1-7b-lora-passage \
--n_prune_layers 6 \
--layers_path "llama.layers"
--layers_path "layers"
```

以 NV-Embed-v1为例:
Expand All @@ -310,7 +311,7 @@ python shortgpt_prune.py \
- `--model_name_or_path`: 原始模型的名称或本地路径。
- `--output_model_path`: 剪枝后模型的保存路径。
- `--n_prune_layers`: 希望移除的层数。脚本会自动找出最不重要的 N 层。
- `--layers_path`: 模型对象中指向 transformer 层列表的点分隔路径(例如 repllama 为`"llama.layers"`, llama 为`"model.layers"`)
- `--layers_path`: 模型对象中指向 transformer 层列表的点分隔路径。

#### 性能评估
剪枝完成后,可以使用 `output_model_path` 路径下的新模型进行[MTEB 评估](#评估)。
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json

import tqdm
from datasets import load_dataset

dataset = load_dataset("unicamp-dl/mmarco", "chinese")
print(dataset["train"][1])
print(len(dataset["train"]))


fw = open("./mmarco.jsonl", "w")

i = 0
for data in tqdm.tqdm(dataset["train"]):

data = {"query": data["query"], "pos": [data["positive"]], "neg": [data["negative"]]}

fw.write(json.dumps(data, ensure_ascii=False) + "\n")
i += 1
# if i > 200000:
# break
fw.close()
18 changes: 15 additions & 3 deletions slm/pipelines/examples/contrastive_training/shortgpt_prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from paddle.io import DataLoader
from tqdm import tqdm

from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer, NVEncodeModel
from paddlenlp.transformers import AutoModel, AutoTokenizer, NVEncodeModel


# =====================================================================================
Expand Down Expand Up @@ -73,7 +73,19 @@ def __init__(self, model_name: str, layers_path: str):
model_name, tokenizer_path=model_name, query_instruction="", document_instruction=""
)
else:
self.model = AutoModelForCausalLM.from_pretrained(model_name, dtype=paddle.float16)
self.model = AutoModel.from_pretrained(model_name, dtype=paddle.float16)

import sys
from io import StringIO

buffer = StringIO()
sys.stdout = buffer
print(self.model)
sys.stdout = sys.__stdout__

model_str = buffer.getvalue()
print("\n=== Model structure (first 5 lines) ===")
print("\n".join(model_str.splitlines()[:5]))

self.model.eval()
print("Model loaded successfully for importance evaluation.")
Expand Down Expand Up @@ -335,7 +347,7 @@ def main():

prune_order = sorted(range(len(short_model.importances)), key=lambda i: short_model.importances[i])
layers_to_delete = set(prune_order[: args.n_prune_layers])

print("\n--- Importance Calculation Complete ---")
print(f"Calculated importances: {[f'{v:.2f}' for v in short_model.importances]}")
print(f"Pruning order (least to most important): {prune_order}")
Expand Down
39 changes: 39 additions & 0 deletions slm/pipelines/examples/contrastive_training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def main():
dtype=dtype,
)
else:
if "llara" in model_args.model_name_or_path.lower():
model_flag = "llara"
tokenizer.padding_side = "left"
else:
model_flag = "NA"
model = BiEncoderModel(
model_name_or_path=model_args.model_name_or_path,
normalized=model_args.normalized,
Expand All @@ -110,6 +115,7 @@ def main():
matryoshka_dims=training_args.matryoshka_dims if training_args.use_matryoshka else None,
matryoshka_loss_weights=training_args.matryoshka_loss_weights if training_args.use_matryoshka else None,
dtype=dtype,
model_flag=model_flag,
)

if training_args.fix_position_embedding:
Expand Down Expand Up @@ -138,6 +144,32 @@ def main():
".*up_proj$",
".*gate_proj$",
]
elif any([x in model_args.model_name_or_path for x in ["bge"]]): # no reference, so use the simplest setting
target_modules = [
".*q_proj$",
".*k_proj$",
".*v_proj$",
]
elif any([x in model_args.model_name_or_path for x in ["LLARA"]]): # same as llama
target_modules = [
".*q_proj$",
".*k_proj$",
".*v_proj$",
".*o_proj$",
".*down_proj$",
".*up_proj$",
".*gate_proj$",
]
elif any([x in model_args.model_name_or_path for x in ["Qwen3"]]): # copy from qwen2
target_modules = [
".*q_proj.*",
".*k_proj.*",
".*v_proj.*",
".*o_proj.*",
".*gate_proj.*",
".*down_proj.*",
".*up_proj.*",
]
else:
raise ValueError("need to specify the target modules for LoRA fine-tuning.")

Expand All @@ -150,6 +182,13 @@ def main():
)
if "llama" in model_args.model_name_or_path.lower():
model.config = model.model_config # for NV-Embed, this is no needed, but for repllama, this is needed
if (
("llara" in model_args.model_name_or_path.lower())
or ("bge-large" in model_args.model_name_or_path.lower())
or ("qwen3" in model_args.model_name_or_path.lower())
or ("bge-en-icl" in model_args.model_name_or_path.lower())
):
model.config = model.model_config # for NV-Embed, this is no needed, but for repllama, this is needed
model.config.tensor_parallel_degree = training_args.tensor_parallel_degree
model = LoRAModel(model, lora_config)
model.mark_only_lora_as_trainable()
Expand Down
Loading