main
1package packet
2
3// This file implements the pushdown automata (PDA) from PGPainless (Paul Schaub)
4// to verify pgp packet sequences. See Paul's blogpost for more details:
5// https://blog.jabberhead.tk/2022/10/26/implementing-packet-sequence-validation-using-pushdown-automata/
6import (
7 "fmt"
8
9 "github.com/ProtonMail/go-crypto/openpgp/errors"
10)
11
12func NewErrMalformedMessage(from State, input InputSymbol, stackSymbol StackSymbol) errors.ErrMalformedMessage {
13 return errors.ErrMalformedMessage(fmt.Sprintf("state %d, input symbol %d, stack symbol %d ", from, input, stackSymbol))
14}
15
16// InputSymbol defines the input alphabet of the PDA
17type InputSymbol uint8
18
19const (
20 LDSymbol InputSymbol = iota
21 SigSymbol
22 OPSSymbol
23 CompSymbol
24 ESKSymbol
25 EncSymbol
26 EOSSymbol
27 UnknownSymbol
28)
29
30// StackSymbol defines the stack alphabet of the PDA
31type StackSymbol int8
32
33const (
34 MsgStackSymbol StackSymbol = iota
35 OpsStackSymbol
36 KeyStackSymbol
37 EndStackSymbol
38 EmptyStackSymbol
39)
40
41// State defines the states of the PDA
42type State int8
43
44const (
45 OpenPGPMessage State = iota
46 ESKMessage
47 LiteralMessage
48 CompressedMessage
49 EncryptedMessage
50 ValidMessage
51)
52
53// transition represents a state transition in the PDA
54type transition func(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error)
55
56// SequenceVerifier is a pushdown automata to verify
57// PGP messages packet sequences according to rfc4880.
58type SequenceVerifier struct {
59 stack []StackSymbol
60 state State
61}
62
63// Next performs a state transition with the given input symbol.
64// If the transition fails a ErrMalformedMessage is returned.
65func (sv *SequenceVerifier) Next(input InputSymbol) error {
66 for {
67 stackSymbol := sv.popStack()
68 transitionFunc := getTransition(sv.state)
69 nextState, newStackSymbols, redo, err := transitionFunc(input, stackSymbol)
70 if err != nil {
71 return err
72 }
73 if redo {
74 sv.pushStack(stackSymbol)
75 }
76 for _, newStackSymbol := range newStackSymbols {
77 sv.pushStack(newStackSymbol)
78 }
79 sv.state = nextState
80 if !redo {
81 break
82 }
83 }
84 return nil
85}
86
87// Valid returns true if RDA is in a valid state.
88func (sv *SequenceVerifier) Valid() bool {
89 return sv.state == ValidMessage && len(sv.stack) == 0
90}
91
92func (sv *SequenceVerifier) AssertValid() error {
93 if !sv.Valid() {
94 return errors.ErrMalformedMessage("invalid message")
95 }
96 return nil
97}
98
99func NewSequenceVerifier() *SequenceVerifier {
100 return &SequenceVerifier{
101 stack: []StackSymbol{EndStackSymbol, MsgStackSymbol},
102 state: OpenPGPMessage,
103 }
104}
105
106func (sv *SequenceVerifier) popStack() StackSymbol {
107 if len(sv.stack) == 0 {
108 return EmptyStackSymbol
109 }
110 elemIndex := len(sv.stack) - 1
111 stackSymbol := sv.stack[elemIndex]
112 sv.stack = sv.stack[:elemIndex]
113 return stackSymbol
114}
115
116func (sv *SequenceVerifier) pushStack(stackSymbol StackSymbol) {
117 sv.stack = append(sv.stack, stackSymbol)
118}
119
120func getTransition(from State) transition {
121 switch from {
122 case OpenPGPMessage:
123 return fromOpenPGPMessage
124 case LiteralMessage:
125 return fromLiteralMessage
126 case CompressedMessage:
127 return fromCompressedMessage
128 case EncryptedMessage:
129 return fromEncryptedMessage
130 case ESKMessage:
131 return fromESKMessage
132 case ValidMessage:
133 return fromValidMessage
134 }
135 return nil
136}
137
138// fromOpenPGPMessage is the transition for the state OpenPGPMessage.
139func fromOpenPGPMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
140 if stackSymbol != MsgStackSymbol {
141 return 0, nil, false, NewErrMalformedMessage(OpenPGPMessage, input, stackSymbol)
142 }
143 switch input {
144 case LDSymbol:
145 return LiteralMessage, nil, false, nil
146 case SigSymbol:
147 return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, false, nil
148 case OPSSymbol:
149 return OpenPGPMessage, []StackSymbol{OpsStackSymbol, MsgStackSymbol}, false, nil
150 case CompSymbol:
151 return CompressedMessage, nil, false, nil
152 case ESKSymbol:
153 return ESKMessage, []StackSymbol{KeyStackSymbol}, false, nil
154 case EncSymbol:
155 return EncryptedMessage, nil, false, nil
156 }
157 return 0, nil, false, NewErrMalformedMessage(OpenPGPMessage, input, stackSymbol)
158}
159
160// fromESKMessage is the transition for the state ESKMessage.
161func fromESKMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
162 if stackSymbol != KeyStackSymbol {
163 return 0, nil, false, NewErrMalformedMessage(ESKMessage, input, stackSymbol)
164 }
165 switch input {
166 case ESKSymbol:
167 return ESKMessage, []StackSymbol{KeyStackSymbol}, false, nil
168 case EncSymbol:
169 return EncryptedMessage, nil, false, nil
170 }
171 return 0, nil, false, NewErrMalformedMessage(ESKMessage, input, stackSymbol)
172}
173
174// fromLiteralMessage is the transition for the state LiteralMessage.
175func fromLiteralMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
176 switch input {
177 case SigSymbol:
178 if stackSymbol == OpsStackSymbol {
179 return LiteralMessage, nil, false, nil
180 }
181 case EOSSymbol:
182 if stackSymbol == EndStackSymbol {
183 return ValidMessage, nil, false, nil
184 }
185 }
186 return 0, nil, false, NewErrMalformedMessage(LiteralMessage, input, stackSymbol)
187}
188
189// fromLiteralMessage is the transition for the state CompressedMessage.
190func fromCompressedMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
191 switch input {
192 case SigSymbol:
193 if stackSymbol == OpsStackSymbol {
194 return CompressedMessage, nil, false, nil
195 }
196 case EOSSymbol:
197 if stackSymbol == EndStackSymbol {
198 return ValidMessage, nil, false, nil
199 }
200 }
201 return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, true, nil
202}
203
204// fromEncryptedMessage is the transition for the state EncryptedMessage.
205func fromEncryptedMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
206 switch input {
207 case SigSymbol:
208 if stackSymbol == OpsStackSymbol {
209 return EncryptedMessage, nil, false, nil
210 }
211 case EOSSymbol:
212 if stackSymbol == EndStackSymbol {
213 return ValidMessage, nil, false, nil
214 }
215 }
216 return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, true, nil
217}
218
219// fromValidMessage is the transition for the state ValidMessage.
220func fromValidMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
221 return 0, nil, false, NewErrMalformedMessage(ValidMessage, input, stackSymbol)
222}