Skip to content

Commit 41bab99

Browse files
Implemented followManyRelation and followSingleRelation
1 parent a94387a commit 41bab99

18 files changed

+226
-20
lines changed

include/fields/rel_obj_iterator.h

+35-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef REL_OBJ_ITERATOR_H
22
#define REL_OBJ_ITERATOR_H
33

4+
#include <vector>
5+
46
class Obj;
57

68
class RelatedObjectIterator {
@@ -31,6 +33,27 @@ class SingleObjectIterator : public RelatedObjectIterator {
3133
}
3234
};
3335

36+
class VectorObjectIterator : public RelatedObjectIterator {
37+
std::vector<Obj*> _objects;
38+
size_t _currentIndex;
39+
40+
public:
41+
VectorObjectIterator(const std::vector<Obj*>& objects)
42+
: _objects(objects), _currentIndex(0) {}
43+
44+
~VectorObjectIterator() = default;
45+
46+
bool hasNext() override {
47+
return _currentIndex < _objects.size();
48+
}
49+
50+
Obj* next() override {
51+
if (!hasNext()) {
52+
return nullptr;
53+
}
54+
return _objects[_currentIndex++];
55+
}
56+
};
3457

3558
template <typename MapType>
3659
class MapRelObjIterator : public RelatedObjectIterator {
@@ -79,9 +102,20 @@ class SingleObjectIterable : public RelatedObjectIterable {
79102
public:
80103
SingleObjectIterable(Obj* obj) : _obj(obj) {}
81104

82-
SingleObjectIterator* iterator() const {
105+
SingleObjectIterator* iterator() const override {
83106
return new SingleObjectIterator(_obj);
84107
}
85108
};
86109

110+
class VectorObjectIterable : public RelatedObjectIterable {
111+
std::vector<Obj*> _objects;
112+
113+
public:
114+
VectorObjectIterable(const std::vector<Obj*>& objects) : _objects(objects) {}
115+
116+
VectorObjectIterator* iterator() const override {
117+
return new VectorObjectIterator(_objects);
118+
}
119+
};
120+
87121
#endif //REL_OBJ_ITERATOR_H

include/network/activation.h

+4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ class Activation : public Obj, public Element, public ModelProvider {
2626
Activation(ActivationDefinition* t, Activation* parent, int id, Neuron* n, Document* doc, std::map<BSType, BindingSignal*> bindingSignals);
2727
virtual ~Activation();
2828

29+
// Implementation of Obj virtual methods
30+
RelatedObjectIterable* followManyRelation(Relation* rel) const override;
31+
Obj* followSingleRelation(const Relation* rel) override;
32+
2933
ActivationKey getKey();
3034
Activation* getParent();
3135
void addOutputLink(Link* l);

include/network/conjunctive_activation.h

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ class ConjunctiveActivation : public Activation {
99
ConjunctiveActivation(ActivationDefinition* t, Activation* parent, int id, Neuron* n, Document* doc, std::map<BSType, BindingSignal*> bindingSignals);
1010
virtual ~ConjunctiveActivation();
1111

12+
RelatedObjectIterable* followManyRelation(Relation* rel) const override;
13+
1214
void linkIncoming(Activation* excludedInputAct) override;
1315
void addInputLink(Link* l) override;
1416
std::vector<Link*> getInputLinks() override;

include/network/conjunctive_synapse.h

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ class ConjunctiveSynapse : public Synapse {
88
ConjunctiveSynapse(SynapseDefinition* type);
99
ConjunctiveSynapse(SynapseDefinition* type, Neuron* input, Neuron* output);
1010

11+
RelatedObjectIterable* followManyRelation(Relation* rel) const override;
12+
Obj* followSingleRelation(const Relation* rel) override;
13+
1114
void write(DataOutput* out) override;
1215
void readFields(DataInput* in, TypeRegistry* tr) override;
1316

include/network/disjunctive_activation.h

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ class DisjunctiveActivation : public Activation {
99
DisjunctiveActivation(ActivationDefinition* t, Activation* parent, int id, Neuron* n, Document* doc, std::map<BSType, BindingSignal*> bindingSignals);
1010
virtual ~DisjunctiveActivation();
1111

12+
RelatedObjectIterable* followManyRelation(Relation* rel) const override;
13+
1214
void linkIncoming(Activation* excludedInputAct) override;
1315
void addInputLink(Link* l) override;
1416
std::vector<Link*> getInputLinks() override;

include/network/disjunctive_synapse.h

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ class DisjunctiveSynapse : public Synapse {
88
DisjunctiveSynapse(SynapseDefinition* type);
99
DisjunctiveSynapse(SynapseDefinition* type, Neuron* input, Neuron* output);
1010

11+
RelatedObjectIterable* followManyRelation(Relation* rel) const override;
12+
Obj* followSingleRelation(const Relation* rel) override;
13+
1114
void link(Model* m) override;
1215

1316
private:

include/network/inhibitory_activation.h

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ class InhibitoryActivation : public Activation {
1010
public:
1111
InhibitoryActivation(ActivationDefinition* t, Activation* parent, int id, Neuron* n, Document* doc, std::map<BSType*, BindingSignal*> bindingSignals);
1212

13+
RelatedObjectIterable* followManyRelation(Relation* rel) const override;
14+
1315
void addInputLink(Link* l) override;
1416
Link* getInputLink(int bsId);
1517
int getInputKey(Link* l);

include/network/link.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ class Link : public Obj, public Element, public ModelProvider {
2020
public:
2121
Link(LinkDefinition* type, Synapse* s, Activation* input, Activation* output);
2222

23-
RelatedObjectIterable* followManyRelation(Relation* rel) override;
24-
Obj* followSingleRelation(Relation* rel) override;
23+
RelatedObjectIterable* followManyRelation(Relation* rel) const override;
24+
Obj* followSingleRelation(const Relation* rel) override;
2525
Timestamp getFired() override;
2626
Timestamp getCreated() override;
2727
Synapse* getSynapse();

include/network/neuron.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class Neuron : public Obj, public Element, public ModelProvider {
1919
Neuron(NeuronDefinition* type, Model* model, long id);
2020
Neuron(NeuronDefinition* type, Model* model);
2121

22-
RelatedObjectIterable* followManyRelation(Relation* rel) override;
23-
Obj* followSingleRelation(Relation* rel) override;
22+
RelatedObjectIterable* followManyRelation(Relation* rel) const override;
23+
Obj* followSingleRelation(const Relation* rel) override;
2424
long getId() const;
2525
void updatePropagable(Neuron* n, bool isPropagable);
2626
void wakeupPropagable();

include/network/synapse.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class Synapse : public Obj, public Element {
2323

2424
virtual ~Synapse() = default;
2525

26-
virtual RelatedObjectIterable* followManyRelation(Relation* rel) = 0;
27-
virtual Obj* followSingleRelation(Relation* rel) = 0;
26+
virtual RelatedObjectIterable* followManyRelation(Relation* rel) const = 0;
27+
virtual Obj* followSingleRelation(const Relation* rel) = 0;
2828

2929
int getSynapseId() const;
3030
void setSynapseId(int synapseId);

src/network/activation.cpp

+34-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
#include "network/activation.h"
2+
#include "network/direction.h"
3+
#include "network/input.h"
4+
#include "network/output.h"
5+
#include "network/activation_definition.h"
26

37
const std::function<bool(Activation*, Activation*)> Activation::ID_COMPARATOR = [](Activation* a1, Activation* a2) {
48
return a1->getId() < a2->getId();
59
};
610

711
Activation::Activation(ActivationDefinition* t, Activation* parent, int id, Neuron* n, Document* doc, std::map<BSType, BindingSignal*> bindingSignals)
8-
: id(id), neuron(n), doc(doc), bindingSignals(bindingSignals), parent(parent), created(Timestamp::NOT_SET), fired(Timestamp::NOT_SET), firedStep(new Fired(this)) {
12+
: Obj(t), id(id), neuron(n), doc(doc), bindingSignals(bindingSignals), parent(parent), created(Timestamp::NOT_SET), fired(Timestamp::NOT_SET), firedStep(new Fired(this)) {
913
doc->addActivation(this);
1014
neuron->updateLastUsed(doc->getId());
1115
setCreated(doc->getCurrentTimestamp());
@@ -15,6 +19,35 @@ Activation::~Activation() {
1519
delete firedStep;
1620
}
1721

22+
RelatedObjectIterable* Activation::followManyRelation(Relation* rel) const {
23+
// Create a custom iterable for each relation type
24+
if (rel->getRelationName() == "INPUT") {
25+
// Since getInputLinks() is pure virtual, derived classes should implement specialized behavior
26+
// This base implementation handles the common case for OUTPUT relations
27+
return nullptr;
28+
} else if (rel->getRelationName() == "OUTPUT") {
29+
// Convert getOutputLinks() vector to an iterable
30+
std::vector<Link*> links = const_cast<Activation*>(this)->getOutputLinks();
31+
std::vector<Obj*> objs;
32+
for (Link* link : links) {
33+
objs.push_back(static_cast<Obj*>(link));
34+
}
35+
return new VectorObjectIterable(objs);
36+
} else {
37+
throw std::runtime_error("Invalid Relation: " + rel->getRelationName());
38+
}
39+
}
40+
41+
Obj* Activation::followSingleRelation(const Relation* rel) {
42+
if (rel->getRelationName() == "SELF") {
43+
return this;
44+
} else if (rel->getRelationName() == "NEURON") {
45+
return neuron;
46+
} else {
47+
throw std::runtime_error("Invalid Relation");
48+
}
49+
}
50+
1851
ActivationKey Activation::getKey() {
1952
return ActivationKey(neuron->getId(), id);
2053
}

src/network/conjunctive_activation.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
#include "network/conjunctive_activation.h"
22
#include "network/synapse.h"
33
#include "network/synapse_definition.h"
4+
#include "fields/relation.h"
5+
#include "fields/rel_obj_iterator.h"
46

57
ConjunctiveActivation::ConjunctiveActivation(ActivationDefinition* t, Activation* parent, int id, Neuron* n, Document* doc, std::map<BSType, BindingSignal*> bindingSignals)
68
: Activation(t, parent, id, n, doc, bindingSignals) {}
79

810
ConjunctiveActivation::~ConjunctiveActivation() {}
911

12+
RelatedObjectIterable* ConjunctiveActivation::followManyRelation(Relation* rel) const {
13+
if (rel->getRelationName() == "INPUT") {
14+
// Convert inputLinks to a vector of Obj*
15+
std::vector<Obj*> objs;
16+
for (const auto& pair : inputLinks) {
17+
objs.push_back(static_cast<Obj*>(pair.second));
18+
}
19+
return new VectorObjectIterable(objs);
20+
} else {
21+
// Use base class implementation for other relations
22+
return Activation::followManyRelation(rel);
23+
}
24+
}
25+
1026
void ConjunctiveActivation::linkIncoming(Activation* excludedInputAct) {
1127
for (auto& s : neuron->getInputSynapsesAsStream()) {
1228
if (static_cast<SynapseDefinition*>(s->getType())->isIncomingLinkingCandidate(getBindingSignals().keySet())) {

src/network/conjunctive_synapse.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,28 @@
11
#include "network/conjunctive_synapse.h"
2+
#include "fields/relation.h"
3+
#include "fields/rel_obj_iterator.h"
24

35
ConjunctiveSynapse::ConjunctiveSynapse(SynapseDefinition* type) : Synapse(type), propagable(false) {}
46

57
ConjunctiveSynapse::ConjunctiveSynapse(SynapseDefinition* type, Neuron* input, Neuron* output) : Synapse(type, input, output), propagable(false) {}
68

9+
RelatedObjectIterable* ConjunctiveSynapse::followManyRelation(Relation* rel) const {
10+
// Typically synapses don't have "many" relationships
11+
throw std::runtime_error("Invalid Relation for ConjunctiveSynapse: " + rel->getRelationName());
12+
}
13+
14+
Obj* ConjunctiveSynapse::followSingleRelation(const Relation* rel) {
15+
if (rel->getRelationName() == "SELF") {
16+
return this;
17+
} else if (rel->getRelationName() == "INPUT") {
18+
return getInput();
19+
} else if (rel->getRelationName() == "OUTPUT") {
20+
return getOutput();
21+
} else {
22+
throw std::runtime_error("Invalid Relation for ConjunctiveSynapse: " + rel->getRelationName());
23+
}
24+
}
25+
726
void ConjunctiveSynapse::write(DataOutput* out) {
827
Synapse::write(out);
928
out->writeBoolean(propagable);

src/network/disjunctive_activation.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,26 @@
11
#include "network/disjunctive_activation.h"
2+
#include "fields/relation.h"
3+
#include "fields/rel_obj_iterator.h"
24

35
DisjunctiveActivation::DisjunctiveActivation(ActivationDefinition* t, Activation* parent, int id, Neuron* n, Document* doc, std::map<BSType, BindingSignal*> bindingSignals)
46
: Activation(t, parent, id, n, doc, bindingSignals) {}
57

68
DisjunctiveActivation::~DisjunctiveActivation() {}
79

10+
RelatedObjectIterable* DisjunctiveActivation::followManyRelation(Relation* rel) const {
11+
if (rel->getRelationName() == "INPUT") {
12+
// Convert inputLinks to a vector of Obj*
13+
std::vector<Obj*> objs;
14+
for (const auto& pair : inputLinks) {
15+
objs.push_back(static_cast<Obj*>(pair.second));
16+
}
17+
return new VectorObjectIterable(objs);
18+
} else {
19+
// Use base class implementation for other relations
20+
return Activation::followManyRelation(rel);
21+
}
22+
}
23+
824
void DisjunctiveActivation::linkIncoming(Activation* excludedInputAct) {
925
// Implementation for linking incoming activations
1026
}

src/network/disjunctive_synapse.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,28 @@
11
#include "network/disjunctive_synapse.h"
2+
#include "fields/relation.h"
3+
#include "fields/rel_obj_iterator.h"
24

35
DisjunctiveSynapse::DisjunctiveSynapse(SynapseDefinition* type) : Synapse(type), propagable(true) {}
46

57
DisjunctiveSynapse::DisjunctiveSynapse(SynapseDefinition* type, Neuron* input, Neuron* output) : Synapse(type, input, output), propagable(true) {}
68

9+
RelatedObjectIterable* DisjunctiveSynapse::followManyRelation(Relation* rel) const {
10+
// Typically synapses don't have "many" relationships
11+
throw std::runtime_error("Invalid Relation for DisjunctiveSynapse: " + rel->getRelationName());
12+
}
13+
14+
Obj* DisjunctiveSynapse::followSingleRelation(const Relation* rel) {
15+
if (rel->getRelationName() == "SELF") {
16+
return this;
17+
} else if (rel->getRelationName() == "INPUT") {
18+
return getInput();
19+
} else if (rel->getRelationName() == "OUTPUT") {
20+
return getOutput();
21+
} else {
22+
throw std::runtime_error("Invalid Relation for DisjunctiveSynapse: " + rel->getRelationName());
23+
}
24+
}
25+
726
void DisjunctiveSynapse::link(Model* m) {
827
getInput(m)->addOutputSynapse(this);
928
}

src/network/inhibitory_activation.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,32 @@
22
#include "network/activation_definition.h"
33
#include "network/synapse_definition.h"
44
#include "network/binding_signal.h"
5+
#include "fields/relation.h"
6+
#include "fields/rel_obj_iterator.h"
57

68
InhibitoryActivation::InhibitoryActivation(ActivationDefinition* t, Activation* parent, int id, Neuron* n, Document* doc, std::map<BSType*, BindingSignal*> bindingSignals)
79
: Activation(t, parent, id, n, doc, bindingSignals) {}
10+
11+
RelatedObjectIterable* InhibitoryActivation::followManyRelation(Relation* rel) const {
12+
if (rel->getRelationName() == "INPUT") {
13+
// Convert inputLinks to a vector of Obj*
14+
std::vector<Obj*> objs;
15+
for (const auto& pair : inputLinks) {
16+
objs.push_back(static_cast<Obj*>(pair.second));
17+
}
18+
return new VectorObjectIterable(objs);
19+
} else if (rel->getRelationName() == "OUTPUT") {
20+
// For InhibitoryActivation, we override both INPUT and OUTPUT handling
21+
std::vector<Obj*> objs;
22+
for (const auto& pair : outputLinks) {
23+
objs.push_back(static_cast<Obj*>(pair.second));
24+
}
25+
return new VectorObjectIterable(objs);
26+
} else {
27+
// Use base class implementation for other relations
28+
return Activation::followManyRelation(rel);
29+
}
30+
}
831

932
void InhibitoryActivation::addInputLink(Link* l) {
1033
int bsId = getInputKey(l);

src/network/link.cpp

+12-12
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@ Link::Link(LinkDefinition* type, Synapse* s, Activation* input, Activation* outp
1717
output->addInputLink(this);
1818
}
1919

20-
Stream<Obj*> Link::followManyRelation(Relation* rel) {
21-
throw std::runtime_error("Invalid Relation");
22-
}
23-
24-
Obj* Link::followSingleRelation(Relation* rel) {
25-
if (rel == SELF) return this;
26-
if (rel == LinkDefinition::INPUT) return input;
27-
if (rel == LinkDefinition::OUTPUT) return output;
28-
if (rel == LinkDefinition::SYNAPSE) return synapse;
29-
if (rel == LinkDefinition::CORRESPONDING_INPUT_LINK) return input->getCorrespondingInputLink(this);
30-
if (rel == LinkDefinition::CORRESPONDING_OUTPUT_LINK) return output->getCorrespondingOutputLink(this);
31-
throw std::runtime_error("Invalid Relation");
20+
RelatedObjectIterable* Link::followManyRelation(Relation* rel) const {
21+
throw std::runtime_error("Invalid Relation: " + rel->getRelationName());
22+
}
23+
24+
Obj* Link::followSingleRelation(const Relation* rel) {
25+
if (rel->getRelationName() == "SELF") return this;
26+
if (rel->getRelationName() == "INPUT") return input;
27+
if (rel->getRelationName() == "OUTPUT") return output;
28+
if (rel->getRelationName() == "SYNAPSE") return synapse;
29+
if (rel->getRelationName() == "CORRESPONDING_INPUT_LINK") return input->getCorrespondingInputLink(this);
30+
if (rel->getRelationName() == "CORRESPONDING_OUTPUT_LINK") return output->getCorrespondingOutputLink(this);
31+
throw std::runtime_error("Invalid Relation: " + rel->getRelationName());
3232
}
3333

3434
Timestamp Link::getFired() {

0 commit comments

Comments
 (0)