main
Raw Download raw file
  1// Copyright 2014-2022 Ulrich Kunitz. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package xz
  6
  7import (
  8	"bytes"
  9	"crypto/sha256"
 10	"errors"
 11	"fmt"
 12	"hash"
 13	"hash/crc32"
 14	"io"
 15
 16	"github.com/ulikunitz/xz/lzma"
 17)
 18
 19// allZeros checks whether a given byte slice has only zeros.
 20func allZeros(p []byte) bool {
 21	for _, c := range p {
 22		if c != 0 {
 23			return false
 24		}
 25	}
 26	return true
 27}
 28
 29// padLen returns the length of the padding required for the given
 30// argument.
 31func padLen(n int64) int {
 32	k := int(n % 4)
 33	if k > 0 {
 34		k = 4 - k
 35	}
 36	return k
 37}
 38
 39/*** Header ***/
 40
 41// headerMagic stores the magic bytes for the header
 42var headerMagic = []byte{0xfd, '7', 'z', 'X', 'Z', 0x00}
 43
 44// HeaderLen provides the length of the xz file header.
 45const HeaderLen = 12
 46
 47// Constants for the checksum methods supported by xz.
 48const (
 49	None   byte = 0x0
 50	CRC32  byte = 0x1
 51	CRC64  byte = 0x4
 52	SHA256 byte = 0xa
 53)
 54
 55// errInvalidFlags indicates that flags are invalid.
 56var errInvalidFlags = errors.New("xz: invalid flags")
 57
 58// verifyFlags returns the error errInvalidFlags if the value is
 59// invalid.
 60func verifyFlags(flags byte) error {
 61	switch flags {
 62	case None, CRC32, CRC64, SHA256:
 63		return nil
 64	default:
 65		return errInvalidFlags
 66	}
 67}
 68
 69// flagstrings maps flag values to strings.
 70var flagstrings = map[byte]string{
 71	None:   "None",
 72	CRC32:  "CRC-32",
 73	CRC64:  "CRC-64",
 74	SHA256: "SHA-256",
 75}
 76
 77// flagString returns the string representation for the given flags.
 78func flagString(flags byte) string {
 79	s, ok := flagstrings[flags]
 80	if !ok {
 81		return "invalid"
 82	}
 83	return s
 84}
 85
 86// newHashFunc returns a function that creates hash instances for the
 87// hash method encoded in flags.
 88func newHashFunc(flags byte) (newHash func() hash.Hash, err error) {
 89	switch flags {
 90	case None:
 91		newHash = newNoneHash
 92	case CRC32:
 93		newHash = newCRC32
 94	case CRC64:
 95		newHash = newCRC64
 96	case SHA256:
 97		newHash = sha256.New
 98	default:
 99		err = errInvalidFlags
100	}
101	return
102}
103
104// header provides the actual content of the xz file header: the flags.
105type header struct {
106	flags byte
107}
108
109// Errors returned by readHeader.
110var errHeaderMagic = errors.New("xz: invalid header magic bytes")
111
112// ValidHeader checks whether data is a correct xz file header. The
113// length of data must be HeaderLen.
114func ValidHeader(data []byte) bool {
115	var h header
116	err := h.UnmarshalBinary(data)
117	return err == nil
118}
119
120// String returns a string representation of the flags.
121func (h header) String() string {
122	return flagString(h.flags)
123}
124
125// UnmarshalBinary reads header from the provided data slice.
126func (h *header) UnmarshalBinary(data []byte) error {
127	// header length
128	if len(data) != HeaderLen {
129		return errors.New("xz: wrong file header length")
130	}
131
132	// magic header
133	if !bytes.Equal(headerMagic, data[:6]) {
134		return errHeaderMagic
135	}
136
137	// checksum
138	crc := crc32.NewIEEE()
139	crc.Write(data[6:8])
140	if uint32LE(data[8:]) != crc.Sum32() {
141		return errors.New("xz: invalid checksum for file header")
142	}
143
144	// stream flags
145	if data[6] != 0 {
146		return errInvalidFlags
147	}
148	flags := data[7]
149	if err := verifyFlags(flags); err != nil {
150		return err
151	}
152
153	h.flags = flags
154	return nil
155}
156
157// MarshalBinary generates the xz file header.
158func (h *header) MarshalBinary() (data []byte, err error) {
159	if err = verifyFlags(h.flags); err != nil {
160		return nil, err
161	}
162
163	data = make([]byte, 12)
164	copy(data, headerMagic)
165	data[7] = h.flags
166
167	crc := crc32.NewIEEE()
168	crc.Write(data[6:8])
169	putUint32LE(data[8:], crc.Sum32())
170
171	return data, nil
172}
173
174/*** Footer ***/
175
176// footerLen defines the length of the footer.
177const footerLen = 12
178
179// footerMagic contains the footer magic bytes.
180var footerMagic = []byte{'Y', 'Z'}
181
182// footer represents the content of the xz file footer.
183type footer struct {
184	indexSize int64
185	flags     byte
186}
187
188// String prints a string representation of the footer structure.
189func (f footer) String() string {
190	return fmt.Sprintf("%s index size %d", flagString(f.flags), f.indexSize)
191}
192
193// Minimum and maximum for the size of the index (backward size).
194const (
195	minIndexSize = 4
196	maxIndexSize = (1 << 32) * 4
197)
198
199// MarshalBinary converts footer values into an xz file footer. Note
200// that the footer value is checked for correctness.
201func (f *footer) MarshalBinary() (data []byte, err error) {
202	if err = verifyFlags(f.flags); err != nil {
203		return nil, err
204	}
205	if !(minIndexSize <= f.indexSize && f.indexSize <= maxIndexSize) {
206		return nil, errors.New("xz: index size out of range")
207	}
208	if f.indexSize%4 != 0 {
209		return nil, errors.New(
210			"xz: index size not aligned to four bytes")
211	}
212
213	data = make([]byte, footerLen)
214
215	// backward size (index size)
216	s := (f.indexSize / 4) - 1
217	putUint32LE(data[4:], uint32(s))
218	// flags
219	data[9] = f.flags
220	// footer magic
221	copy(data[10:], footerMagic)
222
223	// CRC-32
224	crc := crc32.NewIEEE()
225	crc.Write(data[4:10])
226	putUint32LE(data, crc.Sum32())
227
228	return data, nil
229}
230
231// UnmarshalBinary sets the footer value by unmarshalling an xz file
232// footer.
233func (f *footer) UnmarshalBinary(data []byte) error {
234	if len(data) != footerLen {
235		return errors.New("xz: wrong footer length")
236	}
237
238	// magic bytes
239	if !bytes.Equal(data[10:], footerMagic) {
240		return errors.New("xz: footer magic invalid")
241	}
242
243	// CRC-32
244	crc := crc32.NewIEEE()
245	crc.Write(data[4:10])
246	if uint32LE(data) != crc.Sum32() {
247		return errors.New("xz: footer checksum error")
248	}
249
250	var g footer
251	// backward size (index size)
252	g.indexSize = (int64(uint32LE(data[4:])) + 1) * 4
253
254	// flags
255	if data[8] != 0 {
256		return errInvalidFlags
257	}
258	g.flags = data[9]
259	if err := verifyFlags(g.flags); err != nil {
260		return err
261	}
262
263	*f = g
264	return nil
265}
266
267/*** Block Header ***/
268
269// blockHeader represents the content of an xz block header.
270type blockHeader struct {
271	compressedSize   int64
272	uncompressedSize int64
273	filters          []filter
274}
275
276// String converts the block header into a string.
277func (h blockHeader) String() string {
278	var buf bytes.Buffer
279	first := true
280	if h.compressedSize >= 0 {
281		fmt.Fprintf(&buf, "compressed size %d", h.compressedSize)
282		first = false
283	}
284	if h.uncompressedSize >= 0 {
285		if !first {
286			buf.WriteString(" ")
287		}
288		fmt.Fprintf(&buf, "uncompressed size %d", h.uncompressedSize)
289		first = false
290	}
291	for _, f := range h.filters {
292		if !first {
293			buf.WriteString(" ")
294		}
295		fmt.Fprintf(&buf, "filter %s", f)
296		first = false
297	}
298	return buf.String()
299}
300
301// Masks for the block flags.
302const (
303	filterCountMask         = 0x03
304	compressedSizePresent   = 0x40
305	uncompressedSizePresent = 0x80
306	reservedBlockFlags      = 0x3C
307)
308
309// errIndexIndicator signals that an index indicator (0x00) has been found
310// instead of an expected block header indicator.
311var errIndexIndicator = errors.New("xz: found index indicator")
312
313// readBlockHeader reads the block header.
314func readBlockHeader(r io.Reader) (h *blockHeader, n int, err error) {
315	var buf bytes.Buffer
316	buf.Grow(20)
317
318	// block header size
319	z, err := io.CopyN(&buf, r, 1)
320	n = int(z)
321	if err != nil {
322		return nil, n, err
323	}
324	s := buf.Bytes()[0]
325	if s == 0 {
326		return nil, n, errIndexIndicator
327	}
328
329	// read complete header
330	headerLen := (int(s) + 1) * 4
331	buf.Grow(headerLen - 1)
332	z, err = io.CopyN(&buf, r, int64(headerLen-1))
333	n += int(z)
334	if err != nil {
335		return nil, n, err
336	}
337
338	// unmarshal block header
339	h = new(blockHeader)
340	if err = h.UnmarshalBinary(buf.Bytes()); err != nil {
341		return nil, n, err
342	}
343
344	return h, n, nil
345}
346
347// readSizeInBlockHeader reads the uncompressed or compressed size
348// fields in the block header. The present value informs the function
349// whether the respective field is actually present in the header.
350func readSizeInBlockHeader(r io.ByteReader, present bool) (n int64, err error) {
351	if !present {
352		return -1, nil
353	}
354	x, _, err := readUvarint(r)
355	if err != nil {
356		return 0, err
357	}
358	if x >= 1<<63 {
359		return 0, errors.New("xz: size overflow in block header")
360	}
361	return int64(x), nil
362}
363
364// UnmarshalBinary unmarshals the block header.
365func (h *blockHeader) UnmarshalBinary(data []byte) error {
366	// Check header length
367	s := data[0]
368	if data[0] == 0 {
369		return errIndexIndicator
370	}
371	headerLen := (int(s) + 1) * 4
372	if len(data) != headerLen {
373		return fmt.Errorf("xz: data length %d; want %d", len(data),
374			headerLen)
375	}
376	n := headerLen - 4
377
378	// Check CRC-32
379	crc := crc32.NewIEEE()
380	crc.Write(data[:n])
381	if crc.Sum32() != uint32LE(data[n:]) {
382		return errors.New("xz: checksum error for block header")
383	}
384
385	// Block header flags
386	flags := data[1]
387	if flags&reservedBlockFlags != 0 {
388		return errors.New("xz: reserved block header flags set")
389	}
390
391	r := bytes.NewReader(data[2:n])
392
393	// Compressed size
394	var err error
395	h.compressedSize, err = readSizeInBlockHeader(
396		r, flags&compressedSizePresent != 0)
397	if err != nil {
398		return err
399	}
400
401	// Uncompressed size
402	h.uncompressedSize, err = readSizeInBlockHeader(
403		r, flags&uncompressedSizePresent != 0)
404	if err != nil {
405		return err
406	}
407
408	h.filters, err = readFilters(r, int(flags&filterCountMask)+1)
409	if err != nil {
410		return err
411	}
412
413	// Check padding
414	// Since headerLen is a multiple of 4 we don't need to check
415	// alignment.
416	k := r.Len()
417	// The standard spec says that the padding should have not more
418	// than 3 bytes. However we found paddings of 4 or 5 in the
419	// wild. See https://github.com/ulikunitz/xz/pull/11 and
420	// https://github.com/ulikunitz/xz/issues/15
421	//
422	// The only reasonable approach seems to be to ignore the
423	// padding size. We still check that all padding bytes are zero.
424	if !allZeros(data[n-k : n]) {
425		return errPadding
426	}
427	return nil
428}
429
430// MarshalBinary marshals the binary header.
431func (h *blockHeader) MarshalBinary() (data []byte, err error) {
432	if !(minFilters <= len(h.filters) && len(h.filters) <= maxFilters) {
433		return nil, errors.New("xz: filter count wrong")
434	}
435	for i, f := range h.filters {
436		if i < len(h.filters)-1 {
437			if f.id() == lzmaFilterID {
438				return nil, errors.New(
439					"xz: LZMA2 filter is not the last")
440			}
441		} else {
442			// last filter
443			if f.id() != lzmaFilterID {
444				return nil, errors.New("xz: " +
445					"last filter must be the LZMA2 filter")
446			}
447		}
448	}
449
450	var buf bytes.Buffer
451	// header size must set at the end
452	buf.WriteByte(0)
453
454	// flags
455	flags := byte(len(h.filters) - 1)
456	if h.compressedSize >= 0 {
457		flags |= compressedSizePresent
458	}
459	if h.uncompressedSize >= 0 {
460		flags |= uncompressedSizePresent
461	}
462	buf.WriteByte(flags)
463
464	p := make([]byte, 10)
465	if h.compressedSize >= 0 {
466		k := putUvarint(p, uint64(h.compressedSize))
467		buf.Write(p[:k])
468	}
469	if h.uncompressedSize >= 0 {
470		k := putUvarint(p, uint64(h.uncompressedSize))
471		buf.Write(p[:k])
472	}
473
474	for _, f := range h.filters {
475		fp, err := f.MarshalBinary()
476		if err != nil {
477			return nil, err
478		}
479		buf.Write(fp)
480	}
481
482	// padding
483	for i := padLen(int64(buf.Len())); i > 0; i-- {
484		buf.WriteByte(0)
485	}
486
487	// crc place holder
488	buf.Write(p[:4])
489
490	data = buf.Bytes()
491	if len(data)%4 != 0 {
492		panic("data length not aligned")
493	}
494	s := len(data)/4 - 1
495	if !(1 < s && s <= 255) {
496		panic("wrong block header size")
497	}
498	data[0] = byte(s)
499
500	crc := crc32.NewIEEE()
501	crc.Write(data[:len(data)-4])
502	putUint32LE(data[len(data)-4:], crc.Sum32())
503
504	return data, nil
505}
506
507// Constants used for marshalling and unmarshalling filters in the xz
508// block header.
509const (
510	minFilters    = 1
511	maxFilters    = 4
512	minReservedID = 1 << 62
513)
514
515// filter represents a filter in the block header.
516type filter interface {
517	id() uint64
518	UnmarshalBinary(data []byte) error
519	MarshalBinary() (data []byte, err error)
520	reader(r io.Reader, c *ReaderConfig) (fr io.Reader, err error)
521	writeCloser(w io.WriteCloser, c *WriterConfig) (fw io.WriteCloser, err error)
522	// filter must be last filter
523	last() bool
524}
525
526// readFilter reads a block filter from the block header. At this point
527// in time only the LZMA2 filter is supported.
528func readFilter(r io.Reader) (f filter, err error) {
529	br := lzma.ByteReader(r)
530
531	// index
532	id, _, err := readUvarint(br)
533	if err != nil {
534		return nil, err
535	}
536
537	var data []byte
538	switch id {
539	case lzmaFilterID:
540		data = make([]byte, lzmaFilterLen)
541		data[0] = lzmaFilterID
542		if _, err = io.ReadFull(r, data[1:]); err != nil {
543			return nil, err
544		}
545		f = new(lzmaFilter)
546	default:
547		if id >= minReservedID {
548			return nil, errors.New(
549				"xz: reserved filter id in block stream header")
550		}
551		return nil, errors.New("xz: invalid filter id")
552	}
553	if err = f.UnmarshalBinary(data); err != nil {
554		return nil, err
555	}
556	return f, err
557}
558
559// readFilters reads count filters. At this point in time only the count
560// 1 is supported.
561func readFilters(r io.Reader, count int) (filters []filter, err error) {
562	if count != 1 {
563		return nil, errors.New("xz: unsupported filter count")
564	}
565	f, err := readFilter(r)
566	if err != nil {
567		return nil, err
568	}
569	return []filter{f}, err
570}
571
572/*** Index ***/
573
574// record describes a block in the xz file index.
575type record struct {
576	unpaddedSize     int64
577	uncompressedSize int64
578}
579
580// readRecord reads an index record.
581func readRecord(r io.ByteReader) (rec record, n int, err error) {
582	u, k, err := readUvarint(r)
583	n += k
584	if err != nil {
585		return rec, n, err
586	}
587	rec.unpaddedSize = int64(u)
588	if rec.unpaddedSize < 0 {
589		return rec, n, errors.New("xz: unpadded size negative")
590	}
591
592	u, k, err = readUvarint(r)
593	n += k
594	if err != nil {
595		return rec, n, err
596	}
597	rec.uncompressedSize = int64(u)
598	if rec.uncompressedSize < 0 {
599		return rec, n, errors.New("xz: uncompressed size negative")
600	}
601
602	return rec, n, nil
603}
604
605// MarshalBinary converts an index record in its binary encoding.
606func (rec *record) MarshalBinary() (data []byte, err error) {
607	// maximum length of a uvarint is 10
608	p := make([]byte, 20)
609	n := putUvarint(p, uint64(rec.unpaddedSize))
610	n += putUvarint(p[n:], uint64(rec.uncompressedSize))
611	return p[:n], nil
612}
613
614// writeIndex writes the index, a sequence of records.
615func writeIndex(w io.Writer, index []record) (n int64, err error) {
616	crc := crc32.NewIEEE()
617	mw := io.MultiWriter(w, crc)
618
619	// index indicator
620	k, err := mw.Write([]byte{0})
621	n += int64(k)
622	if err != nil {
623		return n, err
624	}
625
626	// number of records
627	p := make([]byte, 10)
628	k = putUvarint(p, uint64(len(index)))
629	k, err = mw.Write(p[:k])
630	n += int64(k)
631	if err != nil {
632		return n, err
633	}
634
635	// list of records
636	for _, rec := range index {
637		p, err := rec.MarshalBinary()
638		if err != nil {
639			return n, err
640		}
641		k, err = mw.Write(p)
642		n += int64(k)
643		if err != nil {
644			return n, err
645		}
646	}
647
648	// index padding
649	k, err = mw.Write(make([]byte, padLen(int64(n))))
650	n += int64(k)
651	if err != nil {
652		return n, err
653	}
654
655	// crc32 checksum
656	putUint32LE(p, crc.Sum32())
657	k, err = w.Write(p[:4])
658	n += int64(k)
659
660	return n, err
661}
662
663// readIndexBody reads the index from the reader. It assumes that the
664// index indicator has already been read.
665func readIndexBody(r io.Reader, expectedRecordLen int) (records []record, n int64, err error) {
666	crc := crc32.NewIEEE()
667	// index indicator
668	crc.Write([]byte{0})
669
670	br := lzma.ByteReader(io.TeeReader(r, crc))
671
672	// number of records
673	u, k, err := readUvarint(br)
674	n += int64(k)
675	if err != nil {
676		return nil, n, err
677	}
678	recLen := int(u)
679	if recLen < 0 || uint64(recLen) != u {
680		return nil, n, errors.New("xz: record number overflow")
681	}
682	if recLen != expectedRecordLen {
683		return nil, n, fmt.Errorf(
684			"xz: index length is %d; want %d",
685			recLen, expectedRecordLen)
686	}
687
688	// list of records
689	records = make([]record, recLen)
690	for i := range records {
691		records[i], k, err = readRecord(br)
692		n += int64(k)
693		if err != nil {
694			return nil, n, err
695		}
696	}
697
698	p := make([]byte, padLen(int64(n+1)), 4)
699	k, err = io.ReadFull(br.(io.Reader), p)
700	n += int64(k)
701	if err != nil {
702		return nil, n, err
703	}
704	if !allZeros(p) {
705		return nil, n, errors.New("xz: non-zero byte in index padding")
706	}
707
708	// crc32
709	s := crc.Sum32()
710	p = p[:4]
711	k, err = io.ReadFull(br.(io.Reader), p)
712	n += int64(k)
713	if err != nil {
714		return records, n, err
715	}
716	if uint32LE(p) != s {
717		return nil, n, errors.New("xz: wrong checksum for index")
718	}
719
720	return records, n, nil
721}