main
Raw Download raw file
  1// Package ssh_config provides tools for manipulating SSH config files.
  2//
  3// Importantly, this parser attempts to preserve comments in a given file, so
  4// you can manipulate a `ssh_config` file from a program, if your heart desires.
  5//
  6// The Get() and GetStrict() functions will attempt to read values from
  7// $HOME/.ssh/config, falling back to /etc/ssh/ssh_config. The first argument is
  8// the host name to match on ("example.com"), and the second argument is the key
  9// you want to retrieve ("Port"). The keywords are case insensitive.
 10//
 11// 		port := ssh_config.Get("myhost", "Port")
 12//
 13// You can also manipulate an SSH config file and then print it or write it back
 14// to disk.
 15//
 16//	f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config"))
 17//	cfg, _ := ssh_config.Decode(f)
 18//	for _, host := range cfg.Hosts {
 19//		fmt.Println("patterns:", host.Patterns)
 20//		for _, node := range host.Nodes {
 21//			fmt.Println(node.String())
 22//		}
 23//	}
 24//
 25//	// Write the cfg back to disk:
 26//	fmt.Println(cfg.String())
 27//
 28// BUG: the Match directive is currently unsupported; parsing a config with
 29// a Match directive will trigger an error.
 30package ssh_config
 31
 32import (
 33	"bytes"
 34	"errors"
 35	"fmt"
 36	"io"
 37	"os"
 38	osuser "os/user"
 39	"path/filepath"
 40	"regexp"
 41	"runtime"
 42	"strings"
 43	"sync"
 44)
 45
 46const version = "1.2"
 47
 48var _ = version
 49
 50type configFinder func() string
 51
 52// UserSettings checks ~/.ssh and /etc/ssh for configuration files. The config
 53// files are parsed and cached the first time Get() or GetStrict() is called.
 54type UserSettings struct {
 55	IgnoreErrors       bool
 56	systemConfig       *Config
 57	systemConfigFinder configFinder
 58	userConfig         *Config
 59	userConfigFinder   configFinder
 60	loadConfigs        sync.Once
 61	onceErr            error
 62}
 63
 64func homedir() string {
 65	user, err := osuser.Current()
 66	if err == nil {
 67		return user.HomeDir
 68	} else {
 69		return os.Getenv("HOME")
 70	}
 71}
 72
 73func userConfigFinder() string {
 74	return filepath.Join(homedir(), ".ssh", "config")
 75}
 76
 77// DefaultUserSettings is the default UserSettings and is used by Get and
 78// GetStrict. It checks both $HOME/.ssh/config and /etc/ssh/ssh_config for keys,
 79// and it will return parse errors (if any) instead of swallowing them.
 80var DefaultUserSettings = &UserSettings{
 81	IgnoreErrors:       false,
 82	systemConfigFinder: systemConfigFinder,
 83	userConfigFinder:   userConfigFinder,
 84}
 85
 86func systemConfigFinder() string {
 87	return filepath.Join("/", "etc", "ssh", "ssh_config")
 88}
 89
 90func findVal(c *Config, alias, key string) (string, error) {
 91	if c == nil {
 92		return "", nil
 93	}
 94	val, err := c.Get(alias, key)
 95	if err != nil || val == "" {
 96		return "", err
 97	}
 98	if err := validate(key, val); err != nil {
 99		return "", err
100	}
101	return val, nil
102}
103
104func findAll(c *Config, alias, key string) ([]string, error) {
105	if c == nil {
106		return nil, nil
107	}
108	return c.GetAll(alias, key)
109}
110
111// Get finds the first value for key within a declaration that matches the
112// alias. Get returns the empty string if no value was found, or if IgnoreErrors
113// is false and we could not parse the configuration file. Use GetStrict to
114// disambiguate the latter cases.
115//
116// The match for key is case insensitive.
117//
118// Get is a wrapper around DefaultUserSettings.Get.
119func Get(alias, key string) string {
120	return DefaultUserSettings.Get(alias, key)
121}
122
123// GetAll retrieves zero or more directives for key for the given alias. GetAll
124// returns nil if no value was found, or if IgnoreErrors is false and we could
125// not parse the configuration file. Use GetAllStrict to disambiguate the
126// latter cases.
127//
128// In most cases you want to use Get or GetStrict, which returns a single value.
129// However, a subset of ssh configuration values (IdentityFile, for example)
130// allow you to specify multiple directives.
131//
132// The match for key is case insensitive.
133//
134// GetAll is a wrapper around DefaultUserSettings.GetAll.
135func GetAll(alias, key string) []string {
136	return DefaultUserSettings.GetAll(alias, key)
137}
138
139// GetStrict finds the first value for key within a declaration that matches the
140// alias. If key has a default value and no matching configuration is found, the
141// default will be returned. For more information on default values and the way
142// patterns are matched, see the manpage for ssh_config.
143//
144// The returned error will be non-nil if and only if a user's configuration file
145// or the system configuration file could not be parsed, and u.IgnoreErrors is
146// false.
147//
148// GetStrict is a wrapper around DefaultUserSettings.GetStrict.
149func GetStrict(alias, key string) (string, error) {
150	return DefaultUserSettings.GetStrict(alias, key)
151}
152
153// GetAllStrict retrieves zero or more directives for key for the given alias.
154//
155// In most cases you want to use Get or GetStrict, which returns a single value.
156// However, a subset of ssh configuration values (IdentityFile, for example)
157// allow you to specify multiple directives.
158//
159// The returned error will be non-nil if and only if a user's configuration file
160// or the system configuration file could not be parsed, and u.IgnoreErrors is
161// false.
162//
163// GetAllStrict is a wrapper around DefaultUserSettings.GetAllStrict.
164func GetAllStrict(alias, key string) ([]string, error) {
165	return DefaultUserSettings.GetAllStrict(alias, key)
166}
167
168// Get finds the first value for key within a declaration that matches the
169// alias. Get returns the empty string if no value was found, or if IgnoreErrors
170// is false and we could not parse the configuration file. Use GetStrict to
171// disambiguate the latter cases.
172//
173// The match for key is case insensitive.
174func (u *UserSettings) Get(alias, key string) string {
175	val, err := u.GetStrict(alias, key)
176	if err != nil {
177		return ""
178	}
179	return val
180}
181
182// GetAll retrieves zero or more directives for key for the given alias. GetAll
183// returns nil if no value was found, or if IgnoreErrors is false and we could
184// not parse the configuration file. Use GetStrict to disambiguate the latter
185// cases.
186//
187// The match for key is case insensitive.
188func (u *UserSettings) GetAll(alias, key string) []string {
189	val, _ := u.GetAllStrict(alias, key)
190	return val
191}
192
193// GetStrict finds the first value for key within a declaration that matches the
194// alias. If key has a default value and no matching configuration is found, the
195// default will be returned. For more information on default values and the way
196// patterns are matched, see the manpage for ssh_config.
197//
198// error will be non-nil if and only if a user's configuration file or the
199// system configuration file could not be parsed, and u.IgnoreErrors is false.
200func (u *UserSettings) GetStrict(alias, key string) (string, error) {
201	u.doLoadConfigs()
202	//lint:ignore S1002 I prefer it this way
203	if u.onceErr != nil && u.IgnoreErrors == false {
204		return "", u.onceErr
205	}
206	val, err := findVal(u.userConfig, alias, key)
207	if err != nil || val != "" {
208		return val, err
209	}
210	val2, err2 := findVal(u.systemConfig, alias, key)
211	if err2 != nil || val2 != "" {
212		return val2, err2
213	}
214	return Default(key), nil
215}
216
217// GetAllStrict retrieves zero or more directives for key for the given alias.
218// If key has a default value and no matching configuration is found, the
219// default will be returned. For more information on default values and the way
220// patterns are matched, see the manpage for ssh_config.
221//
222// The returned error will be non-nil if and only if a user's configuration file
223// or the system configuration file could not be parsed, and u.IgnoreErrors is
224// false.
225func (u *UserSettings) GetAllStrict(alias, key string) ([]string, error) {
226	u.doLoadConfigs()
227	//lint:ignore S1002 I prefer it this way
228	if u.onceErr != nil && u.IgnoreErrors == false {
229		return nil, u.onceErr
230	}
231	val, err := findAll(u.userConfig, alias, key)
232	if err != nil || val != nil {
233		return val, err
234	}
235	val2, err2 := findAll(u.systemConfig, alias, key)
236	if err2 != nil || val2 != nil {
237		return val2, err2
238	}
239	// TODO: IdentityFile has multiple default values that we should return.
240	if def := Default(key); def != "" {
241		return []string{def}, nil
242	}
243	return []string{}, nil
244}
245
246func (u *UserSettings) doLoadConfigs() {
247	u.loadConfigs.Do(func() {
248		// can't parse user file, that's ok.
249		var filename string
250		if u.userConfigFinder == nil {
251			filename = userConfigFinder()
252		} else {
253			filename = u.userConfigFinder()
254		}
255		var err error
256		u.userConfig, err = parseFile(filename)
257		//lint:ignore S1002 I prefer it this way
258		if err != nil && os.IsNotExist(err) == false {
259			u.onceErr = err
260			return
261		}
262		if u.systemConfigFinder == nil {
263			filename = systemConfigFinder()
264		} else {
265			filename = u.systemConfigFinder()
266		}
267		u.systemConfig, err = parseFile(filename)
268		//lint:ignore S1002 I prefer it this way
269		if err != nil && os.IsNotExist(err) == false {
270			u.onceErr = err
271			return
272		}
273	})
274}
275
276func parseFile(filename string) (*Config, error) {
277	return parseWithDepth(filename, 0)
278}
279
280func parseWithDepth(filename string, depth uint8) (*Config, error) {
281	b, err := os.ReadFile(filename)
282	if err != nil {
283		return nil, err
284	}
285	return decodeBytes(b, isSystem(filename), depth)
286}
287
288func isSystem(filename string) bool {
289	// TODO: not sure this is the best way to detect a system repo
290	return strings.HasPrefix(filepath.Clean(filename), "/etc/ssh")
291}
292
293// Decode reads r into a Config, or returns an error if r could not be parsed as
294// an SSH config file.
295func Decode(r io.Reader) (*Config, error) {
296	b, err := io.ReadAll(r)
297	if err != nil {
298		return nil, err
299	}
300	return decodeBytes(b, false, 0)
301}
302
303// DecodeBytes reads b into a Config, or returns an error if r could not be
304// parsed as an SSH config file.
305func DecodeBytes(b []byte) (*Config, error) {
306	return decodeBytes(b, false, 0)
307}
308
309func decodeBytes(b []byte, system bool, depth uint8) (c *Config, err error) {
310	defer func() {
311		if r := recover(); r != nil {
312			if _, ok := r.(runtime.Error); ok {
313				panic(r)
314			}
315			if e, ok := r.(error); ok && e == ErrDepthExceeded {
316				err = e
317				return
318			}
319			err = errors.New(r.(string))
320		}
321	}()
322
323	c = parseSSH(lexSSH(b), system, depth)
324	return c, err
325}
326
327// Config represents an SSH config file.
328type Config struct {
329	// A list of hosts to match against. The file begins with an implicit
330	// "Host *" declaration matching all hosts.
331	Hosts    []*Host
332	depth    uint8
333	position Position
334}
335
336// Get finds the first value in the configuration that matches the alias and
337// contains key. Get returns the empty string if no value was found, or if the
338// Config contains an invalid conditional Include value.
339//
340// The match for key is case insensitive.
341func (c *Config) Get(alias, key string) (string, error) {
342	lowerKey := strings.ToLower(key)
343	for _, host := range c.Hosts {
344		if !host.Matches(alias) {
345			continue
346		}
347		for _, node := range host.Nodes {
348			switch t := node.(type) {
349			case *Empty:
350				continue
351			case *KV:
352				// "keys are case insensitive" per the spec
353				lkey := strings.ToLower(t.Key)
354				if lkey == "match" {
355					panic("can't handle Match directives")
356				}
357				if lkey == lowerKey {
358					return t.Value, nil
359				}
360			case *Include:
361				val := t.Get(alias, key)
362				if val != "" {
363					return val, nil
364				}
365			default:
366				return "", fmt.Errorf("unknown Node type %v", t)
367			}
368		}
369	}
370	return "", nil
371}
372
373// GetAll returns all values in the configuration that match the alias and
374// contains key, or nil if none are present.
375func (c *Config) GetAll(alias, key string) ([]string, error) {
376	lowerKey := strings.ToLower(key)
377	all := []string(nil)
378	for _, host := range c.Hosts {
379		if !host.Matches(alias) {
380			continue
381		}
382		for _, node := range host.Nodes {
383			switch t := node.(type) {
384			case *Empty:
385				continue
386			case *KV:
387				// "keys are case insensitive" per the spec
388				lkey := strings.ToLower(t.Key)
389				if lkey == "match" {
390					panic("can't handle Match directives")
391				}
392				if lkey == lowerKey {
393					all = append(all, t.Value)
394				}
395			case *Include:
396				val, _ := t.GetAll(alias, key)
397				if len(val) > 0 {
398					all = append(all, val...)
399				}
400			default:
401				return nil, fmt.Errorf("unknown Node type %v", t)
402			}
403		}
404	}
405
406	return all, nil
407}
408
409// String returns a string representation of the Config file.
410func (c Config) String() string {
411	return marshal(c).String()
412}
413
414func (c Config) MarshalText() ([]byte, error) {
415	return marshal(c).Bytes(), nil
416}
417
418func marshal(c Config) *bytes.Buffer {
419	var buf bytes.Buffer
420	for i := range c.Hosts {
421		buf.WriteString(c.Hosts[i].String())
422	}
423	return &buf
424}
425
426// Pattern is a pattern in a Host declaration. Patterns are read-only values;
427// create a new one with NewPattern().
428type Pattern struct {
429	str   string // Its appearance in the file, not the value that gets compiled.
430	regex *regexp.Regexp
431	not   bool // True if this is a negated match
432}
433
434// String prints the string representation of the pattern.
435func (p Pattern) String() string {
436	return p.str
437}
438
439// Copied from regexp.go with * and ? removed.
440var specialBytes = []byte(`\.+()|[]{}^$`)
441
442func special(b byte) bool {
443	return bytes.IndexByte(specialBytes, b) >= 0
444}
445
446// NewPattern creates a new Pattern for matching hosts. NewPattern("*") creates
447// a Pattern that matches all hosts.
448//
449// From the manpage, a pattern consists of zero or more non-whitespace
450// characters, `*' (a wildcard that matches zero or more characters), or `?' (a
451// wildcard that matches exactly one character). For example, to specify a set
452// of declarations for any host in the ".co.uk" set of domains, the following
453// pattern could be used:
454//
455//	Host *.co.uk
456//
457// The following pattern would match any host in the 192.168.0.[0-9] network range:
458//
459//	Host 192.168.0.?
460func NewPattern(s string) (*Pattern, error) {
461	if s == "" {
462		return nil, errors.New("ssh_config: empty pattern")
463	}
464	negated := false
465	if s[0] == '!' {
466		negated = true
467		s = s[1:]
468	}
469	var buf bytes.Buffer
470	buf.WriteByte('^')
471	for i := 0; i < len(s); i++ {
472		// A byte loop is correct because all metacharacters are ASCII.
473		switch b := s[i]; b {
474		case '*':
475			buf.WriteString(".*")
476		case '?':
477			buf.WriteString(".?")
478		default:
479			// borrowing from QuoteMeta here.
480			if special(b) {
481				buf.WriteByte('\\')
482			}
483			buf.WriteByte(b)
484		}
485	}
486	buf.WriteByte('$')
487	r, err := regexp.Compile(buf.String())
488	if err != nil {
489		return nil, err
490	}
491	return &Pattern{str: s, regex: r, not: negated}, nil
492}
493
494// Host describes a Host directive and the keywords that follow it.
495type Host struct {
496	// A list of host patterns that should match this host.
497	Patterns []*Pattern
498	// A Node is either a key/value pair or a comment line.
499	Nodes []Node
500	// EOLComment is the comment (if any) terminating the Host line.
501	EOLComment string
502	// Whitespace if any between the Host declaration and a trailing comment.
503	spaceBeforeComment string
504
505	hasEquals    bool
506	leadingSpace int // TODO: handle spaces vs tabs here.
507	// The file starts with an implicit "Host *" declaration.
508	implicit bool
509}
510
511// Matches returns true if the Host matches for the given alias. For
512// a description of the rules that provide a match, see the manpage for
513// ssh_config.
514func (h *Host) Matches(alias string) bool {
515	found := false
516	for i := range h.Patterns {
517		if h.Patterns[i].regex.MatchString(alias) {
518			if h.Patterns[i].not {
519				// Negated match. "A pattern entry may be negated by prefixing
520				// it with an exclamation mark (`!'). If a negated entry is
521				// matched, then the Host entry is ignored, regardless of
522				// whether any other patterns on the line match. Negated matches
523				// are therefore useful to provide exceptions for wildcard
524				// matches."
525				return false
526			}
527			found = true
528		}
529	}
530	return found
531}
532
533// String prints h as it would appear in a config file. Minor tweaks may be
534// present in the whitespace in the printed file.
535func (h *Host) String() string {
536	var buf strings.Builder
537	//lint:ignore S1002 I prefer to write it this way
538	if h.implicit == false {
539		buf.WriteString(strings.Repeat(" ", int(h.leadingSpace)))
540		buf.WriteString("Host")
541		if h.hasEquals {
542			buf.WriteString(" = ")
543		} else {
544			buf.WriteString(" ")
545		}
546		for i, pat := range h.Patterns {
547			buf.WriteString(pat.String())
548			if i < len(h.Patterns)-1 {
549				buf.WriteString(" ")
550			}
551		}
552		buf.WriteString(h.spaceBeforeComment)
553		if h.EOLComment != "" {
554			buf.WriteByte('#')
555			buf.WriteString(h.EOLComment)
556		}
557		buf.WriteByte('\n')
558	}
559	for i := range h.Nodes {
560		buf.WriteString(h.Nodes[i].String())
561		buf.WriteByte('\n')
562	}
563	return buf.String()
564}
565
566// Node represents a line in a Config.
567type Node interface {
568	Pos() Position
569	String() string
570}
571
572// KV is a line in the config file that contains a key, a value, and possibly
573// a comment.
574type KV struct {
575	Key   string
576	Value string
577	// Whitespace after the value but before any comment
578	spaceAfterValue string
579	Comment         string
580	hasEquals       bool
581	leadingSpace    int // Space before the key. TODO handle spaces vs tabs.
582	position        Position
583}
584
585// Pos returns k's Position.
586func (k *KV) Pos() Position {
587	return k.position
588}
589
590// String prints k as it was parsed in the config file.
591func (k *KV) String() string {
592	if k == nil {
593		return ""
594	}
595	equals := " "
596	if k.hasEquals {
597		equals = " = "
598	}
599	line := strings.Repeat(" ", int(k.leadingSpace)) + k.Key + equals + k.Value + k.spaceAfterValue
600	if k.Comment != "" {
601		line += "#" + k.Comment
602	}
603	return line
604}
605
606// Empty is a line in the config file that contains only whitespace or comments.
607type Empty struct {
608	Comment      string
609	leadingSpace int // TODO handle spaces vs tabs.
610	position     Position
611}
612
613// Pos returns e's Position.
614func (e *Empty) Pos() Position {
615	return e.position
616}
617
618// String prints e as it was parsed in the config file.
619func (e *Empty) String() string {
620	if e == nil {
621		return ""
622	}
623	if e.Comment == "" {
624		return ""
625	}
626	return fmt.Sprintf("%s#%s", strings.Repeat(" ", int(e.leadingSpace)), e.Comment)
627}
628
629// Include holds the result of an Include directive, including the config files
630// that have been parsed as part of that directive. At most 5 levels of Include
631// statements will be parsed.
632type Include struct {
633	// Comment is the contents of any comment at the end of the Include
634	// statement.
635	Comment string
636	// an include directive can include several different files, and wildcards
637	directives []string
638
639	mu sync.Mutex
640	// 1:1 mapping between matches and keys in files array; matches preserves
641	// ordering
642	matches []string
643	// actual filenames are listed here
644	files        map[string]*Config
645	leadingSpace int
646	position     Position
647	depth        uint8
648	hasEquals    bool
649}
650
651const maxRecurseDepth = 5
652
653// ErrDepthExceeded is returned if too many Include directives are parsed.
654// Usually this indicates a recursive loop (an Include directive pointing to the
655// file it contains).
656var ErrDepthExceeded = errors.New("ssh_config: max recurse depth exceeded")
657
658func removeDups(arr []string) []string {
659	// Use map to record duplicates as we find them.
660	encountered := make(map[string]bool, len(arr))
661	result := make([]string, 0)
662
663	for v := range arr {
664		//lint:ignore S1002 I prefer it this way
665		if encountered[arr[v]] == false {
666			encountered[arr[v]] = true
667			result = append(result, arr[v])
668		}
669	}
670	return result
671}
672
673// NewInclude creates a new Include with a list of file globs to include.
674// Configuration files are parsed greedily (e.g. as soon as this function runs).
675// Any error encountered while parsing nested configuration files will be
676// returned.
677func NewInclude(directives []string, hasEquals bool, pos Position, comment string, system bool, depth uint8) (*Include, error) {
678	if depth > maxRecurseDepth {
679		return nil, ErrDepthExceeded
680	}
681	inc := &Include{
682		Comment:      comment,
683		directives:   directives,
684		files:        make(map[string]*Config),
685		position:     pos,
686		leadingSpace: pos.Col - 1,
687		depth:        depth,
688		hasEquals:    hasEquals,
689	}
690	// no need for inc.mu.Lock() since nothing else can access this inc
691	matches := make([]string, 0)
692	for i := range directives {
693		var path string
694		if filepath.IsAbs(directives[i]) {
695			path = directives[i]
696		} else if system {
697			path = filepath.Join("/etc/ssh", directives[i])
698		} else {
699			path = filepath.Join(homedir(), ".ssh", directives[i])
700		}
701		theseMatches, err := filepath.Glob(path)
702		if err != nil {
703			return nil, err
704		}
705		matches = append(matches, theseMatches...)
706	}
707	matches = removeDups(matches)
708	inc.matches = matches
709	for i := range matches {
710		config, err := parseWithDepth(matches[i], depth)
711		if err != nil {
712			return nil, err
713		}
714		inc.files[matches[i]] = config
715	}
716	return inc, nil
717}
718
719// Pos returns the position of the Include directive in the larger file.
720func (i *Include) Pos() Position {
721	return i.position
722}
723
724// Get finds the first value in the Include statement matching the alias and the
725// given key.
726func (inc *Include) Get(alias, key string) string {
727	inc.mu.Lock()
728	defer inc.mu.Unlock()
729	// TODO: we search files in any order which is not correct
730	for i := range inc.matches {
731		cfg := inc.files[inc.matches[i]]
732		if cfg == nil {
733			panic("nil cfg")
734		}
735		val, err := cfg.Get(alias, key)
736		if err == nil && val != "" {
737			return val
738		}
739	}
740	return ""
741}
742
743// GetAll finds all values in the Include statement matching the alias and the
744// given key.
745func (inc *Include) GetAll(alias, key string) ([]string, error) {
746	inc.mu.Lock()
747	defer inc.mu.Unlock()
748	var vals []string
749
750	// TODO: we search files in any order which is not correct
751	for i := range inc.matches {
752		cfg := inc.files[inc.matches[i]]
753		if cfg == nil {
754			panic("nil cfg")
755		}
756		val, err := cfg.GetAll(alias, key)
757		if err == nil && len(val) != 0 {
758			// In theory if SupportsMultiple was false for this key we could
759			// stop looking here. But the caller has asked us to find all
760			// instances of the keyword (and could use Get() if they wanted) so
761			// let's keep looking.
762			vals = append(vals, val...)
763		}
764	}
765	return vals, nil
766}
767
768// String prints out a string representation of this Include directive. Note
769// included Config files are not printed as part of this representation.
770func (inc *Include) String() string {
771	equals := " "
772	if inc.hasEquals {
773		equals = " = "
774	}
775	line := fmt.Sprintf("%sInclude%s%s", strings.Repeat(" ", int(inc.leadingSpace)), equals, strings.Join(inc.directives, " "))
776	if inc.Comment != "" {
777		line += " #" + inc.Comment
778	}
779	return line
780}
781
782var matchAll *Pattern
783
784func init() {
785	var err error
786	matchAll, err = NewPattern("*")
787	if err != nil {
788		panic(err)
789	}
790}
791
792func newConfig() *Config {
793	return &Config{
794		Hosts: []*Host{
795			&Host{
796				implicit: true,
797				Patterns: []*Pattern{matchAll},
798				Nodes:    make([]Node, 0),
799			},
800		},
801		depth: 0,
802	}
803}