29
29
< link href ="../../../_static/styles/theme.css?digest=8878045cc6db502f8baf " rel ="stylesheet " />
30
30
< link href ="../../../_static/styles/pydata-sphinx-theme.css?digest=8878045cc6db502f8baf " rel ="stylesheet " />
31
31
32
- < link rel ="stylesheet " type ="text/css " href ="../../../_static/pygments.css?v=a746c00c " />
32
+ < link rel ="stylesheet " type ="text/css " href ="../../../_static/pygments.css?v=8f2a1f02 " />
33
33
< link rel ="stylesheet " type ="text/css " href ="../../../_static/copybutton.css?v=76b2166b " />
34
34
< link rel ="stylesheet " type ="text/css " href ="../../../_static/sg_gallery.css?v=d2d258e8 " />
35
35
364
364
365
365
< h1 > Source code for topomodelx.base.conv</ h1 > < div class ="highlight "> < pre >
366
366
< span > </ span > < span class ="sd "> """Convolutional layer for message passing."""</ span >
367
- < span class ="kn "> from</ span > < span class ="nn "> typing</ span > < span class ="kn "> import</ span > < span class ="n "> Literal</ span >
367
+ < span class ="kn "> from</ span > < span class =" w " > </ span > < span class ="nn "> typing</ span > < span class =" w " > </ span > < span class ="kn "> import</ span > < span class ="n "> Literal</ span >
368
368
369
- < span class ="kn "> import</ span > < span class ="nn "> torch</ span >
370
- < span class ="kn "> from</ span > < span class ="nn "> torch.nn.parameter</ span > < span class ="kn "> import</ span > < span class ="n "> Parameter</ span >
369
+ < span class ="kn "> import</ span > < span class =" w " > </ span > < span class ="nn "> torch</ span >
370
+ < span class ="kn "> from</ span > < span class =" w " > </ span > < span class ="nn "> torch.nn.parameter</ span > < span class =" w " > </ span > < span class ="kn "> import</ span > < span class ="n "> Parameter</ span >
371
371
372
- < span class ="kn "> from</ span > < span class ="nn "> topomodelx.base.message_passing</ span > < span class ="kn "> import</ span > < span class ="n "> MessagePassing</ span >
372
+ < span class ="kn "> from</ span > < span class =" w " > </ span > < span class ="nn "> topomodelx.base.message_passing</ span > < span class =" w " > </ span > < span class ="kn "> import</ span > < span class ="n "> MessagePassing</ span >
373
373
374
374
375
375
< div class ="viewcode-block " id ="Conv ">
376
376
< a class ="viewcode-back " href ="../../../api/base/conv.html#topomodelx.base.conv.Conv "> [docs]</ a >
377
- < span class ="k "> class</ span > < span class ="nc "> Conv</ span > < span class ="p "> (</ span > < span class ="n "> MessagePassing</ span > < span class ="p "> ):</ span >
377
+ < span class ="k "> class</ span > < span class =" w " > </ span > < span class ="nc "> Conv</ span > < span class ="p "> (</ span > < span class ="n "> MessagePassing</ span > < span class ="p "> ):</ span >
378
378
< span class ="w "> </ span > < span class ="sd "> """Message passing: steps 1, 2, and 3.</ span >
379
379
380
380
< span class ="sd "> Builds the message passing route given by one neighborhood matrix.</ span >
@@ -401,7 +401,7 @@ <h1>Source code for topomodelx.base.conv</h1><div class="highlight"><pre>
401
401
< span class ="sd "> NB: if `False` in_channels has to be equal to out_channels.</ span >
402
402
< span class ="sd "> """</ span >
403
403
404
- < span class ="k "> def</ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span >
404
+ < span class ="k "> def</ span > < span class =" w " > </ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span >
405
405
< span class ="bp "> self</ span > < span class ="p "> ,</ span >
406
406
< span class ="n "> in_channels</ span > < span class ="p "> ,</ span >
407
407
< span class ="n "> out_channels</ span > < span class ="p "> ,</ span >
@@ -443,7 +443,7 @@ <h1>Source code for topomodelx.base.conv</h1><div class="highlight"><pre>
443
443
444
444
< div class ="viewcode-block " id ="Conv.update ">
445
445
< a class ="viewcode-back " href ="../../../api/base/conv.html#topomodelx.base.conv.Conv.update "> [docs]</ a >
446
- < span class ="k "> def</ span > < span class ="nf "> update</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> x_message_on_target</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> :</ span >
446
+ < span class ="k "> def</ span > < span class =" w " > </ span > < span class ="nf "> update</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> x_message_on_target</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> :</ span >
447
447
< span class ="w "> </ span > < span class ="sd "> """Update embeddings on each cell (step 4).</ span >
448
448
449
449
< span class ="sd "> Parameters</ span >
@@ -465,7 +465,7 @@ <h1>Source code for topomodelx.base.conv</h1><div class="highlight"><pre>
465
465
466
466
< div class ="viewcode-block " id ="Conv.forward ">
467
467
< a class ="viewcode-back " href ="../../../api/base/conv.html#topomodelx.base.conv.Conv.forward "> [docs]</ a >
468
- < span class ="k "> def</ span > < span class ="nf "> forward</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> x_source</ span > < span class ="p "> ,</ span > < span class ="n "> neighborhood</ span > < span class ="p "> ,</ span > < span class ="n "> x_target</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> :</ span >
468
+ < span class ="k "> def</ span > < span class =" w " > </ span > < span class ="nf "> forward</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> x_source</ span > < span class ="p "> ,</ span > < span class ="n "> neighborhood</ span > < span class ="p "> ,</ span > < span class ="n "> x_target</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> :</ span >
469
469
< span class ="w "> </ span > < span class ="sd "> """Forward pass.</ span >
470
470
471
471
< span class ="sd "> This implements message passing:</ span >
0 commit comments