diff --git a/src/refiners/fluxion/layers/chain.py b/src/refiners/fluxion/layers/chain.py index 0cb5f1208..c73a60612 100644 --- a/src/refiners/fluxion/layers/chain.py +++ b/src/refiners/fluxion/layers/chain.py @@ -609,6 +609,7 @@ def replace( new_module._set_parent(self) if isinstance(old_module, ContextModule): old_module._set_parent(old_module_parent) + self._register_provider() def structural_copy(self: TChain) -> TChain: """Copy the structure of the Chain tree. diff --git a/tests/adapters/test_adapter_context.py b/tests/adapters/test_adapter_context.py new file mode 100644 index 000000000..447e7bca2 --- /dev/null +++ b/tests/adapters/test_adapter_context.py @@ -0,0 +1,25 @@ +import refiners.fluxion.layers as fl +from refiners.fluxion.adapters.adapter import Adapter +from refiners.fluxion.context import Contexts + + +class ContextAdapter(fl.Chain, Adapter[fl.Chain]): + def __init__(self, target: fl.Chain): + with self.setup_adapter(target): + super().__init__( + fl.Lambda(lambda: 42), + fl.SetContext("foo", "bar"), + ) + + +class ContextChain(fl.Chain): + def init_context(self) -> Contexts: + return {"foo": {"bar": None}} + + +def test_adapter_can_access_parent_context(): + chain = ContextChain(fl.Chain(), fl.UseContext("foo", "bar")) + adaptee = chain.layer("Chain", fl.Chain) + ContextAdapter(adaptee).inject(chain) + + assert chain() == 42