1616
1717package vadl .vdt .impl .irregular ;
1818
19- import org .apache .commons .lang3 .NotImplementedException ;
19+ import static vadl .vdt .utils .PBit .Value .DONT_CARE ;
20+
21+ import java .math .BigInteger ;
22+ import java .util .ArrayList ;
23+ import java .util .Collection ;
24+ import java .util .Comparator ;
25+ import java .util .Iterator ;
26+ import java .util .List ;
27+ import java .util .NoSuchElementException ;
28+ import java .util .Objects ;
29+ import java .util .Optional ;
30+ import java .util .PriorityQueue ;
31+ import java .util .function .Function ;
32+ import javax .annotation .Nonnull ;
33+ import vadl .utils .Pair ;
2034import vadl .vdt .impl .irregular .model .DecodeEntries ;
35+ import vadl .vdt .impl .irregular .model .DecodeEntry ;
2136import vadl .vdt .model .Node ;
37+ import vadl .vdt .utils .BitPattern ;
38+ import vadl .vdt .utils .BitVector ;
2239
2340/**
2441 * Decode tree generator largely based on the Qin et al. decode tree construction algorithm using
3047 */
3148public class OccurrenceAwareDecodeTreeGenerator extends IrregularDecodeTreeGenerator {
3249
50+ @ SuppressWarnings ("UnusedVariable" )
3351 private final double memoryPenalty ;
3452
3553 public OccurrenceAwareDecodeTreeGenerator (double memoryPenalty ) {
@@ -38,8 +56,280 @@ public OccurrenceAwareDecodeTreeGenerator(double memoryPenalty) {
3856
3957 @ Override
4058 protected Node makeNode (DecodeEntries decodeEntries ) {
41- throw new NotImplementedException ("TODO" );
59+
60+ final var singleDecision = bestSingleDecision (decodeEntries ).orElse (null );
61+ final var multiDecision = bestMultiDecision (decodeEntries ).orElse (null );
62+
63+ if (singleDecision == null && multiDecision == null ) {
64+ throw new IllegalStateException ("Unable to split entry set" );
65+ }
66+
67+ if (multiDecision == null
68+ || (singleDecision != null && singleDecision .left () < multiDecision .left ())) {
69+
70+ return makeConditionNode (decodeEntries ,
71+ Objects .requireNonNull (singleDecision ).right ().pattern ());
72+ }
73+
74+ return makeMultiDecisionNode (decodeEntries , multiDecision .right ());
75+ }
76+
77+ private Pair <Double , SingleSplitEntrySet > calculateSingleCost (DecodeEntries entries ,
78+ BitPattern p ) {
79+ final SingleSplitEntrySet split = split (entries , p );
80+
81+ final double cost = calculateCost (List .of (split .matching (), split .others ()));
82+
83+ // Add a memory penalty
84+ final int s = split .matching ().size () + split .others ().size () - 1 ;
85+ final double meRatio = s / (entries .entries ().size () - 1.0 );
86+ final double penalty = memoryPenalty * (Math .log (meRatio ) / Math .log (2 ));
87+
88+ return new Pair <>(cost + penalty , split );
89+ }
90+
91+ private Pair <Double , MultiSplitEntrySet > calculateMultiCost (DecodeEntries entries ,
92+ BitVector mask ) {
93+ final MultiSplitEntrySet split = split (entries , mask );
94+
95+ final double cost = calculateCost (split .entries ().values ());
96+
97+ // Add a memory penalty
98+ final int m = mask .toValue ().bitCount ();
99+ final double s = split .entries ().values ().stream ()
100+ .filter (e -> !e .isEmpty ())
101+ .reduce (0 , (acc , e ) -> acc + (e .size () - 1 ), Integer ::sum )
102+ + 1 + Math .pow (2 , m );
103+
104+ final double meRatio = s / (entries .entries ().size () - 1.0 );
105+ final double penalty = memoryPenalty * (Math .log (meRatio ) / Math .log (2 ));
106+
107+ return new Pair <>(cost + penalty , split );
108+ }
109+
110+ private double calculateCost (Collection <List <DecodeEntry >> splits ) {
111+ return 1 + splits .stream ()
112+ .reduce (0.0 , (acc , e ) -> {
113+ double prob = e .stream ()
114+ .map (DecodeEntry ::occurrenceProbability ).reduce (0.0 , Double ::sum );
115+ return acc + prob * huffmanTreeHeight (e );
116+ }, Double ::sum );
117+ }
118+
119+ private Optional <Pair <Double , SingleSplitEntrySet >> bestSingleDecision (
120+ DecodeEntries decodeEntries ) {
121+ return findBestPattern (decodeEntries .checkedBits (), p -> {
122+ var split = calculateSingleCost (decodeEntries , p );
123+ if (split .right ().matching ().isEmpty () || split .right ().others ().isEmpty ()) {
124+ return Optional .empty ();
125+ }
126+ return Optional .of (split .left ());
127+ })
128+ .map (decision -> calculateSingleCost (decodeEntries , decision ));
129+ }
130+
131+ private Optional <Pair <Double , MultiSplitEntrySet >> bestMultiDecision (
132+ DecodeEntries decodeEntries ) {
133+
134+ final var candidates = baseMaskCandidates (decodeEntries .checkedBits ());
135+
136+ if (candidates .isEmpty ()) {
137+ return Optional .empty ();
138+ }
139+
140+ final Pair <Double , MultiSplitEntrySet > base = candidates .stream ()
141+ .map (m -> {
142+ var split = calculateMultiCost (decodeEntries , m );
143+ var hasDecision = split .right ().entries ().values ().stream ()
144+ .filter (es -> !es .isEmpty ())
145+ .count () > 1 ;
146+ return hasDecision ? split : null ;
147+ })
148+ .filter (Objects ::nonNull )
149+ .min (Comparator .comparing (Pair ::left )).orElse (null );
150+
151+ final var checked = decodeEntries .checkedBits ().toMaskVector ();
152+
153+ Pair <Double , MultiSplitEntrySet > prev = Objects .requireNonNull (base );
154+ Pair <Double , MultiSplitEntrySet > next = prev ;
155+ do {
156+
157+ prev = next ;
158+
159+ // candidate bits to try next
160+ var currentBase = prev .right ().mask ();
161+ var bits = checked .not ().and (currentBase .not ()).toValue ();
162+
163+ while (bits .getLowestSetBit () >= 0 ) {
164+ int i = bits .getLowestSetBit ();
165+
166+ var candidate = BitVector .fromValue (currentBase .toValue ().setBit (i ), currentBase .width ());
167+ var cost = calculateMultiCost (decodeEntries , candidate );
168+
169+ if (cost .left () < prev .left ()) {
170+ next = cost ;
171+ }
172+
173+ bits = bits .clearBit (i );
174+ }
175+
176+ } while (next .left () < prev .left ());
177+
178+ return Optional .of (prev );
42179 }
43180
181+ /**
182+ * Search for a minimal cost pattern. To not enumerate all possibilities (3^n) we grow the
183+ * pattern bit by bit as long as the cost improves.
184+ *
185+ * @param base the bit pattern to start with / or grow from.
186+ * @param costFunction the function calculating the cost of choosing a pattern.
187+ * @return all candidates of the base pattern extended with an additional decision bit.
188+ */
189+ private Optional <BitPattern > findBestPattern (final BitPattern base ,
190+ final Function <BitPattern ,
191+ Optional <Double >> costFunction ) {
192+
193+ boolean initialized = false ;
194+
195+ BitPattern prev ;
196+ double prevCost ;
197+
198+ BitPattern next = base ;
199+ double nextCost = Double .POSITIVE_INFINITY ;
200+
201+ do {
202+
203+ // Move current best to the 'previous' values
204+ prev = next ;
205+ prevCost = nextCost ;
206+
207+ // Find next best
208+ var candidates = patternCandidates (prev );
209+
210+ if (!initialized && !candidates .hasNext ()) {
211+ return Optional .empty ();
212+ }
213+ initialized = true ;
214+
215+ while (candidates .hasNext ()) {
216+ var candidate = candidates .next ();
217+ var cost = costFunction .apply (candidate );
218+
219+ if (cost .isPresent () && cost .get () < nextCost ) {
220+ next = candidate ;
221+ nextCost = cost .get ();
222+ }
223+ }
224+
225+ } while (nextCost < prevCost );
226+
227+ return Optional .of (prev );
228+ }
229+
230+ /**
231+ * Generator for splitting pattern candidates by growing the base pattern with an additional bit.
232+ *
233+ * @param base the pattern from which to grow the splitting patterns.
234+ * @return all candidates of the base pattern extended with an additional decision bit.
235+ */
236+ private Iterator <BitPattern > patternCandidates (BitPattern base ) {
237+
238+ return new Iterator <>() {
239+
240+ private BigInteger candidates = base .toMaskVector ().not ().toValue ();
241+ private boolean first = true ;
242+
243+ @ Override
244+ public boolean hasNext () {
245+ return !BigInteger .ZERO .equals (candidates ) || !first ;
246+ }
247+
248+ @ Override
249+ public BitPattern next () {
250+ if (!hasNext ()) {
251+ throw new NoSuchElementException ();
252+ }
253+
254+ final int decisionBit = candidates .getLowestSetBit ();
255+
256+ final var mask = base .toMaskVector ().toValue ().setBit (decisionBit );
257+ final var mVector = BitVector .fromValue (mask , base .width ());
258+
259+ if (first ) {
260+ // Flip the lowest candidate bit to 'one' only in the bitmask
261+ first = false ;
262+ return BitPattern .fromBitVector (mVector , base .toBitVector ());
263+ }
264+
265+ // Flip the lowest candidate bit to 'one' in the base pattern
266+ var value = base .toBitVector ().toValue ().setBit (decisionBit );
267+ var vVector = BitVector .fromValue (value , base .width ());
268+
269+ // Clear the candidate bit
270+ candidates = candidates .clearBit (decisionBit );
271+ first = true ;
272+
273+ return BitPattern .fromBitVector (mVector , vVector );
274+ }
275+ };
276+ }
277+
278+ /**
279+ * Generator for all relevant 2-bit mask candidates, given the already known bits in the base
280+ * pattern.
281+ *
282+ * @param base The base pattern, specifying known bits.
283+ * @return all 2-bit mask candidates
284+ */
285+ private List <BitVector > baseMaskCandidates (final BitPattern base ) {
286+ final int w = base .width ();
287+ final List <BitVector > result = new ArrayList <>();
288+ for (int i = w - 1 ; i > 0 ; i --) {
289+ if (base .get (i ).getValue () != DONT_CARE || base .get (i - 1 ).getValue () != DONT_CARE ) {
290+ continue ;
291+ }
292+ var maskValue = BigInteger .ZERO .setBit (w - i ).setBit (w - i - 1 );
293+ result .add (BitVector .fromValue (maskValue , w ));
294+ }
295+ return result ;
296+ }
297+
298+ private record HuffmanNode (double weight , int height )
299+ implements Comparable <HuffmanNode > {
300+
301+ static HuffmanNode of (DecodeEntry entry ) {
302+ return new HuffmanNode (entry .occurrenceProbability (), 0 );
303+ }
304+
305+ @ Override
306+ public int compareTo (@ Nonnull HuffmanNode o ) {
307+ return Double .compare (weight , o .weight );
308+ }
309+
310+ public HuffmanNode merge (HuffmanNode node ) {
311+ final var w = weight + node .weight ;
312+ final var h = 1 + Math .max (height , node .height );
313+ return new HuffmanNode (w , h );
314+ }
315+ }
316+
317+ private static int huffmanTreeHeight (Collection <DecodeEntry > entries ) {
318+ if (entries .isEmpty ()) {
319+ return 0 ;
320+ }
321+
322+ final PriorityQueue <HuffmanNode > priorityQueue = new PriorityQueue <>(entries .size ());
323+ for (var e : entries ) {
324+ priorityQueue .add (HuffmanNode .of (e ));
325+ }
326+
327+ while (priorityQueue .size () > 1 ) {
328+ var a = priorityQueue .poll ();
329+ var b = priorityQueue .poll ();
330+ priorityQueue .add (a .merge (Objects .requireNonNull (b )));
331+ }
332+ return Objects .requireNonNull (priorityQueue .peek ()).height ();
333+ }
44334
45335}
0 commit comments