Skip to content

Commit ae5ef4f

Browse files
committed
make sure adapters can access the context of their parent chain
#456
1 parent 590648e commit ae5ef4f

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

src/refiners/fluxion/layers/chain.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ def replace(
609609
new_module._set_parent(self)
610610
if isinstance(old_module, ContextModule):
611611
old_module._set_parent(old_module_parent)
612+
self._register_provider()
612613

613614
def structural_copy(self: TChain) -> TChain:
614615
"""Copy the structure of the Chain tree.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import refiners.fluxion.layers as fl
2+
from refiners.fluxion.adapters.adapter import Adapter
3+
from refiners.fluxion.context import Contexts
4+
5+
6+
class ContextAdapter(fl.Chain, Adapter[fl.Chain]):
7+
def __init__(self, target: fl.Chain):
8+
with self.setup_adapter(target):
9+
super().__init__(
10+
fl.Lambda(lambda: 42),
11+
fl.SetContext("foo", "bar"),
12+
)
13+
14+
15+
class ContextChain(fl.Chain):
16+
def init_context(self) -> Contexts:
17+
return {"foo": {"bar": None}}
18+
19+
20+
def test_adapter_can_access_parent_context():
21+
chain = ContextChain(fl.Chain(), fl.UseContext("foo", "bar"))
22+
adaptee = chain.layer("Chain", fl.Chain)
23+
ContextAdapter(adaptee).inject(chain)
24+
25+
assert chain() == 42

0 commit comments

Comments
 (0)