main
1// Copyright 2014-2022 Ulrich Kunitz. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package lzma
6
7import (
8 "errors"
9 "unicode"
10)
11
12// node represents a node in the binary tree.
13type node struct {
14 // x is the search value
15 x uint32
16 // p parent node
17 p uint32
18 // l left child
19 l uint32
20 // r right child
21 r uint32
22}
23
24// wordLen is the number of bytes represented by the v field of a node.
25const wordLen = 4
26
27// binTree supports the identification of the next operation based on a
28// binary tree.
29//
30// Nodes will be identified by their index into the ring buffer.
31type binTree struct {
32 dict *encoderDict
33 // ring buffer of nodes
34 node []node
35 // absolute offset of the entry for the next node. Position 4
36 // byte larger.
37 hoff int64
38 // front position in the node ring buffer
39 front uint32
40 // index of the root node
41 root uint32
42 // current x value
43 x uint32
44 // preallocated array
45 data []byte
46}
47
48// null represents the nonexistent index. We can't use zero because it
49// would always exist or we would need to decrease the index for each
50// reference.
51const null uint32 = 1<<32 - 1
52
53// newBinTree initializes the binTree structure. The capacity defines
54// the size of the buffer and defines the maximum distance for which
55// matches will be found.
56func newBinTree(capacity int) (t *binTree, err error) {
57 if capacity < 1 {
58 return nil, errors.New(
59 "newBinTree: capacity must be larger than zero")
60 }
61 if int64(capacity) >= int64(null) {
62 return nil, errors.New(
63 "newBinTree: capacity must less 2^{32}-1")
64 }
65 t = &binTree{
66 node: make([]node, capacity),
67 hoff: -int64(wordLen),
68 root: null,
69 data: make([]byte, maxMatchLen),
70 }
71 return t, nil
72}
73
74func (t *binTree) SetDict(d *encoderDict) { t.dict = d }
75
76// WriteByte writes a single byte into the binary tree.
77func (t *binTree) WriteByte(c byte) error {
78 t.x = (t.x << 8) | uint32(c)
79 t.hoff++
80 if t.hoff < 0 {
81 return nil
82 }
83 v := t.front
84 if int64(v) < t.hoff {
85 // We are overwriting old nodes stored in the tree.
86 t.remove(v)
87 }
88 t.node[v].x = t.x
89 t.add(v)
90 t.front++
91 if int64(t.front) >= int64(len(t.node)) {
92 t.front = 0
93 }
94 return nil
95}
96
97// Writes writes a sequence of bytes into the binTree structure.
98func (t *binTree) Write(p []byte) (n int, err error) {
99 for _, c := range p {
100 t.WriteByte(c)
101 }
102 return len(p), nil
103}
104
105// add puts the node v into the tree. The node must not be part of the
106// tree before.
107func (t *binTree) add(v uint32) {
108 vn := &t.node[v]
109 // Set left and right to null indices.
110 vn.l, vn.r = null, null
111 // If the binary tree is empty make v the root.
112 if t.root == null {
113 t.root = v
114 vn.p = null
115 return
116 }
117 x := vn.x
118 p := t.root
119 // Search for the right leave link and add the new node.
120 for {
121 pn := &t.node[p]
122 if x <= pn.x {
123 if pn.l == null {
124 pn.l = v
125 vn.p = p
126 return
127 }
128 p = pn.l
129 } else {
130 if pn.r == null {
131 pn.r = v
132 vn.p = p
133 return
134 }
135 p = pn.r
136 }
137 }
138}
139
140// parent returns the parent node index of v and the pointer to v value
141// in the parent.
142func (t *binTree) parent(v uint32) (p uint32, ptr *uint32) {
143 if t.root == v {
144 return null, &t.root
145 }
146 p = t.node[v].p
147 if t.node[p].l == v {
148 ptr = &t.node[p].l
149 } else {
150 ptr = &t.node[p].r
151 }
152 return
153}
154
155// Remove node v.
156func (t *binTree) remove(v uint32) {
157 vn := &t.node[v]
158 p, ptr := t.parent(v)
159 l, r := vn.l, vn.r
160 if l == null {
161 // Move the right child up.
162 *ptr = r
163 if r != null {
164 t.node[r].p = p
165 }
166 return
167 }
168 if r == null {
169 // Move the left child up.
170 *ptr = l
171 t.node[l].p = p
172 return
173 }
174
175 // Search the in-order predecessor u.
176 un := &t.node[l]
177 ur := un.r
178 if ur == null {
179 // In order predecessor is l. Move it up.
180 un.r = r
181 t.node[r].p = l
182 un.p = p
183 *ptr = l
184 return
185 }
186 var u uint32
187 for {
188 // Look for the max value in the tree where l is root.
189 u = ur
190 ur = t.node[u].r
191 if ur == null {
192 break
193 }
194 }
195 // replace u with ul
196 un = &t.node[u]
197 ul := un.l
198 up := un.p
199 t.node[up].r = ul
200 if ul != null {
201 t.node[ul].p = up
202 }
203
204 // replace v by u
205 un.l, un.r = l, r
206 t.node[l].p = u
207 t.node[r].p = u
208 *ptr = u
209 un.p = p
210}
211
212// search looks for the node that have the value x or for the nodes that
213// brace it. The node highest in the tree with the value x will be
214// returned. All other nodes with the same value live in left subtree of
215// the returned node.
216func (t *binTree) search(v uint32, x uint32) (a, b uint32) {
217 a, b = null, null
218 if v == null {
219 return
220 }
221 for {
222 vn := &t.node[v]
223 if x <= vn.x {
224 if x == vn.x {
225 return v, v
226 }
227 b = v
228 if vn.l == null {
229 return
230 }
231 v = vn.l
232 } else {
233 a = v
234 if vn.r == null {
235 return
236 }
237 v = vn.r
238 }
239 }
240}
241
242// max returns the node with maximum value in the subtree with v as
243// root.
244func (t *binTree) max(v uint32) uint32 {
245 if v == null {
246 return null
247 }
248 for {
249 r := t.node[v].r
250 if r == null {
251 return v
252 }
253 v = r
254 }
255}
256
257// min returns the node with the minimum value in the subtree with v as
258// root.
259func (t *binTree) min(v uint32) uint32 {
260 if v == null {
261 return null
262 }
263 for {
264 l := t.node[v].l
265 if l == null {
266 return v
267 }
268 v = l
269 }
270}
271
272// pred returns the in-order predecessor of node v.
273func (t *binTree) pred(v uint32) uint32 {
274 if v == null {
275 return null
276 }
277 u := t.max(t.node[v].l)
278 if u != null {
279 return u
280 }
281 for {
282 p := t.node[v].p
283 if p == null {
284 return null
285 }
286 if t.node[p].r == v {
287 return p
288 }
289 v = p
290 }
291}
292
293// succ returns the in-order successor of node v.
294func (t *binTree) succ(v uint32) uint32 {
295 if v == null {
296 return null
297 }
298 u := t.min(t.node[v].r)
299 if u != null {
300 return u
301 }
302 for {
303 p := t.node[v].p
304 if p == null {
305 return null
306 }
307 if t.node[p].l == v {
308 return p
309 }
310 v = p
311 }
312}
313
314// xval converts the first four bytes of a into an 32-bit unsigned
315// integer in big-endian order.
316func xval(a []byte) uint32 {
317 var x uint32
318 switch len(a) {
319 default:
320 x |= uint32(a[3])
321 fallthrough
322 case 3:
323 x |= uint32(a[2]) << 8
324 fallthrough
325 case 2:
326 x |= uint32(a[1]) << 16
327 fallthrough
328 case 1:
329 x |= uint32(a[0]) << 24
330 case 0:
331 }
332 return x
333}
334
335// dumpX converts value x into a four-letter string.
336func dumpX(x uint32) string {
337 a := make([]byte, 4)
338 for i := 0; i < 4; i++ {
339 c := byte(x >> uint((3-i)*8))
340 if unicode.IsGraphic(rune(c)) {
341 a[i] = c
342 } else {
343 a[i] = '.'
344 }
345 }
346 return string(a)
347}
348
349/*
350// dumpNode writes a representation of the node v into the io.Writer.
351func (t *binTree) dumpNode(w io.Writer, v uint32, indent int) {
352 if v == null {
353 return
354 }
355
356 vn := &t.node[v]
357
358 t.dumpNode(w, vn.r, indent+2)
359
360 for i := 0; i < indent; i++ {
361 fmt.Fprint(w, " ")
362 }
363 if vn.p == null {
364 fmt.Fprintf(w, "node %d %q parent null\n", v, dumpX(vn.x))
365 } else {
366 fmt.Fprintf(w, "node %d %q parent %d\n", v, dumpX(vn.x), vn.p)
367 }
368
369 t.dumpNode(w, vn.l, indent+2)
370}
371
372// dump prints a representation of the binary tree into the writer.
373func (t *binTree) dump(w io.Writer) error {
374 bw := bufio.NewWriter(w)
375 t.dumpNode(bw, t.root, 0)
376 return bw.Flush()
377}
378*/
379
380func (t *binTree) distance(v uint32) int {
381 dist := int(t.front) - int(v)
382 if dist <= 0 {
383 dist += len(t.node)
384 }
385 return dist
386}
387
388type matchParams struct {
389 rep [4]uint32
390 // length when match will be accepted
391 nAccept int
392 // nodes to check
393 check int
394 // finish if length get shorter
395 stopShorter bool
396}
397
398func (t *binTree) match(m match, distIter func() (int, bool), p matchParams,
399) (r match, checked int, accepted bool) {
400 buf := &t.dict.buf
401 for {
402 if checked >= p.check {
403 return m, checked, true
404 }
405 dist, ok := distIter()
406 if !ok {
407 return m, checked, false
408 }
409 checked++
410 if m.n > 0 {
411 i := buf.rear - dist + m.n - 1
412 if i < 0 {
413 i += len(buf.data)
414 } else if i >= len(buf.data) {
415 i -= len(buf.data)
416 }
417 if buf.data[i] != t.data[m.n-1] {
418 if p.stopShorter {
419 return m, checked, false
420 }
421 continue
422 }
423 }
424 n := buf.matchLen(dist, t.data)
425 switch n {
426 case 0:
427 if p.stopShorter {
428 return m, checked, false
429 }
430 continue
431 case 1:
432 if uint32(dist-minDistance) != p.rep[0] {
433 continue
434 }
435 }
436 if n < m.n || (n == m.n && int64(dist) >= m.distance) {
437 continue
438 }
439 m = match{int64(dist), n}
440 if n >= p.nAccept {
441 return m, checked, true
442 }
443 }
444}
445
446func (t *binTree) NextOp(rep [4]uint32) operation {
447 // retrieve maxMatchLen data
448 n, _ := t.dict.buf.Peek(t.data[:maxMatchLen])
449 if n == 0 {
450 panic("no data in buffer")
451 }
452 t.data = t.data[:n]
453
454 var (
455 m match
456 x, u, v uint32
457 iterPred, iterSucc func() (int, bool)
458 )
459 p := matchParams{
460 rep: rep,
461 nAccept: maxMatchLen,
462 check: 32,
463 }
464 i := 4
465 iterSmall := func() (dist int, ok bool) {
466 i--
467 if i <= 0 {
468 return 0, false
469 }
470 return i, true
471 }
472 m, checked, accepted := t.match(m, iterSmall, p)
473 if accepted {
474 goto end
475 }
476 p.check -= checked
477 x = xval(t.data)
478 u, v = t.search(t.root, x)
479 if u == v && len(t.data) == 4 {
480 iter := func() (dist int, ok bool) {
481 if u == null {
482 return 0, false
483 }
484 dist = t.distance(u)
485 u, v = t.search(t.node[u].l, x)
486 if u != v {
487 u = null
488 }
489 return dist, true
490 }
491 m, _, _ = t.match(m, iter, p)
492 goto end
493 }
494 p.stopShorter = true
495 iterSucc = func() (dist int, ok bool) {
496 if v == null {
497 return 0, false
498 }
499 dist = t.distance(v)
500 v = t.succ(v)
501 return dist, true
502 }
503 m, checked, accepted = t.match(m, iterSucc, p)
504 if accepted {
505 goto end
506 }
507 p.check -= checked
508 iterPred = func() (dist int, ok bool) {
509 if u == null {
510 return 0, false
511 }
512 dist = t.distance(u)
513 u = t.pred(u)
514 return dist, true
515 }
516 m, _, _ = t.match(m, iterPred, p)
517end:
518 if m.n == 0 {
519 return lit{t.data[0]}
520 }
521 return m
522}