main
1package wav
2
3import (
4 "encoding/binary"
5 "errors"
6 "fmt"
7 "io"
8 "os"
9 "slices"
10)
11
12type WAV struct {
13 SampleRate int
14 Channels int
15 BitsPerSample int
16 PCM []int16
17}
18
19var (
20 ErrNotWAV = errors.New("invalid WAV file")
21 ErrIncompatibleWAV = errors.New("incompatible WAV files")
22)
23
24const (
25 _bytesPerSample = 2 // int16 PCM
26)
27
28func findDataChunk(r io.ReadSeeker) (int, error) {
29 // Seek to offset 12 to start scanning chunks after RIFF header
30 _, err := r.Seek(12, io.SeekStart)
31 if err != nil {
32 return 0, err
33 }
34
35 var chunkHeader [8]byte
36 for {
37 _, err := io.ReadFull(r, chunkHeader[:])
38 if err != nil {
39 return 0, err
40 }
41 chunkID := string(chunkHeader[0:4])
42 chunkSize := int(binary.LittleEndian.Uint32(chunkHeader[4:8]))
43
44 if chunkID == "data" {
45 return chunkSize, nil
46 }
47
48 // Skip this chunk
49 _, err = r.Seek(int64(chunkSize+chunkSize%2), io.SeekCurrent) // word-align chunks
50 if err != nil {
51 return 0, err
52 }
53 }
54}
55
56func ReadWAV(path string) (*WAV, error) {
57
58 f, err := os.Open(path)
59 if err != nil {
60 return nil, fmt.Errorf("opening wav: %w", err)
61 }
62 defer f.Close()
63
64 header := make([]byte, 44)
65
66 _, err = io.ReadFull(f, header)
67 if err != nil {
68 return nil, fmt.Errorf("reading wav header: %w", err)
69 }
70
71 magic := string(header[0:4])
72 if magic != "RIFF" && magic != "WAVE" {
73 return nil, fmt.Errorf("magic=%q: %w", magic, ErrNotWAV)
74 }
75
76 channels := int(binary.LittleEndian.Uint16(header[22:24]))
77 sampleRate := int(binary.LittleEndian.Uint32(header[24:28]))
78 // TODO: whats at 28:34?
79 bitsPerSample := int(binary.LittleEndian.Uint16(header[34:36]))
80 // TODO: whats at 36:40?
81 dataLen, err := findDataChunk(f)
82 if err != nil {
83 return nil, fmt.Errorf("finding data chunk: %w", err)
84 }
85 sampleCount := dataLen / _bytesPerSample
86 pcm := make([]int16, sampleCount)
87
88 err = binary.Read(f, binary.LittleEndian, pcm)
89 if err != nil {
90 return nil, fmt.Errorf("reading wav data: %w", err)
91 }
92
93 return &WAV{
94 SampleRate: sampleRate,
95 Channels: channels,
96 BitsPerSample: bitsPerSample,
97 PCM: pcm,
98 }, nil
99}
100
101func (w *WAV) WriteFile(path string) error {
102 out, err := os.Create(path)
103 if err != nil {
104 return err
105 }
106 defer out.Close()
107
108 byteRate := w.SampleRate * w.Channels * _bytesPerSample
109 blockAlign := w.Channels * _bytesPerSample
110 dataSize := len(w.PCM) * _bytesPerSample
111 riffSize := 36 + dataSize
112
113 header := make([]byte, 44)
114 copy(header[0:], []byte("RIFF"))
115 binary.LittleEndian.PutUint32(header[4:], uint32(riffSize))
116 copy(header[8:], []byte("WAVEfmt "))
117 binary.LittleEndian.PutUint32(header[16:], 16) // fmt chunk size
118 binary.LittleEndian.PutUint16(header[20:], 1) // PCM format
119 binary.LittleEndian.PutUint16(header[22:], uint16(w.Channels))
120 binary.LittleEndian.PutUint32(header[24:], uint32(w.SampleRate))
121 binary.LittleEndian.PutUint32(header[28:], uint32(byteRate))
122 binary.LittleEndian.PutUint16(header[32:], uint16(blockAlign))
123 binary.LittleEndian.PutUint16(header[34:], uint16(w.BitsPerSample))
124 copy(header[36:], []byte("data"))
125 binary.LittleEndian.PutUint32(header[40:], uint32(dataSize))
126
127 _, err = out.Write(header)
128 if err != nil {
129 return err
130 }
131 return binary.Write(out, binary.LittleEndian, w.PCM)
132}
133
134func check(first, second *WAV) error {
135 if first.SampleRate != second.SampleRate {
136 return fmt.Errorf("sample rate a=%d b=%d: %w",
137 first.SampleRate,
138 second.SampleRate,
139 ErrIncompatibleWAV)
140 }
141 if first.BitsPerSample != second.BitsPerSample {
142 return fmt.Errorf("bits per sample a=%d b=%d: %w",
143 first.BitsPerSample,
144 second.BitsPerSample,
145 ErrIncompatibleWAV)
146 }
147 if first.Channels != second.Channels {
148 return fmt.Errorf("channels a=%d b=%d: %w",
149 first.Channels,
150 second.Channels,
151 ErrIncompatibleWAV)
152 }
153 return nil
154}
155
156func (first *WAV) Append(second *WAV) (*WAV, error) {
157 err := check(first, second)
158 if err != nil {
159 return nil, err
160 }
161
162 // transition prep and bounds checking
163 reviewDuration := 5
164 reviewSamples := first.SampleRate * first.Channels * reviewDuration
165 if len(first.PCM) < reviewSamples {
166 reviewSamples = len(first.PCM)
167 }
168 if len(second.PCM) < reviewSamples {
169 reviewSamples = len(second.PCM)
170 }
171
172 reviewPCM := append(first.PCM[len(first.PCM)-reviewSamples:], second.PCM[:reviewSamples]...)
173 first.PCM = append(first.PCM, second.PCM...)
174
175 return &WAV{
176 SampleRate: first.SampleRate,
177 Channels: first.Channels,
178 BitsPerSample: first.BitsPerSample,
179 PCM: reviewPCM,
180 }, nil
181}
182
183func (first *WAV) Fade(second *WAV, d float64) (*WAV, error) {
184 err := check(first, second)
185 if err != nil {
186 return nil, err
187 }
188
189 fadeSamples := int(d * float64(first.SampleRate*first.Channels))
190 if fadeSamples <= 0 || fadeSamples > len(first.PCM) || fadeSamples > len(second.PCM) {
191 return first.Append(second)
192 }
193
194 fadeStart := len(first.PCM) - fadeSamples
195 transition := make([]int16, fadeSamples)
196 for i := range fadeSamples {
197 fadeIn := float64(i) / float64(fadeSamples)
198 fadeOut := 1.0 - fadeIn
199 sampleA := float64(first.PCM[fadeStart+i])
200 sampleB := float64(second.PCM[i])
201 transition[i] = int16(sampleA*fadeOut + sampleB*fadeIn)
202 }
203
204 first.PCM = slices.Concat(
205 first.PCM[:fadeStart],
206 transition,
207 second.PCM[fadeSamples:],
208 )
209
210 reviewDuration := 2
211 reviewSamples := first.SampleRate * first.Channels * reviewDuration
212 reviewStart := fadeStart - reviewSamples
213 reviewEnd := fadeStart + fadeSamples + reviewSamples
214 reviewPCM := first.PCM[reviewStart:reviewEnd]
215
216 return &WAV{
217 SampleRate: first.SampleRate,
218 Channels: first.Channels,
219 BitsPerSample: first.BitsPerSample,
220 PCM: reviewPCM,
221 }, nil
222}
223
224func (w *WAV) Stretch(count int) {
225 ch := w.Channels
226 frames := len(w.PCM) / ch
227 stretched := make([]int16, 0, len(w.PCM)*count)
228
229 for i := range frames {
230 frame := w.PCM[i*ch : (i+1)*ch]
231 for range count {
232 stretched = append(stretched, frame...)
233 }
234 }
235
236 w.PCM = stretched
237}
238
239func (w *WAV) TrimEnd(d float64) {
240 if d <= 0 {
241 return
242 }
243
244 trimSamples := int(d * float64(w.SampleRate*w.Channels))
245 newEnd := len(w.PCM) - trimSamples
246 if newEnd < 0 {
247 newEnd = 0
248 }
249
250 w.PCM = w.PCM[:newEnd]
251}
252
253func (w *WAV) TrimStart(d float64) {
254 if d <= 0 {
255 return
256 }
257
258 trimSamples := int(d * float64(w.SampleRate*w.Channels))
259 newStart := len(w.PCM) - trimSamples
260 if newStart < 0 {
261 newStart = 0
262 }
263
264 w.PCM = w.PCM[newStart:]
265}