Skip to content

Commit 3949131

Browse files
committed
feat(merge): module graph
2 parents 0bee665 + 7ec6a0a commit 3949131

File tree

9 files changed

+520
-38
lines changed

9 files changed

+520
-38
lines changed

components/acoustics-porting/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,18 @@ file(GLOB_RECURSE ACOUSTICS_SRCS_HAL
4848
${CMAKE_CURRENT_LIST_DIR}/../acoustics/hal/*.cpp
4949
)
5050

51+
file(GLOB_RECURSE ACOUSTICS_SRCS_MODULE
52+
${CMAKE_CURRENT_LIST_DIR}/module/*.c
53+
${CMAKE_CURRENT_LIST_DIR}/module/*.hpp
54+
${CMAKE_CURRENT_LIST_DIR}/module/*.cpp
55+
)
56+
message(STATUS "ACOUSTICS_SRCS_MODULE: ${ACOUSTICS_SRCS_MODULE}")
57+
5158
set(ACOUSTICS_SRCS
5259
${ACOUSTICS_SRCS_API}
5360
${ACOUSTICS_SRCS_CORE}
5461
${ACOUSTICS_SRCS_HAL}
62+
${ACOUSTICS_SRCS_MODULE}
5563
)
5664

5765
list(APPEND ACOUSTICS_PORTING_SRCS ${ACOUSTICS_SRCS})

components/acoustics/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,18 @@ file(GLOB_RECURSE ACOUSTICS_SRCS_HAL
5252
)
5353
message(STATUS "ACOUSTICS_SRCS_HAL: ${ACOUSTICS_SRCS_HAL}")
5454

55+
file(GLOB_RECURSE ACOUSTICS_SRCS_MODULE
56+
${CMAKE_CURRENT_LIST_DIR}/module/*.c
57+
${CMAKE_CURRENT_LIST_DIR}/module/*.hpp
58+
${CMAKE_CURRENT_LIST_DIR}/module/*.cpp
59+
)
60+
message(STATUS "ACOUSTICS_SRCS_MODULE: ${ACOUSTICS_SRCS_MODULE}")
61+
5562
set(ACOUSTICS_SRCS
5663
${ACOUSTICS_SRCS_API}
5764
${ACOUSTICS_SRCS_CORE}
5865
${ACOUSTICS_SRCS_HAL}
66+
${ACOUSTICS_SRCS_MODULE}
5967
)
6068

6169
get_property(ACOUSTICS_SDK_TARGET GLOBAL PROPERTY ACOUSTICS_SDK_TARGET)

components/acoustics/core/tensor.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "logger.hpp"
66

77
#include <algorithm>
8+
#include <cmath>
89
#include <cstddef>
910
#include <cstdint>
1011
#include <memory>
@@ -73,6 +74,11 @@ class Tensor final
7374
return _dims[index];
7475
}
7576

77+
inline bool operator==(const Shape &other) const noexcept
78+
{
79+
return _dot == other._dot && std::equal(_dims.cbegin(), _dims.cend(), other._dims.cbegin());
80+
}
81+
7682
inline int dot() const noexcept
7783
{
7884
return _dot;
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "module_dag.hpp"
2+
3+
namespace module {
4+
5+
core::Status MDAGBuilderRegistry::registerDAGBuilder(std::string_view name, DAGBuilder builder) noexcept
6+
{
7+
if (name.empty()) [[unlikely]]
8+
{
9+
LOG(ERROR, "DAG name cannot be empty");
10+
return STATUS(EINVAL, "DAG name cannot be empty");
11+
}
12+
if (!builder) [[unlikely]]
13+
{
14+
LOG(ERROR, "DAG builder cannot be null");
15+
return STATUS(EINVAL, "DAG builder cannot be null");
16+
}
17+
18+
auto it = _dags.find(name);
19+
if (it != _dags.end()) [[unlikely]]
20+
{
21+
LOG(WARNING, "DAG builder for %s already exists", name.data());
22+
return STATUS(EEXIST, "DAG builder already exists");
23+
}
24+
25+
_dags.emplace(name, std::move(builder));
26+
27+
return STATUS_OK();
28+
}
29+
30+
MDAGBuilderRegistry::DAGBuilderMap MDAGBuilderRegistry::_dags;
31+
32+
} // namespace module
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
#pragma once
2+
#ifndef MODULE_DAG_HPP
3+
#define MODULE_DAG_HPP
4+
5+
#include "module_node.hpp"
6+
7+
#include "core/config_object.hpp"
8+
#include "core/logger.hpp"
9+
#include "core/status.hpp"
10+
#include "core/tensor.hpp"
11+
12+
#include <algorithm>
13+
#include <forward_list>
14+
#include <functional>
15+
#include <memory>
16+
#include <queue>
17+
#include <string_view>
18+
#include <unordered_map>
19+
#include <utility>
20+
#include <vector>
21+
22+
namespace module {
23+
24+
class MDAG;
25+
26+
class MDAGBuilderRegistry final
27+
{
28+
public:
29+
using DAGBuilder = std::shared_ptr<MDAG> (*)(const core::ConfigMap &);
30+
using DAGBuilderMap = std::unordered_map<std::string_view, DAGBuilder>;
31+
32+
MDAGBuilderRegistry() = default;
33+
~MDAGBuilderRegistry() = default;
34+
35+
inline static std::shared_ptr<MDAG> getDAG(std::string_view name, const core::ConfigMap &configs) noexcept
36+
{
37+
auto it = _dags.find(name);
38+
if (it != _dags.end()) [[likely]]
39+
{
40+
return it->second(configs);
41+
}
42+
return {};
43+
}
44+
45+
static const DAGBuilderMap &getDAGBuilderMap() noexcept
46+
{
47+
return _dags;
48+
}
49+
50+
static core::Status registerDAGBuilder(std::string_view name, DAGBuilder builder) noexcept;
51+
52+
private:
53+
static DAGBuilderMap _dags;
54+
};
55+
56+
class MDAG final
57+
{
58+
public:
59+
explicit MDAG(std::string_view name) noexcept : _name(name), _nodes(), _adj(), _in_degree(), _execution_order()
60+
{
61+
if (name.empty()) [[unlikely]]
62+
{
63+
LOG(WARNING, "DAG name cannot be empty");
64+
}
65+
}
66+
67+
~MDAG() = default;
68+
69+
MNode *addNode(std::shared_ptr<MNode> node) noexcept
70+
{
71+
if (!node) [[unlikely]]
72+
{
73+
LOG(ERROR, "Attempted to add a null node to the DAG");
74+
return nullptr;
75+
}
76+
77+
auto ptr = node.get();
78+
if (std::find_if(_nodes.begin(), _nodes.end(),
79+
[ptr](const std::shared_ptr<MNode> &n) { return n.get() == ptr; })
80+
!= _nodes.end()) [[unlikely]]
81+
{
82+
LOG(WARNING, "Node %s already exists in the DAG", node->name().data());
83+
return ptr;
84+
}
85+
86+
_execution_order.clear();
87+
88+
_nodes.push_front(std::move(node));
89+
90+
if (_adj.find(ptr) == _adj.end()) [[likely]]
91+
{
92+
_adj[ptr] = {};
93+
}
94+
if (_in_degree.find(ptr) == _in_degree.end()) [[likely]]
95+
{
96+
_in_degree[ptr] = 0;
97+
}
98+
99+
return ptr;
100+
}
101+
102+
bool addEdge(MNode *from, MNode *to) noexcept
103+
{
104+
if (!from || !to) [[unlikely]]
105+
{
106+
LOG(ERROR, "Attempted to add an edge with null nodes");
107+
return false;
108+
}
109+
if (_adj.find(from) == _adj.end() || _adj.find(to) == _adj.end()) [[unlikely]]
110+
{
111+
LOG(ERROR, "One or both nodes not found in the DAG");
112+
return false;
113+
}
114+
115+
_execution_order.clear();
116+
117+
_adj[from].push_front(to);
118+
++_in_degree[to];
119+
120+
return true;
121+
}
122+
123+
bool computeExecutionOrder() const noexcept
124+
{
125+
std::priority_queue<MNode *, std::vector<MNode *>, MNode::Lower> pq;
126+
127+
for (const auto &id: _in_degree)
128+
{
129+
if (id.second == 0) [[likely]]
130+
{
131+
pq.push(id.first);
132+
}
133+
}
134+
135+
while (!pq.empty())
136+
{
137+
auto *node = pq.top();
138+
pq.pop();
139+
_execution_order.push_back(node);
140+
141+
for (auto *neighbor: _adj.at(node))
142+
{
143+
const auto id = --_in_degree[neighbor];
144+
if (id == 0) [[likely]]
145+
{
146+
pq.push(neighbor);
147+
}
148+
}
149+
}
150+
151+
if (_execution_order.size() != std::distance(_nodes.begin(), _nodes.end())) [[unlikely]]
152+
{
153+
LOG(ERROR, "Cycle detected in the DAG, execution order incomplete");
154+
_execution_order.clear();
155+
return false;
156+
}
157+
158+
_execution_order.shrink_to_fit();
159+
160+
return true;
161+
}
162+
163+
MNode *inputNode() const noexcept
164+
{
165+
if (_execution_order.empty()) [[unlikely]]
166+
{
167+
if (!computeExecutionOrder()) [[unlikely]]
168+
{
169+
LOG(ERROR, "Failed to compute execution order for the DAG");
170+
return nullptr;
171+
}
172+
if (_execution_order.empty()) [[unlikely]]
173+
{
174+
LOG(ERROR, "No input node found in the DAG");
175+
return nullptr;
176+
}
177+
}
178+
return _execution_order.front();
179+
}
180+
181+
MNode *outputNode() const noexcept
182+
{
183+
if (_execution_order.empty()) [[unlikely]]
184+
{
185+
if (!computeExecutionOrder()) [[unlikely]]
186+
{
187+
LOG(ERROR, "Failed to compute execution order for the DAG");
188+
return nullptr;
189+
}
190+
if (_execution_order.empty()) [[unlikely]]
191+
{
192+
LOG(ERROR, "No input node found in the DAG");
193+
return nullptr;
194+
}
195+
}
196+
return _execution_order.back();
197+
}
198+
199+
MNode *node(std::string_view name) const noexcept
200+
{
201+
for (const auto &node: _nodes)
202+
{
203+
if (node->name() == name)
204+
{
205+
return node.get();
206+
}
207+
}
208+
return nullptr;
209+
}
210+
211+
inline core::Status operator()() noexcept
212+
{
213+
if (_execution_order.empty()) [[unlikely]]
214+
{
215+
if (!computeExecutionOrder()) [[unlikely]]
216+
{
217+
return STATUS(EFAULT, "Failed to compute execution order for the DAG");
218+
}
219+
}
220+
221+
for (auto *node: _execution_order)
222+
{
223+
if (const auto &status = (*node)(); !status) [[unlikely]]
224+
{
225+
return status;
226+
}
227+
}
228+
229+
return STATUS_OK();
230+
}
231+
232+
private:
233+
const std::string_view _name;
234+
std::forward_list<std::shared_ptr<module::MNode>> _nodes;
235+
std::unordered_map<MNode *, std::forward_list<MNode *>> _adj;
236+
mutable std::unordered_map<MNode *, int> _in_degree;
237+
mutable std::vector<MNode *> _execution_order;
238+
};
239+
240+
} // namespace module
241+
242+
namespace bridge {
243+
244+
extern void __REGISTER_MODULE_DAG_BUILDER__();
245+
246+
} // namespace bridge
247+
248+
#endif

0 commit comments

Comments
 (0)