From eff16c69860d765ac5f0303cce504955116893a6 Mon Sep 17 00:00:00 2001
From: Martin Mares <mj@ucw.cz>
Date: Fri, 7 Jun 2024 15:47:49 +0200
Subject: [PATCH] Reworked passing of credentials

The previous version was blatantly insecure :)
---
 TODO            |  2 ++
 shipcat/main.py | 33 +++++++++++++++++++--------------
 wrapper.c       | 31 +++++++++++++++----------------
 3 files changed, 36 insertions(+), 30 deletions(-)

diff --git a/TODO b/TODO
index 3313231..74e91d1 100644
--- a/TODO
+++ b/TODO
@@ -1,3 +1,5 @@
+- status: show if enabled
+
 Albireo cleanup:
 - prerouting rule
 - containers
diff --git a/shipcat/main.py b/shipcat/main.py
index fed800a..8466c7d 100755
--- a/shipcat/main.py
+++ b/shipcat/main.py
@@ -17,6 +17,8 @@ from shipcat.util import die, progress, verbose, set_verbose, run_command
 
 
 config: GlobalConfig
+caller_uid: int = 0
+caller_gids: List[int] = [0]
 
 
 def trace(msg: str) -> None:
@@ -30,32 +32,40 @@ def load_config() -> GlobalConfig:
         die(str(e))
 
 
-def setup_container(args: argparse.Namespace, require_root: bool, container: Optional[str] = None) -> ContainerConfig:
+def get_caller_credentials() -> None:
     if os.geteuid() != 0:
         die('This program must be run setuid root')
 
+    global caller_uid, caller_gids
+    if 'SHIPCAT_UID' in os.environ:
+        caller_uid = int(os.environ['SHIPCAT_UID'])
+        caller_gids = list(map(int, os.environ['SHIPCAT_GIDS'].split(',')))
+
+
+def setup_container(args: argparse.Namespace, require_root: bool, container: Optional[str] = None) -> ContainerConfig:
+
     try:
         cc = ContainerConfig.load(config, container or args.name)
     except ConfigError as e:
         die(str(e))
 
-    if require_root and args.as_user != 0:
+    if require_root and caller_uid != 0:
         die('This operation must be performed by root')
 
-    if not check_rights(cc, args.as_user, args.as_groups):
+    if not check_rights(cc):
         die('You do not have permission to operate this container')
 
     return cc
 
 
-def check_rights(cc: ContainerConfig, uid: int, gids: List[int]) -> bool:
-    if uid == 0:
+def check_rights(cc: ContainerConfig) -> bool:
+    if caller_uid == 0:
         return True
 
-    if uid in cc.allowed_users:
+    if caller_uid in cc.allowed_users:
         return True
 
-    for gid in gids:
+    for gid in caller_gids:
         if gid in cc.allowed_groups:
             return True
 
@@ -245,7 +255,7 @@ def overall_status(args: argparse.Namespace) -> None:
     for conf in config.container_config_path.iterdir():
         if conf.suffix == '.toml':
             cc = ContainerConfig.load(config, conf.stem)
-            if args.all or check_rights(cc, args.as_user, args.as_groups):
+            if args.all or check_rights(cc):
                 st = status.get(cc.service_name, {})
                 st_active = st.get('ActiveState', 'not found')
                 st_sub = st.get('SubState', 'not found')
@@ -385,16 +395,10 @@ def main_service_stop():
     run_command(['podman', 'stop', '--ignore', cc.name])
 
 
-def parse_int_list(s: str) -> List[int]:
-    return list(map(int, s.split(',')))
-
-
 def main() -> None:
     parser = argparse.ArgumentParser(
         description="Ship's Cat -- a container management tool",
     )
-    parser.add_argument('--as-user', type=int, default=0, metavar='UID', help='user ID of requesting user')
-    parser.add_argument('--as-groups', type=parse_int_list, metavar='GID,...', help='group IDs of requesting user (primary first)')
     parser.add_argument('--verbose', '-v', default=False, action='store_true', help='be chatty and explain what is going on')
     subparsers = parser.add_subparsers(help='action to perform', dest='action', required=True, metavar='ACTION')
 
@@ -439,6 +443,7 @@ def main() -> None:
     update_parser.add_argument('name', help='name of the container')
 
     args = parser.parse_args()
+    get_caller_credentials()
 
     actions = {
         'disable': cmd_disable,
diff --git a/wrapper.c b/wrapper.c
index ef70805..9f3da10 100644
--- a/wrapper.c
+++ b/wrapper.c
@@ -56,8 +56,8 @@ static void sanitize_fds(void)
   close(fd);
 }
 
-static char as_user[32];
-static char as_groups[1024];
+static char env_uid[32];
+static char env_gids[1024];
 
 static void get_credentials(void)
 {
@@ -65,16 +65,16 @@ static void get_credentials(void)
   struct passwd *pwd = getpwuid(uid);
   if (!pwd)
     die("You don't exist");
-  int ulen = snprintf(as_user, sizeof(as_user), "%u", uid);
-  if (ulen >= (int) sizeof(as_user))
+  int ulen = snprintf(env_uid, sizeof(env_uid), "SHIPCAT_UID=%u", uid);
+  if (ulen >= (int) sizeof(env_uid))
     die("UID too long");
 
   gid_t gid = getgid();
   struct group *grp = getgrgid(gid);
   if (!grp)
     die("Your group does not exist");
-  int glen = snprintf(as_groups, sizeof(as_groups), "%u", gid);
-  if (glen >= (int) sizeof(as_groups))
+  int glen = snprintf(env_gids, sizeof(env_gids), "SHIPCAT_GIDS=%u", gid);
+  if (glen >= (int) sizeof(env_gids))
     die("GID too long");
 
   int max_groups = getgroups(0, NULL);
@@ -90,8 +90,8 @@ static void get_credentials(void)
     {
       if (groups[i] == gid)
 	continue;
-      glen += snprintf(as_groups + glen, (int) sizeof(as_groups) - glen, ",%u", groups[i]);
-      if (glen >= (int) sizeof(as_groups))
+      glen += snprintf(env_gids + glen, (int) sizeof(env_gids) - glen, ",%u", groups[i]);
+      if (glen >= (int) sizeof(env_gids))
 	die("Group list too long");
     }
 
@@ -109,15 +109,11 @@ static void switch_ugid(void)
 
 static char **make_args(int argc, char **argv)
 {
-  char **args = xmalloc(sizeof(char *) * (5 + argc));
+  char **args = xmalloc(sizeof(char *) * (argc + 1));
   args[0] = "/usr/bin/shipcat";
-  args[1] = "--as-user";
-  args[2] = as_user;
-  args[3] = "--as-groups";
-  args[4] = as_groups;
   for (int i=1; i<argc; i++)
-    args[4+i] = argv[i];
-  args[4+argc] = NULL;
+    args[i] = argv[i];
+  args[argc] = NULL;
   return args;
 }
 
@@ -131,12 +127,15 @@ static char **make_env(void)
     "TERM=",
   };
 
-  char **env = xmalloc(sizeof(char *) * (ARRAY_SIZE(set_env) + ARRAY_SIZE(inherit_env) + 1));
+  char **env = xmalloc(sizeof(char *) * (ARRAY_SIZE(set_env) + ARRAY_SIZE(inherit_env) + 3));
   char **ep = env;
 
   for (int i=0; i < ARRAY_SIZE(set_env); i++)
     *ep++ = set_env[i];
 
+  *ep++ = env_uid;
+  *ep++ = env_gids;
+
   for (int i=0; environ[i]; i++)
     {
       char *e = environ[i];
-- 
GitLab