Skip to content

Commit ebea7ff

Browse files
committed
decoder: Implement descision heuristics
1 parent 966c51a commit ebea7ff

File tree

3 files changed

+308
-14
lines changed

3 files changed

+308
-14
lines changed

vadl/main/vadl/vdt/impl/irregular/IrregularDecodeTreeGenerator.java

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,19 @@ protected Node makeNode(DecodeEntries decodeEntries) {
127127
final MultiPatterns patterns = makePatterns(decodeEntries);
128128

129129
if (!patterns.hasDecision()) {
130-
// Split entry set by exclusion conditions instead
131-
return makeConditionNode(decodeEntries);
130+
// Select best splitting pattern based on exclusion conditions
131+
final BitPattern pattern = selectPattern(decodeEntries);
132+
return makeConditionNode(decodeEntries, pattern);
132133
}
133134

134135
final MultiSplitEntrySet splitEntries = split(decodeEntries, patterns.mask());
135136

137+
return makeMultiDecisionNode(decodeEntries, splitEntries);
138+
}
139+
140+
protected MultiDecisionNode makeMultiDecisionNode(DecodeEntries decodeEntries,
141+
MultiSplitEntrySet splitEntries) {
142+
136143
final Map<BitPattern, Node> children = new HashMap<>();
137144
for (var branches : splitEntries.entries().entrySet()) {
138145

@@ -150,13 +157,10 @@ protected Node makeNode(DecodeEntries decodeEntries) {
150157
children.put(pattern, childNode);
151158
}
152159

153-
return new MultiDecisionNode(patterns.mask(), children);
160+
return new MultiDecisionNode(splitEntries.mask(), children);
154161
}
155162

156-
private Node makeConditionNode(DecodeEntries decodeEntries) {
157-
158-
// Select best splitting pattern based on exclusion conditions
159-
final BitPattern pattern = selectPattern(decodeEntries);
163+
protected Node makeConditionNode(DecodeEntries decodeEntries, BitPattern pattern) {
160164

161165
// Split the entry set
162166
final SingleSplitEntrySet splitEntries = split(decodeEntries, pattern);
@@ -192,7 +196,7 @@ private MultiPatterns makePatterns(DecodeEntries decodeEntries, BitVector mask)
192196

193197
// We don't need to check bits more than once
194198
BitVector checked = decodeEntries.checkedBits().toMaskVector();
195-
mask = mask.xor(checked);
199+
mask = mask.and(checked.not());
196200

197201
final Set<BitPattern> options = new LinkedHashSet<>();
198202
for (DecodeEntry e : decodeEntries.entries()) {
@@ -220,7 +224,7 @@ protected record MultiSplitEntrySet(BitVector mask, Map<BitPattern, List<DecodeE
220224
* @param pattern the splitting pattern.
221225
* @return the split entry set.
222226
*/
223-
private SingleSplitEntrySet split(DecodeEntries decodeEntries, BitPattern pattern) {
227+
protected SingleSplitEntrySet split(DecodeEntries decodeEntries, BitPattern pattern) {
224228

225229
final List<DecodeEntry> matchingEntries = new ArrayList<>();
226230
final List<DecodeEntry> otherEntries = new ArrayList<>();
@@ -266,7 +270,7 @@ private SingleSplitEntrySet split(DecodeEntries decodeEntries, BitPattern patter
266270
* @param mask the splitting mask.
267271
* @return the split entry set.
268272
*/
269-
private MultiSplitEntrySet split(DecodeEntries decodeEntries, BitVector mask) {
273+
protected MultiSplitEntrySet split(DecodeEntries decodeEntries, BitVector mask) {
270274

271275
final MultiPatterns patterns = makePatterns(decodeEntries, mask);
272276
final Map<BitPattern, List<DecodeEntry>> entries = new LinkedHashMap<>();
@@ -664,7 +668,7 @@ private Diagnostic toConstructionDiagnostic(DecodeEntries decodeEntries) {
664668
.map(Definition::simpleName)
665669
.toList();
666670

667-
var diagnostic = error(("Unable to split instruction set during decoder generation: %s")
671+
var diagnostic = error("Unable to split instruction set during decoder generation: %s"
668672
.formatted(insnNames), primary);
669673

670674
for (DecodeEntry e : decodeEntries.entries()) {

vadl/main/vadl/vdt/impl/irregular/OccurrenceAwareDecodeTreeGenerator.java

Lines changed: 292 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,26 @@
1616

1717
package vadl.vdt.impl.irregular;
1818

19-
import org.apache.commons.lang3.NotImplementedException;
19+
import static vadl.vdt.utils.PBit.Value.DONT_CARE;
20+
21+
import java.math.BigInteger;
22+
import java.util.ArrayList;
23+
import java.util.Collection;
24+
import java.util.Comparator;
25+
import java.util.Iterator;
26+
import java.util.List;
27+
import java.util.NoSuchElementException;
28+
import java.util.Objects;
29+
import java.util.Optional;
30+
import java.util.PriorityQueue;
31+
import java.util.function.Function;
32+
import javax.annotation.Nonnull;
33+
import vadl.utils.Pair;
2034
import vadl.vdt.impl.irregular.model.DecodeEntries;
35+
import vadl.vdt.impl.irregular.model.DecodeEntry;
2136
import vadl.vdt.model.Node;
37+
import vadl.vdt.utils.BitPattern;
38+
import vadl.vdt.utils.BitVector;
2239

2340
/**
2441
* Decode tree generator largely based on the Qin et al. decode tree construction algorithm using
@@ -30,6 +47,7 @@
3047
*/
3148
public class OccurrenceAwareDecodeTreeGenerator extends IrregularDecodeTreeGenerator {
3249

50+
@SuppressWarnings("UnusedVariable")
3351
private final double memoryPenalty;
3452

3553
public OccurrenceAwareDecodeTreeGenerator(double memoryPenalty) {
@@ -38,8 +56,280 @@ public OccurrenceAwareDecodeTreeGenerator(double memoryPenalty) {
3856

3957
@Override
4058
protected Node makeNode(DecodeEntries decodeEntries) {
41-
throw new NotImplementedException("TODO");
59+
60+
final var singleDecision = bestSingleDecision(decodeEntries).orElse(null);
61+
final var multiDecision = bestMultiDecision(decodeEntries).orElse(null);
62+
63+
if (singleDecision == null && multiDecision == null) {
64+
throw new IllegalStateException("Unable to split entry set");
65+
}
66+
67+
if (multiDecision == null
68+
|| (singleDecision != null && singleDecision.left() < multiDecision.left())) {
69+
70+
return makeConditionNode(decodeEntries,
71+
Objects.requireNonNull(singleDecision).right().pattern());
72+
}
73+
74+
return makeMultiDecisionNode(decodeEntries, multiDecision.right());
75+
}
76+
77+
private Pair<Double, SingleSplitEntrySet> calculateSingleCost(DecodeEntries entries,
78+
BitPattern p) {
79+
final SingleSplitEntrySet split = split(entries, p);
80+
81+
final double cost = calculateCost(List.of(split.matching(), split.others()));
82+
83+
// Add a memory penalty
84+
final int s = split.matching().size() + split.others().size() - 1;
85+
final double meRatio = s / (entries.entries().size() - 1.0);
86+
final double penalty = memoryPenalty * (Math.log(meRatio) / Math.log(2));
87+
88+
return new Pair<>(cost + penalty, split);
89+
}
90+
91+
private Pair<Double, MultiSplitEntrySet> calculateMultiCost(DecodeEntries entries,
92+
BitVector mask) {
93+
final MultiSplitEntrySet split = split(entries, mask);
94+
95+
final double cost = calculateCost(split.entries().values());
96+
97+
// Add a memory penalty
98+
final int m = mask.toValue().bitCount();
99+
final double s = split.entries().values().stream()
100+
.filter(e -> !e.isEmpty())
101+
.reduce(0, (acc, e) -> acc + (e.size() - 1), Integer::sum)
102+
+ 1 + Math.pow(2, m);
103+
104+
final double meRatio = s / (entries.entries().size() - 1.0);
105+
final double penalty = memoryPenalty * (Math.log(meRatio) / Math.log(2));
106+
107+
return new Pair<>(cost + penalty, split);
108+
}
109+
110+
private double calculateCost(Collection<List<DecodeEntry>> splits) {
111+
return 1 + splits.stream()
112+
.reduce(0.0, (acc, e) -> {
113+
double prob = e.stream()
114+
.map(DecodeEntry::occurrenceProbability).reduce(0.0, Double::sum);
115+
return acc + prob * huffmanTreeHeight(e);
116+
}, Double::sum);
117+
}
118+
119+
private Optional<Pair<Double, SingleSplitEntrySet>> bestSingleDecision(
120+
DecodeEntries decodeEntries) {
121+
return findBestPattern(decodeEntries.checkedBits(), p -> {
122+
var split = calculateSingleCost(decodeEntries, p);
123+
if (split.right().matching().isEmpty() || split.right().others().isEmpty()) {
124+
return Optional.empty();
125+
}
126+
return Optional.of(split.left());
127+
})
128+
.map(decision -> calculateSingleCost(decodeEntries, decision));
129+
}
130+
131+
private Optional<Pair<Double, MultiSplitEntrySet>> bestMultiDecision(
132+
DecodeEntries decodeEntries) {
133+
134+
final var candidates = baseMaskCandidates(decodeEntries.checkedBits());
135+
136+
if (candidates.isEmpty()) {
137+
return Optional.empty();
138+
}
139+
140+
final Pair<Double, MultiSplitEntrySet> base = candidates.stream()
141+
.map(m -> {
142+
var split = calculateMultiCost(decodeEntries, m);
143+
var hasDecision = split.right().entries().values().stream()
144+
.filter(es -> !es.isEmpty())
145+
.count() > 1;
146+
return hasDecision ? split : null;
147+
})
148+
.filter(Objects::nonNull)
149+
.min(Comparator.comparing(Pair::left)).orElse(null);
150+
151+
final var checked = decodeEntries.checkedBits().toMaskVector();
152+
153+
Pair<Double, MultiSplitEntrySet> prev = Objects.requireNonNull(base);
154+
Pair<Double, MultiSplitEntrySet> next = prev;
155+
do {
156+
157+
prev = next;
158+
159+
// candidate bits to try next
160+
var currentBase = prev.right().mask();
161+
var bits = checked.not().and(currentBase.not()).toValue();
162+
163+
while (bits.getLowestSetBit() >= 0) {
164+
int i = bits.getLowestSetBit();
165+
166+
var candidate = BitVector.fromValue(currentBase.toValue().setBit(i), currentBase.width());
167+
var cost = calculateMultiCost(decodeEntries, candidate);
168+
169+
if (cost.left() < prev.left()) {
170+
next = cost;
171+
}
172+
173+
bits = bits.clearBit(i);
174+
}
175+
176+
} while (next.left() < prev.left());
177+
178+
return Optional.of(prev);
42179
}
43180

181+
/**
182+
* Search for a minimal cost pattern. To not enumerate all possibilities (3^n) we grow the
183+
* pattern bit by bit as long as the cost improves.
184+
*
185+
* @param base the bit pattern to start with / or grow from.
186+
* @param costFunction the function calculating the cost of choosing a pattern.
187+
* @return all candidates of the base pattern extended with an additional decision bit.
188+
*/
189+
private Optional<BitPattern> findBestPattern(final BitPattern base,
190+
final Function<BitPattern,
191+
Optional<Double>> costFunction) {
192+
193+
boolean initialized = false;
194+
195+
BitPattern prev;
196+
double prevCost;
197+
198+
BitPattern next = base;
199+
double nextCost = Double.POSITIVE_INFINITY;
200+
201+
do {
202+
203+
// Move current best to the 'previous' values
204+
prev = next;
205+
prevCost = nextCost;
206+
207+
// Find next best
208+
var candidates = patternCandidates(prev);
209+
210+
if (!initialized && !candidates.hasNext()) {
211+
return Optional.empty();
212+
}
213+
initialized = true;
214+
215+
while (candidates.hasNext()) {
216+
var candidate = candidates.next();
217+
var cost = costFunction.apply(candidate);
218+
219+
if (cost.isPresent() && cost.get() < nextCost) {
220+
next = candidate;
221+
nextCost = cost.get();
222+
}
223+
}
224+
225+
} while (nextCost < prevCost);
226+
227+
return Optional.of(prev);
228+
}
229+
230+
/**
231+
* Generator for splitting pattern candidates by growing the base pattern with an additional bit.
232+
*
233+
* @param base the pattern from which to grow the splitting patterns.
234+
* @return all candidates of the base pattern extended with an additional decision bit.
235+
*/
236+
private Iterator<BitPattern> patternCandidates(BitPattern base) {
237+
238+
return new Iterator<>() {
239+
240+
private BigInteger candidates = base.toMaskVector().not().toValue();
241+
private boolean first = true;
242+
243+
@Override
244+
public boolean hasNext() {
245+
return !BigInteger.ZERO.equals(candidates) || !first;
246+
}
247+
248+
@Override
249+
public BitPattern next() {
250+
if (!hasNext()) {
251+
throw new NoSuchElementException();
252+
}
253+
254+
final int decisionBit = candidates.getLowestSetBit();
255+
256+
final var mask = base.toMaskVector().toValue().setBit(decisionBit);
257+
final var mVector = BitVector.fromValue(mask, base.width());
258+
259+
if (first) {
260+
// Flip the lowest candidate bit to 'one' only in the bitmask
261+
first = false;
262+
return BitPattern.fromBitVector(mVector, base.toBitVector());
263+
}
264+
265+
// Flip the lowest candidate bit to 'one' in the base pattern
266+
var value = base.toBitVector().toValue().setBit(decisionBit);
267+
var vVector = BitVector.fromValue(value, base.width());
268+
269+
// Clear the candidate bit
270+
candidates = candidates.clearBit(decisionBit);
271+
first = true;
272+
273+
return BitPattern.fromBitVector(mVector, vVector);
274+
}
275+
};
276+
}
277+
278+
/**
279+
* Generator for all relevant 2-bit mask candidates, given the already known bits in the base
280+
* pattern.
281+
*
282+
* @param base The base pattern, specifying known bits.
283+
* @return all 2-bit mask candidates
284+
*/
285+
private List<BitVector> baseMaskCandidates(final BitPattern base) {
286+
final int w = base.width();
287+
final List<BitVector> result = new ArrayList<>();
288+
for (int i = w - 1; i > 0; i--) {
289+
if (base.get(i).getValue() != DONT_CARE || base.get(i - 1).getValue() != DONT_CARE) {
290+
continue;
291+
}
292+
var maskValue = BigInteger.ZERO.setBit(w - i).setBit(w - i - 1);
293+
result.add(BitVector.fromValue(maskValue, w));
294+
}
295+
return result;
296+
}
297+
298+
private record HuffmanNode(double weight, int height)
299+
implements Comparable<HuffmanNode> {
300+
301+
static HuffmanNode of(DecodeEntry entry) {
302+
return new HuffmanNode(entry.occurrenceProbability(), 0);
303+
}
304+
305+
@Override
306+
public int compareTo(@Nonnull HuffmanNode o) {
307+
return Double.compare(weight, o.weight);
308+
}
309+
310+
public HuffmanNode merge(HuffmanNode node) {
311+
final var w = weight + node.weight;
312+
final var h = 1 + Math.max(height, node.height);
313+
return new HuffmanNode(w, h);
314+
}
315+
}
316+
317+
private static int huffmanTreeHeight(Collection<DecodeEntry> entries) {
318+
if (entries.isEmpty()) {
319+
return 0;
320+
}
321+
322+
final PriorityQueue<HuffmanNode> priorityQueue = new PriorityQueue<>(entries.size());
323+
for (var e : entries) {
324+
priorityQueue.add(HuffmanNode.of(e));
325+
}
326+
327+
while (priorityQueue.size() > 1) {
328+
var a = priorityQueue.poll();
329+
var b = priorityQueue.poll();
330+
priorityQueue.add(a.merge(Objects.requireNonNull(b)));
331+
}
332+
return Objects.requireNonNull(priorityQueue.peek()).height();
333+
}
44334

45335
}

0 commit comments

Comments
 (0)