Skip to content

Conversation

TheTahaaa
Copy link

Hi there! 👋

This PR adds support for Arrow, a modular routing mechanism for LoRA experts introduced here, as well as our refinement method GenKnowSub, proposed in our ACL 2025 Main Conference paper. GenKnowSub enhances Arrow by subtracting a general-domain LoRA from task-specific ones prior to routing, leading to improved generalisation and modularity.

The integration is implemented through a new ArrowLoraLinearLayer, registered as a LoraVariant (similar to DoRA).

🔧 Code Changes

Modified files under peft/tuners/lora/:

  • config.py, layer.py, model.py, variants.py
  • bnb.py (adds support for Arrow with bitsandbytes quantisation)

Added a new file:

  • arrow.py – contains the core logic for Arrow and GenKnowSub

✅ Testing & Quality

Confirmed that tests/test_lora_variants.py passes.

All formatting and style checks pass (make quality successful).

📚 Optional Follow-up

I’ve also developed a separate GitHub repository demonstrating how to use Arrow and GenKnowSub on benchmark datasets (e.g., BoolQ, PIQA, ARC-Challenge). I’d be happy to contribute this as an example under peft/examples/ if it’s of interest.

Thank you in advance for reviewing this PR!

@githubnemo
Copy link
Collaborator

Hey @TheTahaaa, thank you for the PR!

I'm unsure about the use of a LoRA variant here. The concept of Arrow is, as I understand it, more similar to X-Lora (i.e. a routing of multiple adapters) than a single adapter. Therefore I think it would make more sense to implement it as its own method (but still being able to ingest LoRA adapters, like X-LoRA does).

Secondly, I think that while a direct integration of GenKnowSub via a method can have certain utility, it seems more like a recipe or pre-processing step rather than something that should be available all the time. Maybe it would be more informative to have a full-fledged example how to create a general knowledge LoRA and how to subtract it?

Let me know what you think.

@TheTahaaa
Copy link
Author

TheTahaaa commented Jul 17, 2025

Thanks for the feedback @githubnemo !

Regarding the use of a LoraVariant for Arrow: while Arrow isn’t a new adapter type, it’s fundamentally an inference-time, train-free routing mechanism over already trained LoRAs within the same layer, as emphasised in the paper. That’s why I implemented it as a variant—so it can stay tightly coupled to the existing LoRA modules, plug cleanly into PEFT’s adapter system, and require no changes to loading logic or user workflows. Users can simply add a dummy “router” adapter, activate it, and Arrow handles token-level routing over the loaded experts in place. It’s lightweight, modular, and leverages the LoraVariant interface for minimal disruption.

Regarding GenKnowSub: in our paper, we show that subtracting a general-domain LoRA from each task-specific LoRA before applying Arrow significantly improves modularity and routing quality. To obtain these general-domain LoRAs, we simply fine-tune the base model on a general corpus (e.g., Wikipedia) using LoRA with a causal language modeling objective. These general LoRAs—such as ones trained on different languages—are then loaded into the model alongside the task-specific LoRAs.

Because GenKnowSub is a one-time, inference-only preprocessing step, we’ve implemented it as a toggle in the Arrow router config via use_gks=True. When enabled, the router adapter will subtract the average of the general LoRAs (e.g., language experts) from each task-specific LoRA prior to routing.

Here’s a concise example showing how a user would load LoRA experts, apply GenKnowSub, and use Arrow routing:

# Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained(...)
base_model = AutoModelForCausalLM.from_pretrained(...)

# Load task-specific LoRA adapters
model.load_adapter(..., adapter_name="cluster0")
...
model.load_adapter(..., adapter_name="cluster9")

# Load general-domain LoRA adapters (e.g., English, French, German)
model.load_adapter(..., adapter_name="le_en")
model.load_adapter(..., adapter_name="le_fr")
model.load_adapter(..., adapter_name="le_de")

# Now the model contains all the task-specific and general-domain LoRAs

# Define Arrow + GenKnowSub config as a dummy LoRA router
router_cfg = LoraConfig(
    r=2, # dummy rank since A and B won't be used!
    use_arrow=True, # This will turn this LoRA to the ArrowLoraVariant
    arrow_expert_num=10, # Number of task-specific modules in each LoRA layer
    arrow_top_k=3, # Number of selected task-specific LoRAs for each token in each layer 
    arrow_router_temperature=1.0,
    use_gks=True,  # ← enable GenKnowSub!
    le_names=..., # name of loaded general-domain LoRAs
    ts_names=..., # name of loaded task-specific LoRAs
    target_modules=["qkv_proj", "o_proj"]
)

# Add a dummy router adapter to trigger Arrow + GenKnowSub logic
model.add_adapter(adapter_name="router", peft_config=router_cfg)
model.set_adapter("router")

# Done! Just call `model.generate()` or `model(...)` as usual.

This flow is designed to be non-invasive and intuitive. Once the dummy router adapter is added and activated, it internally handles everything through its custom forward() path, including:

  • Performing GenKnowSub (subtracting the average of general-domain LoRAs from task-specific ones),

  • Routing tokens using the Arrow mechanism across the loaded task-specific experts.

No special hooks, overrides, or custom code are needed beyond activating the adapter — the standard .forward() or .generate() calls will work seamlessly.

If helpful, I’d be happy to walk through how this is handled inside the forward() logic of the ArrowLoraLinearLayer.

@githubnemo
Copy link
Collaborator

Thanks for taking the time to give a thorough answer.

Regarding the use of a LoraVariant for Arrow: while Arrow isn’t a new adapter type, it’s fundamentally an inference-time, train-free routing mechanism over already trained LoRAs within the same layer, as emphasised in the paper. That’s why I implemented it as a variant—so it can stay tightly coupled to the existing LoRA modules, plug cleanly into PEFT’s adapter system, and require no changes to loading logic or user workflows. Users can simply add a dummy “router” adapter, activate it, and Arrow handles token-level routing over the loaded experts in place. It’s lightweight, modular, and leverages the LoraVariant interface for minimal disruption.

Yes, I thought about this again and I'd still implement it as its own method for better separation of concerns and a clearer interface for the user. However, I can also see how this would involve a lot of code duplication when we could re-use all the infrastructure that is already there in form of LoRA. So, I suggest a compromise:

  1. Introduce an ArrowConfig class in src/tuners/lora/config.py similar to EvaConfig and friends to group the configuration in one place and prevent cluttering future LoRA adapter configs (side question: do we really need arrow_expert_num when we have the number of task-specific LoRAs at hand?)
  2. Introduce a function like create_arrow_model which will be user-facing that implements the process you showed in your example but also checks that the supplied arguments can work together. These checks would include things like a) are the supplied adapters all targeting the same modules, b) do the adapters have compatible shapes (e.g., when subtracting in case of GenKnowSub). There are probably more checks that are helpful to the user which we'll uncover on the way.

For testing we should implement stand-alone tests similar to tests/test_xlora.py that test arrow basic functionality and, additionally, GenKnowSub. While precise testing is difficult we can at least test whether a) Arrow with different sets of (dummy) adapters produces different outputs and b) the same is true for GenKnowSub compared to vanilla Arrow. Normally I'd suggest to add tests in test_custom_models and others but I see little value for Arrow in this case.

Do you think that would make sense?

@TheTahaaa
Copy link
Author

Sounds good to me!

I'll start working on the changes and add the necessary test file(s) soon. My understanding is that most of the work will involve refactoring existing code into a new class for the standalone Arrow method – is that correct?

@githubnemo
Copy link
Collaborator

My understanding is that most of the work will involve refactoring existing code into a new class for the standalone Arrow method – is that correct?

Not quite what I meant. My suggestion is to keep the variant solution that you already have because it re-uses a lot of code from LoRA and making it its own method would duplicate a lot of code. To minimize the impact on future configs I suggested to use a breakout config (point (1) in my post). Since the variant implementation will not have a PeftModel class that checks user inputs related to Arrow/GenKnowSub I suggested (point (2) in my post) to implement a helper function that takes a model, adapters and whatever is needed to compile an arrow routing setup (possibly with GenKnowSub) and builds a model similar to what you provided in your example.

So instead of this:

# Load task-specific LoRA adapters
model.load_adapter(..., adapter_name="cluster0")
...
model.load_adapter(..., adapter_name="cluster9")

# Load general-domain LoRA adapters (e.g., English, French, German)
model.load_adapter(..., adapter_name="le_en")
model.load_adapter(..., adapter_name="le_fr")
model.load_adapter(..., adapter_name="le_de")

# Now the model contains all the task-specific and general-domain LoRAs

# Define Arrow + GenKnowSub config as a dummy LoRA router
router_cfg = LoraConfig(
    r=2, # dummy rank since A and B won't be used!
    use_arrow=True, # This will turn this LoRA to the ArrowLoraVariant
    arrow_expert_num=10, # Number of task-specific modules in each LoRA layer
    arrow_top_k=3, # Number of selected task-specific LoRAs for each token in each layer 
    arrow_router_temperature=1.0,
    use_gks=True,  # ← enable GenKnowSub!
    le_names=..., # name of loaded general-domain LoRAs
    ts_names=..., # name of loaded task-specific LoRAs
    target_modules=["qkv_proj", "o_proj"]
)

# Add a dummy router adapter to trigger Arrow + GenKnowSub logic
model.add_adapter(adapter_name="router", peft_config=router_cfg)
model.set_adapter("router")

# Done! Just call `model.generate()` or `model(...)` as usual.

there will be something like

from peft import create_arrow_model, ArrowConfig

peft_model = create_arrow_model(
    base_model=base_model,
    task_specific_adapters=[path0, path1, path2],
    general_knowledge_adapters=[path3, path4, path5],
    arrow_config=ArrowConfig(...)
)

This function would live in the lora/variants.py file and could check if the adapters are compatible (e.g., equally ranked).

@TheTahaaa
Copy link
Author

That would be much more straightforward to implement, working on it! @githubnemo

@TheTahaaa
Copy link
Author

TheTahaaa commented Jul 31, 2025

✅ Implementation complete! @githubnemo

I’ve added just two attributes to the LoraConfig class: use_arrow and arrow_config.
I also refined the Arrow forward pass to reduce VRAM usage.

Currently working on the test files—will push the final version soon.

P.S. The final version with create_arrow_model and tests/test_arrow is committed and pushed. 🎯

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is already looking good, I think that this design works quite well.

Lots of smaller nit picks which need to be resolved first before doing a more thorough review. One bigger point is the precomputation of prototypes / gks values which disallows adding / removing adapters. It also prevents us from switching to an expert adapter, do fine-tuning and switching back to the router (the prototypes would be outdated). While I think that supporting adding / removing adapters is a smaller change, allowing for fine-tuning while arrow layers are present is a bit more complicated. I expect GKS to be OK with this since we only subtract the weights once at the beginning and could do fine-tuning afterwards. I'm not sure if it is possible to solve this without removing prototype computation. Maybe we should make prototype computation optional. WDYT?


gen_names: list[str] = field(default=None, metadata={"help": "list of general LoRA names."})

def __post_init__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a check for ts_names as well since it cannot be empty as far as I understand

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ts_names and gen_names are actually adapter names, not their paths.
Within create_arrow_model(), I standardise these names to ts_expert_i and gen_expert_i, and set ts_names / gen_names accordingly

# Adding task-specific adapter names to the arrow_config
arrow_config.ts_names = [f"ts_expert_{i}" for i in range(len(task_specific_adapter_paths))]

# Adding general adapter names to the arrow_config
arrow_config.gen_names = [f"gen_expert_{i}" for i in range(len(general_adapter_paths))]

This means that regardless of what the user specifies (or doesn’t specify) in arrow_config, the adapters will be named in this standardised format inside create_arrow_model().

Screenshot 2025-08-11 at 5 54 52 PM

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would still be worthwhile to have a sanity check in case people instantiate ArrowConfig by themselves and to supply ts_names.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agree.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, I think I misunderstood. I was under the impression that we instantiate the ArrowConfig in create_arrow_model fully after knowing the adapter paths and therefore in both cases (1: user instantiates arrow config with adapter names, 2: create_arrow_config loads adapters into model and creates arrow config with adapter names) the adapter names must NOT be None.

Now I understood that create_arrow_model takes the arrow config and modifies it subsequently after loading the adapters into the model. I'm fine with that but then I'd opt to revert what I suggested and remove the post init checks so that users can instantiate the arrow config without using create_arrow_model. Sorry for that.

A side note: Reading the code with the naming scheme in place, I'm not sure if I like calling generic adapters 'experts' (gen_expert_[0-9]+). Isn't a generic adapter the opposite of an expert? Maybe it'd make sense to call the adapters explicitly task_ and gks_. To make such changes minimal and avoid typos, let's define such strings as global constants (e.g., TASK_ADAPTER_PREFIX = "task_"). WDYT?

@TheTahaaa
Copy link
Author

Nice! This is already looking good, I think that this design works quite well.

Lots of smaller nit picks which need to be resolved first before doing a more thorough review. One bigger point is the precomputation of prototypes / gks values which disallows adding / removing adapters. It also prevents us from switching to an expert adapter, do fine-tuning and switching back to the router (the prototypes would be outdated). While I think that supporting adding / removing adapters is a smaller change, allowing for fine-tuning while arrow layers are present is a bit more complicated. I expect GKS to be OK with this since we only subtract the weights once at the beginning and could do fine-tuning afterwards. I'm not sure if it is possible to solve this without removing prototype computation. Maybe we should make prototype computation optional. WDYT?

The precomputation of prototypes (and the averaging of general experts for GKS) is lazy—it happens in the forward() of ArrowLoraLinearLayer only the first time the "arrow_router" adapter is active.

This means fine-tuning with multiple experts is fully supported as long as "arrow_router" is not active during training. You can activate other adapters, add/remove experts, fine-tune as needed, and then switch back to "arrow_router". At that moment, prototypes (and GKS values) will be computed based on the currently loaded experts.

However, in the case of adding or removing LoRA modules to/from the arrow_model, we need to ensure that the task-specific and general LoRA module lists are kept in sync. This can be handled via a utility function (similar to .update_layer()), which would reorganise and refresh the module references before the forward pass when "arrow_router" is active

@githubnemo
Copy link
Collaborator

This means fine-tuning with multiple experts is fully supported as long as "arrow_router" is not active during training. You can activate other adapters, add/remove experts, fine-tune as needed, and then switch back to "arrow_router". At that moment, prototypes (and GKS values) will be computed based on the currently loaded experts.

I think there are also training scenarios that won't work. For example if I want to test the setup's performance end-to-end during training. Let's say I have a tool usage benchmark (e.g., a multi-turn conversation with intermittent tool calls) and I have several adapters, one for each tool. It would be hard to separate the individual steps of the benchmark for each tool so I have to test the whole arrow + adapter setup. This means that I have to train a single adapter, return to the arrow adapter, evaluate, return to the single adapter, etc. This would inherently be incompatible with caching the prototypes.

However, in the case of adding or removing LoRA modules to/from the arrow_model, we need to ensure that the task-specific and general LoRA module lists are kept in sync. This can be handled via a utility function (similar to .update_layer()), which would reorganise and refresh the module references before the forward pass when "arrow_router" is active

Yes, something like update_layer would work. We don't support this in variants right now but I think adding support shouldn't be doable. For now, let's just add an arrow-specific exception to update_layers, something like

def update_layer(self, adapter_name, ...):
    [...]

    if hasattr(self, "lora_arrow"):
        for adapter in self.lora_variant:
            if adapter in self.lora_arrow:
                self.lora_arrow[adapter].on_adapter_change()

@TheTahaaa
Copy link
Author

TheTahaaa commented Aug 11, 2025

However, in the case of adding or removing LoRA modules to/from the arrow_model, we need to ensure that the task-specific and general LoRA module lists are kept in sync. This can be handled via a utility function (similar to .update_layer()), which would reorganise and refresh the module references before the forward pass when "arrow_router" is active

Yes, something like update_layer would work. We don't support this in variants right now but I think adding support shouldn't be doable. For now, let's just add an arrow-specific exception to update_layers, something like

def update_layer(self, adapter_name, ...):
    [...]

    if hasattr(self, "lora_arrow"):
        for adapter in self.lora_variant:
            if adapter in self.lora_arrow:
                self.lora_arrow[adapter].on_adapter_change()

Should I go ahead and implement on_adapter_change() now, or just raise a ValueError for the time being?

The point is that if an adapter is loaded or removed from the model, the user should explicitly indicate which adapters are task-specific and which are general adapters used in GKS.

One possible approach is to add an attribute in ArrowLoraLinearLayer, e.g., last_adapter_nums, initialised with the number of task-specific and general LoRAs. Then, when on_adapter_change() is called, we can easily check whether the loaded LoRAs have changed.

That said, I think that whenever a user adds or removes a LoRA from the arrow_model, they should also indicate the final set of task-specific and general LoRAs so we can properly organise things.

What do you think?

@githubnemo
Copy link
Collaborator

However, in the case of adding or removing LoRA modules to/from the arrow_model, we need to ensure that the task-specific and general LoRA module lists are kept in sync. This can be handled via a utility function (similar to .update_layer()), which would reorganise and refresh the module references before the forward pass when "arrow_router" is active

Yes, something like update_layer would work. We don't support this in variants right now but I think adding support shouldn't be doable. For now, let's just add an arrow-specific exception to update_layers, something like

def update_layer(self, adapter_name, ...):
    [...]

    if hasattr(self, "lora_arrow"):
        for adapter in self.lora_variant:
            if adapter in self.lora_arrow:
                self.lora_arrow[adapter].on_adapter_change()

Should I go ahead and implement on_adapter_change() now, or just raise a ValueError for the time being?

The point is that if an adapter is loaded or removed from the model, the user should explicitly indicate which adapters are task-specific and which are general adapters used in GKS.

One possible approach is to add an attribute in ArrowLoraLinearLayer, e.g., last_adapter_nums, initialised with the number of task-specific and general LoRAs. Then, when on_adapter_change() is called, we can easily check whether the loaded LoRAs have changed.

That said, I think that whenever a user adds or removes a LoRA from the arrow_model, they should also indicate the final set of task-specific and general LoRAs so we can properly organise things.

What do you think?

Yeah, let's go ahead and implement on_adapter_change. I think we can assume the general knowledge adapters to be fixed which is a reasonable assumption since they are by definition task-unspecific and therefore unlikely to be trained or swapped in such a setup. Therefore I don't think we'll need to worry about the user discerning between task-specific/general and assume that every adapter added to the model is task-specific. So .on_adapter_change would only need to do GKS and computing prototypes.

@TheTahaaa
Copy link
Author

TheTahaaa commented Aug 12, 2025

However, in the case of adding or removing LoRA modules to/from the arrow_model, we need to ensure that the task-specific and general LoRA module lists are kept in sync. This can be handled via a utility function (similar to .update_layer()), which would reorganise and refresh the module references before the forward pass when "arrow_router" is active

Yes, something like update_layer would work. We don't support this in variants right now but I think adding support shouldn't be doable. For now, let's just add an arrow-specific exception to update_layers, something like

def update_layer(self, adapter_name, ...):
    [...]

    if hasattr(self, "lora_arrow"):
        for adapter in self.lora_variant:
            if adapter in self.lora_arrow:
                self.lora_arrow[adapter].on_adapter_change()

Should I go ahead and implement on_adapter_change() now, or just raise a ValueError for the time being?
The point is that if an adapter is loaded or removed from the model, the user should explicitly indicate which adapters are task-specific and which are general adapters used in GKS.
One possible approach is to add an attribute in ArrowLoraLinearLayer, e.g., last_adapter_nums, initialised with the number of task-specific and general LoRAs. Then, when on_adapter_change() is called, we can easily check whether the loaded LoRAs have changed.
That said, I think that whenever a user adds or removes a LoRA from the arrow_model, they should also indicate the final set of task-specific and general LoRAs so we can properly organise things.
What do you think?

Yeah, let's go ahead and implement on_adapter_change. I think we can assume the general knowledge adapters to be fixed which is a reasonable assumption since they are by definition task-unspecific and therefore unlikely to be trained or swapped in such a setup. Therefore I don't think we'll need to worry about the user discerning between task-specific/general and assume that every adapter added to the model is task-specific. So .on_adapter_change would only need to do GKS and computing prototypes.

✅ Implemented!

on_adapter_change() now checks all adapters currently loaded in the model, ensuring the model recomputes gks and prototyp once in the forward pass.

@githubnemo
Copy link
Collaborator

Give me a ping when I can review this again. Regarding review procedure it would be best to keep the comments open instead of marking them as resolved.

Make sure to merge main into your branch regularly to keep the PR mergable! :)

@TheTahaaa
Copy link
Author

TheTahaaa commented Aug 15, 2025

Hey @githubnemo ,

I’ve merged the latest main branch into my PR and re-ran all the tests — all 7 test cases passed successfully. ✅

There are still some style/formatting issues coming from the merged main branch (see screenshot below), so I haven’t run make quality on the merged branch yet:

Screenshot 2025-08-15 at 10 55 14 AM

I had also marked all previously open comments as resolved — just so I can address them systematically one by one. They are all marked as unresolved again now.

The PR is now ready for review!

@TheTahaaa
Copy link
Author

Hey @githubnemo ,

I’ve updated the prototype computation logic and merged the latest changes from main.
It’s ready for your review! 😁

@TheTahaaa
Copy link
Author

TheTahaaa commented Aug 17, 2025

A minor while important update in variants/create_arrow_model() ensures that all loaded adapters are first checked for consistency in their r and alpha values. Once validated, these shared values are then applied to the arrow_router.

@githubnemo
Copy link
Collaborator

I think there was an issue with rebasing and formatting. src/peft/tuners/__init__.py is missing the miss and shira changes (there are other places as well where changes from main are reverted, this was just an example). There seems to be a rogue auto-formatter going around as well (there are lots of changes in method_comparison/ for example).

There are still some style/formatting issues coming from the merged main branch (see screenshot below), so I haven’t run make quality on the merged branch yet:

We've recently updated the ruff version, maybe yours is now outdated. pip install -e '.[quality] should update the quality checkers and hopefully resolve the issue.

Could you go over the PR once and try to make sure that the PR aligns with main and doesn't contain irrelevant changes (e.g., auto-formatting of unrelated things). That would help quite a lot, thanks!

@TheTahaaa
Copy link
Author

TheTahaaa commented Aug 19, 2025

Hey @githubnemo ,

Rebased on the latest main, removed unrelated changes, restored upstream entries (MiSS/SHiRA), and updated quality hooks.
The PR now contains a single tidy commit touching only:

src/peft/__init__.py
src/peft/tuners/__init__.py
src/peft/tuners/lora/{__init__,arrow,bnb,config,layer,model,variants}.py
tests/test_arrow.py

Ready for review ✅

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for dealing with the rebase, everything looks fine now!

There are still some aspects that deserve testing like the hyper-param compatibility check and possibly GPU/bnb related code but in general it looks quite good already.

Since the implementation seems reaching maturity I'd suggest that we focus on adding documentation and examples. You can document the new method in a new section in docs/source/developer_guides/lora.md explaining what the method does, that it can be enhanced with GKS and provide brief code examples, similar to the rest of the document. Providing a bit of insight how to manually instantiate the arrow layer for advanced users is a plus. This should include that the naming scheme for loaded adapters is documented (in the code as well) since that is now a major part of how arrow ingests adapters.

In your original post of this PR you mentioned that you already have some examples for common benchmark datasets. Do you think you would be able to contribute an example from one of them? Normally I'd suggest adding the method to the MetaMathQA benchmark suite in PEFT but Arrow (+GKS) make a lot of assumptions that are not true for other methods so I'd rather skip that at this time. Nevertheless, I think having an example that showcases the benefit of Arrow and Arrow + GKS over baseline would be great.

Comment on lines 65 to 73
def resolve_lora_variant(self, *, use_dora: bool, use_arrow: bool, **kwargs) -> Optional[LoraVariant]:
if use_arrow:
# arrow variant must be imported so it registers itself
from .variants import ArrowLinearVariant

return ArrowLinearVariant()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested forward as well? Maybe it makes sense to add a test to test_gpu_examples.py for testing the bnb integration properly. I think class TestEvaInitializationGPU is a good example on how to do this.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@TheTahaaa
Copy link
Author

TheTahaaa commented Aug 22, 2025

There are still some aspects that deserve testing like the hyper-param compatibility check and possibly GPU/bnb related code but in general it looks quite good already.

Regarding the bnb concern, I did test both Phi2 and Phi3-mini with 4-bit quantisation on a few datasets like Arc-easy, Arc-challenge, HellaSwag; they work absolutely fine!

Since the implementation seems reaching maturity I'd suggest that we focus on adding documentation and examples. You can document the new method in a new section in docs/source/developer_guides/lora.md explaining what the method does, that it can be enhanced with GKS and provide brief code examples, similar to the rest of the document. Providing a bit of insight how to manually instantiate the arrow layer for advanced users is a plus. This should include that the naming scheme for loaded adapters is documented (in the code as well) since that is now a major part of how arrow ingests adapters.

Sure!

In your original post of this PR you mentioned that you already have some examples for common benchmark datasets. Do you think you would be able to contribute an example from one of them? Normally I'd suggest adding the method to the MetaMathQA benchmark suite in PEFT but Arrow (+GKS) make a lot of assumptions that are not true for other methods so I'd rather skip that at this time. Nevertheless, I think having an example that showcases the benefit of Arrow and Arrow + GKS over baseline would be great.

Yes, I can. I’ll run some experiments on PIQA, BoolQ, ARC-Easy, etc. to show how GKS and Arrow improve over the base model, using the task adapters trained on the Flan dataset. Should I add this example to docs/source/developer_guides/lora.md, or examples directory?

@TheTahaaa
Copy link
Author

TheTahaaa commented Aug 25, 2025

I’ve added another test to cover the scenario of adding a new adapter, running a forward pass (to build prototype and apply gks), then activating the new adapter and calling forward again (to simulate training/inference with a newly loaded expert), and finally re-activating the arrow_router. The output was identical to the case where all experts were loaded from the start ✅.

In addition, I updated the documentation in docs/source/developer_guides/lora.md. I also prepared a complete code snippet to evaluate the performance of Arrow and GKS against the base model on PIQA, BoolQ, ARC-Easy, etc. I’m just waiting for your confirmation on where to include it — I think the examples directory would be the right place.

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nicely done! Only some minor comments left.

I think the examples directory would be the right place.

I agree, something like examples/arrow_multitask or something akin to the task you have in mind.

Could you also check the CI? There seem to be a few errors due to some LoRA layers still missing the arrow_config parameter in their __init__ and update_layer methods.

Comment on lines +37 to +44
----------------------------------------------------
(qkv_proj): lora.Linear4bit or lora.Linear(
(base_layer): Linear4bit or Linear (lora_dropout): ModuleDict( ... ) (lora_A): ModuleDict( ... )
(lora_B): ModuleDict( ... ) (lora_embedding_A): ParameterDict( ... ) (lora_embedding_B): ParameterDict(
... ) (lora_magnitude_vector): ModuleDict( ... ) (lora_arrow): ModuleDict(
(arrow_router): ArrowLoraLinearLayer() )
)
----------------------------------------------------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming looks good!

This way, users can’t pass values for them directly; they are always set inside create_arrow_model().

I revised my opinion about this here and I still think that users should have the option to set these config options without calling create_arrow_model (it's for safety & convenience, not because it is not possible otherwise). If you don't object let's remove the checks in post_init otherwise I'm fine keeping it that way as well.

Right now I’ve placed them inside create_arrow_model(), but I’m open to suggestions on the best location (e.g., config.py or utils/constants.py).

I think keeping them at top-level in arrow.py is OK. I'd interpret these values not as part of the public API but users can (on their own risk) import these if they want to build something that is out of scope for create_arrow_model.

model_id = "hf-internal-testing/tiny-random-gpt2"
with hub_online_once(model_id):
base_model_1 = AutoModelForCausalLM.from_pretrained(model_id)
base_model_2 = AutoModelForCausalLM.from_pretrained(model_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
base_model_2 = AutoModelForCausalLM.from_pretrained(model_id)
base_model_2 = copy.deepcopy(base_model_1)

Otherwise the second from_pretrained would not be cached if the model was not seen before. This also avoids the from_pretrained overhead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, Done!

model_id = "hf-internal-testing/tiny-random-gpt2"
with hub_online_once(model_id):
base_model_1 = AutoModelForCausalLM.from_pretrained(model_id)
base_model_2 = AutoModelForCausalLM.from_pretrained(model_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
base_model_2 = AutoModelForCausalLM.from_pretrained(model_id)
base_model_2 = copy.deepcopy(base_model_1)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 322 to 325
# Find a Conv2d module name in the ResNet test model
probe = AutoModelForImageClassification.from_pretrained(model_id)
conv_names = [n.split(".")[-1] for n, m in probe.named_modules() if isinstance(m, torch.nn.Conv2d)]
assert len(conv_names) > 0, "No Conv2d modules found in the ResNet test model."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Find a Conv2d module name in the ResNet test model
probe = AutoModelForImageClassification.from_pretrained(model_id)
conv_names = [n.split(".")[-1] for n, m in probe.named_modules() if isinstance(m, torch.nn.Conv2d)]
assert len(conv_names) > 0, "No Conv2d modules found in the ResNet test model."

Not necessary anymore I think

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!

Comment on lines 360 to 361
# Create base in fp16 (no manual assignment to .dtype)
base = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Create base in fp16 (no manual assignment to .dtype)
base = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
# Create base in fp16 (no manual assignment to .dtype)
with hub_online_once(model_id):
base = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I had overlooked that part.

Comment on lines 658 to 661
top_k = 3,
router_temperature = 1.0,
use_gks = True,
rng_seed = 42,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
top_k = 3,
router_temperature = 1.0,
use_gks = True,
rng_seed = 42,
top_k=3,
router_temperature=1.0,
use_gks=True,
rng_seed=42,

Let's stay with PEP8 formatting in the examples as well (no whitespace around = when assigning named parameters). This is done several times in this document.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're all corrected.

Comment on lines 670 to 678
def create_arrow_model(
base_model: PreTrainedModel,
task_specific_adapter_paths: list[str], # path of task-specific LoRAs
ts_repo_id: str,
arrow_config: ArrowConfig,
general_adapter_paths: list[str] | None = None, # path to the trained general-knowledge LoRAs
gen_repo_id: str | None = None,
**adapter_kwargs: Any,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading the well-written examples, I wonder: Wouldn't it be more user-friendly to just specify task_specific_adapter_paths fully qualified and remove the _repo_id parameters? I.e.:

create_arrow_model(
    base_model=base_model,
    task_specific_adapter_paths=[
        "TahaBa/phi3-mini-clustered-flan/ts_expert_0",
        "TahaBa/phi3-mini-clustered-flan/ts_expert_1",
        ...
    ],
    [...]
)

We could change get_subfolder_from_path to something like
this:

def split_model_path(path):
    if os.path.exists(path):
        return path, None
    else:
        split_path = path.split("/")
        model_id = "/".join(split_path[0:2])
        subfolder = "/".join(split_path[2:])
        return model_id, subfolder

# usage
model_id, subfolder = split_model_path(path)

PeftModel.from_pretrained(
    model=base_model,
    model_id=model_id,
    subfolder=subfolder,
)

WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I found the current behavior a bit unintuitive. The main challenge was supporting three different cases for adapter loading:

  1. Locally saved adapters (path directly contains adapter_config.json)
  2. Adapters stored on the Hub inside a subfolder of a repo
  3. Adapters published as standalone repos (config at the repo root)

Your suggested approach only worked smoothly for Hub repos with subfolders, but not for local paths or adapters that are repos themselves.

To make this more user-friendly, I added a small helper inside create_arrow_model that transparently handles all three cases. With this in place, the user just needs to pass task_specific_adapter_paths and general_adapter_paths to create_arrow_model, and everything else is resolved internally.

Comment on lines +711 to +727
# Quantisation config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
)

# Loading the model
base_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config,
)

# Now call create_arrow_model() as we explained before.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add this as a test case in test_gpu_examples.py with a small model like facebook/opt-125m and randomly initialized, newly created adapters. Just so that we can see that the model loads and generate works.

Copy link
Author

@TheTahaaa TheTahaaa Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it to the end of the test_gpu_examples.py. ✅

Comment on lines 1829 to 1866

Note:

`lora_config.target_parameters` Note:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Undo formatting

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 142 to 144
def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove __init__

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep it is the reason for "LoraModel.init() got an unexpected keyword argument 'state_dict'" error in CI.

@TheTahaaa
Copy link
Author

TheTahaaa commented Aug 27, 2025

All review suggestions have been addressed, and the new arrow-multitask script has been added to the examples directory. Waiting for your review ✅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants