master
Raw Download raw file
  1import pwd
  2import os
  3import tempfile
  4from typing import Union, List
  5from dataclasses import dataclass, field
  6from dataclasses_jsonschema import JsonSchemaMixin
  7from lib.client import BaseClient
  8
  9
 10@dataclass
 11class Key(JsonSchemaMixin):
 12    id: int
 13    key: str
 14    user: str
 15
 16
 17@dataclass
 18class GithubAuthorizedKeyFile(JsonSchemaMixin):
 19    github_users: Union[str, List[str]]
 20    annotate: bool
 21    filename: str = None
 22    keys: List[Key] = field(default_factory=list)
 23    user: str = None
 24
 25    def __post_init__(self):
 26        if isinstance(self.github_users, str):
 27            self.github_users = [self.github_users]
 28
 29        if self.filename is None:
 30            err, self.filename = self.keyfile(user=self.user, write=True)
 31
 32    async def collect_keys(self):
 33
 34        for user in self.github_users:
 35            client = BaseClient(
 36                host="api.github.com", path=f"/users/{user}/keys"
 37            )
 38            err, data = await client.get_data()
 39            user_keys = [Key(**k, user=user) for k in data]
 40            self.keys = [*self.keys, *user_keys]
 41
 42    def keyfile(
 43        self,
 44        user=None,
 45        write=False,
 46        path=None,
 47        manage_dir=True,
 48        follow=False,
 49    ):
 50        """
 51        Calculate name of authorized keys file, optionally creating the
 52        directories and file, properly setting permissions.
 53    
 54        :param str user: name of user in passwd file
 55        :param bool write: if True, write changes to authorized_keys file (creating directories if needed)
 56        :param str path: if not None, use provided path rather than default of '~user/.ssh/authorized_keys'
 57        :param bool manage_dir: if True, create and set ownership of the parent dir of the authorized_keys file
 58        :param bool follow: if True symlinks will be followed and not replaced
 59        :return: full path string to authorized_keys for user
 60        """
 61
 62        try:
 63            if user is None:
 64                user_entry = pwd.getpwuid(os.getuid())
 65            else:
 66                user_entry = pwd.getpwnam(user)
 67        except KeyError as e:
 68            return (f"Failed to lookup user {user}: {e}", None)
 69
 70        if path is None:
 71            homedir = user_entry.pw_dir
 72            sshdir = os.path.join(homedir, ".ssh")
 73            keysfile = os.path.join(sshdir, "authorized_keys")
 74        else:
 75            sshdir = os.path.dirname(path)
 76            keysfile = path
 77
 78        if follow:
 79            keysfile = os.path.realpath(keysfile)
 80
 81        if not write:
 82            return (None, keysfile)
 83
 84        uid = user_entry.pw_uid
 85        gid = user_entry.pw_gid
 86
 87        if manage_dir:
 88            if not os.path.exists(sshdir):
 89                os.mkdir(sshdir, int("0700", 8))
 90            os.chown(sshdir, uid, gid)
 91            os.chmod(sshdir, int("0700", 8))
 92
 93        if not os.path.exists(keysfile):
 94            basedir = os.path.dirname(keysfile)
 95            if not os.path.exists(basedir):
 96                os.makedirs(basedir)
 97            try:
 98                f = open(
 99                    keysfile, "w"
100                )  # touches file so we can set ownership and perms
101            finally:
102                f.close()
103
104        try:
105            os.chown(keysfile, uid, gid)
106            os.chmod(keysfile, int("0600", 8))
107        except OSError:
108            pass
109
110        return (None, keysfile)
111
112    def serialize(self):
113        lines = []
114        lines.append(f"# managed by ssh-copy-id-from-github\n")
115        for k in self.keys:
116            key_annotation = ""
117            if self.annotate:
118                key_annotation = f" github_user={k.user} key_id={k.id}"
119            line = f"{k.key}{key_annotation}\n"
120            lines.append(line)
121        return "".join(lines)
122
123    def writefile(self):
124        fd, tmp_path = tempfile.mkstemp(
125            "", "tmp", os.path.dirname(self.filename)
126        )
127        f = open(tmp_path, "w")
128
129        try:
130            f.write(self.serialize())
131        except IOError as e:
132            return (f"Failed to write to file {tmp_path}: {e}", None)
133        f.close()
134        try:
135            os.rename(tmp_path, self.filename)
136        except Exception as err:
137            return (err, None)