main
Raw Download raw file
  1package packet
  2
  3// This file implements the pushdown automata (PDA) from PGPainless (Paul Schaub)
  4// to verify pgp packet sequences. See Paul's blogpost for more details:
  5// https://blog.jabberhead.tk/2022/10/26/implementing-packet-sequence-validation-using-pushdown-automata/
  6import (
  7	"fmt"
  8
  9	"github.com/ProtonMail/go-crypto/openpgp/errors"
 10)
 11
 12func NewErrMalformedMessage(from State, input InputSymbol, stackSymbol StackSymbol) errors.ErrMalformedMessage {
 13	return errors.ErrMalformedMessage(fmt.Sprintf("state %d, input symbol %d, stack symbol %d ", from, input, stackSymbol))
 14}
 15
 16// InputSymbol defines the input alphabet of the PDA
 17type InputSymbol uint8
 18
 19const (
 20	LDSymbol InputSymbol = iota
 21	SigSymbol
 22	OPSSymbol
 23	CompSymbol
 24	ESKSymbol
 25	EncSymbol
 26	EOSSymbol
 27	UnknownSymbol
 28)
 29
 30// StackSymbol defines the stack alphabet of the PDA
 31type StackSymbol int8
 32
 33const (
 34	MsgStackSymbol StackSymbol = iota
 35	OpsStackSymbol
 36	KeyStackSymbol
 37	EndStackSymbol
 38	EmptyStackSymbol
 39)
 40
 41// State defines the states of the PDA
 42type State int8
 43
 44const (
 45	OpenPGPMessage State = iota
 46	ESKMessage
 47	LiteralMessage
 48	CompressedMessage
 49	EncryptedMessage
 50	ValidMessage
 51)
 52
 53// transition represents a state transition in the PDA
 54type transition func(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error)
 55
 56// SequenceVerifier is a pushdown automata to verify
 57// PGP messages packet sequences according to rfc4880.
 58type SequenceVerifier struct {
 59	stack []StackSymbol
 60	state State
 61}
 62
 63// Next performs a state transition with the given input symbol.
 64// If the transition fails a ErrMalformedMessage is returned.
 65func (sv *SequenceVerifier) Next(input InputSymbol) error {
 66	for {
 67		stackSymbol := sv.popStack()
 68		transitionFunc := getTransition(sv.state)
 69		nextState, newStackSymbols, redo, err := transitionFunc(input, stackSymbol)
 70		if err != nil {
 71			return err
 72		}
 73		if redo {
 74			sv.pushStack(stackSymbol)
 75		}
 76		for _, newStackSymbol := range newStackSymbols {
 77			sv.pushStack(newStackSymbol)
 78		}
 79		sv.state = nextState
 80		if !redo {
 81			break
 82		}
 83	}
 84	return nil
 85}
 86
 87// Valid returns true if RDA is in a valid state.
 88func (sv *SequenceVerifier) Valid() bool {
 89	return sv.state == ValidMessage && len(sv.stack) == 0
 90}
 91
 92func (sv *SequenceVerifier) AssertValid() error {
 93	if !sv.Valid() {
 94		return errors.ErrMalformedMessage("invalid message")
 95	}
 96	return nil
 97}
 98
 99func NewSequenceVerifier() *SequenceVerifier {
100	return &SequenceVerifier{
101		stack: []StackSymbol{EndStackSymbol, MsgStackSymbol},
102		state: OpenPGPMessage,
103	}
104}
105
106func (sv *SequenceVerifier) popStack() StackSymbol {
107	if len(sv.stack) == 0 {
108		return EmptyStackSymbol
109	}
110	elemIndex := len(sv.stack) - 1
111	stackSymbol := sv.stack[elemIndex]
112	sv.stack = sv.stack[:elemIndex]
113	return stackSymbol
114}
115
116func (sv *SequenceVerifier) pushStack(stackSymbol StackSymbol) {
117	sv.stack = append(sv.stack, stackSymbol)
118}
119
120func getTransition(from State) transition {
121	switch from {
122	case OpenPGPMessage:
123		return fromOpenPGPMessage
124	case LiteralMessage:
125		return fromLiteralMessage
126	case CompressedMessage:
127		return fromCompressedMessage
128	case EncryptedMessage:
129		return fromEncryptedMessage
130	case ESKMessage:
131		return fromESKMessage
132	case ValidMessage:
133		return fromValidMessage
134	}
135	return nil
136}
137
138// fromOpenPGPMessage is the transition for the state OpenPGPMessage.
139func fromOpenPGPMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
140	if stackSymbol != MsgStackSymbol {
141		return 0, nil, false, NewErrMalformedMessage(OpenPGPMessage, input, stackSymbol)
142	}
143	switch input {
144	case LDSymbol:
145		return LiteralMessage, nil, false, nil
146	case SigSymbol:
147		return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, false, nil
148	case OPSSymbol:
149		return OpenPGPMessage, []StackSymbol{OpsStackSymbol, MsgStackSymbol}, false, nil
150	case CompSymbol:
151		return CompressedMessage, nil, false, nil
152	case ESKSymbol:
153		return ESKMessage, []StackSymbol{KeyStackSymbol}, false, nil
154	case EncSymbol:
155		return EncryptedMessage, nil, false, nil
156	}
157	return 0, nil, false, NewErrMalformedMessage(OpenPGPMessage, input, stackSymbol)
158}
159
160// fromESKMessage is the transition for the state ESKMessage.
161func fromESKMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
162	if stackSymbol != KeyStackSymbol {
163		return 0, nil, false, NewErrMalformedMessage(ESKMessage, input, stackSymbol)
164	}
165	switch input {
166	case ESKSymbol:
167		return ESKMessage, []StackSymbol{KeyStackSymbol}, false, nil
168	case EncSymbol:
169		return EncryptedMessage, nil, false, nil
170	}
171	return 0, nil, false, NewErrMalformedMessage(ESKMessage, input, stackSymbol)
172}
173
174// fromLiteralMessage is the transition for the state LiteralMessage.
175func fromLiteralMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
176	switch input {
177	case SigSymbol:
178		if stackSymbol == OpsStackSymbol {
179			return LiteralMessage, nil, false, nil
180		}
181	case EOSSymbol:
182		if stackSymbol == EndStackSymbol {
183			return ValidMessage, nil, false, nil
184		}
185	}
186	return 0, nil, false, NewErrMalformedMessage(LiteralMessage, input, stackSymbol)
187}
188
189// fromLiteralMessage is the transition for the state CompressedMessage.
190func fromCompressedMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
191	switch input {
192	case SigSymbol:
193		if stackSymbol == OpsStackSymbol {
194			return CompressedMessage, nil, false, nil
195		}
196	case EOSSymbol:
197		if stackSymbol == EndStackSymbol {
198			return ValidMessage, nil, false, nil
199		}
200	}
201	return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, true, nil
202}
203
204// fromEncryptedMessage is the transition for the state EncryptedMessage.
205func fromEncryptedMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
206	switch input {
207	case SigSymbol:
208		if stackSymbol == OpsStackSymbol {
209			return EncryptedMessage, nil, false, nil
210		}
211	case EOSSymbol:
212		if stackSymbol == EndStackSymbol {
213			return ValidMessage, nil, false, nil
214		}
215	}
216	return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, true, nil
217}
218
219// fromValidMessage is the transition for the state ValidMessage.
220func fromValidMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
221	return 0, nil, false, NewErrMalformedMessage(ValidMessage, input, stackSymbol)
222}