1
1
#include " network/activation.h"
2
+ #include " network/direction.h"
3
+ #include " network/input.h"
4
+ #include " network/output.h"
5
+ #include " network/activation_definition.h"
2
6
3
7
const std::function<bool (Activation*, Activation*)> Activation::ID_COMPARATOR = [](Activation* a1, Activation* a2) {
4
8
return a1->getId () < a2->getId ();
5
9
};
6
10
7
11
Activation::Activation (ActivationDefinition* t, Activation* parent, int id, Neuron* n, Document* doc, std::map<BSType, BindingSignal*> bindingSignals)
8
- : id(id), neuron(n), doc(doc), bindingSignals(bindingSignals), parent(parent), created(Timestamp::NOT_SET), fired(Timestamp::NOT_SET), firedStep(new Fired(this )) {
12
+ : Obj(t), id(id), neuron(n), doc(doc), bindingSignals(bindingSignals), parent(parent), created(Timestamp::NOT_SET), fired(Timestamp::NOT_SET), firedStep(new Fired(this )) {
9
13
doc->addActivation (this );
10
14
neuron->updateLastUsed (doc->getId ());
11
15
setCreated (doc->getCurrentTimestamp ());
@@ -15,6 +19,35 @@ Activation::~Activation() {
15
19
delete firedStep;
16
20
}
17
21
22
+ RelatedObjectIterable* Activation::followManyRelation (Relation* rel) const {
23
+ // Create a custom iterable for each relation type
24
+ if (rel->getRelationName () == " INPUT" ) {
25
+ // Since getInputLinks() is pure virtual, derived classes should implement specialized behavior
26
+ // This base implementation handles the common case for OUTPUT relations
27
+ return nullptr ;
28
+ } else if (rel->getRelationName () == " OUTPUT" ) {
29
+ // Convert getOutputLinks() vector to an iterable
30
+ std::vector<Link*> links = const_cast <Activation*>(this )->getOutputLinks ();
31
+ std::vector<Obj*> objs;
32
+ for (Link* link : links) {
33
+ objs.push_back (static_cast <Obj*>(link ));
34
+ }
35
+ return new VectorObjectIterable (objs);
36
+ } else {
37
+ throw std::runtime_error (" Invalid Relation: " + rel->getRelationName ());
38
+ }
39
+ }
40
+
41
+ Obj* Activation::followSingleRelation (const Relation* rel) {
42
+ if (rel->getRelationName () == " SELF" ) {
43
+ return this ;
44
+ } else if (rel->getRelationName () == " NEURON" ) {
45
+ return neuron;
46
+ } else {
47
+ throw std::runtime_error (" Invalid Relation" );
48
+ }
49
+ }
50
+
18
51
ActivationKey Activation::getKey () {
19
52
return ActivationKey (neuron->getId (), id);
20
53
}
0 commit comments