#define _GNU_SOURCE /* struct ucred */ #include <assert.h> #include <errno.h> #include <signal.h> #include <stdarg.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <sys/stat.h> #include <sys/socket.h> #include <sys/un.h> #include <syslog.h> #include <unistd.h> #include "cgroup.h" #include "common.h" #define begins_with(s_, a_) (!strncmp(s_, a_, sizeof(s_) - 1)) void logperror(const char *s) { syslog(LOG_ERR, "%s: %s", s, strerror(errno)); } void memory_limits(size_t *minuser, size_t *mincomp, size_t *maxcomp, size_t *total) { FILE *f = fopen("/proc/meminfo", "r"); char line[1024]; while (fgets(line, sizeof(line), f)) { if (begins_with("MemTotal:", line)) { *total = 0; sscanf(line, "MemTotal:%zu", total); *total *= 1024; break; } } fclose(f); *minuser = *total * split_ratio; if (*minuser < static_minfree) *minuser = static_minfree; if (*minuser > static_maxfree) *minuser = static_maxfree; *mincomp = static_minfree; *maxcomp = *total - *minuser; if (*maxcomp < 0) *maxcomp = 0; /* maxcomp < mincomp may happen; they are used in different * settings. */ } size_t get_default_mem_limit(void) { size_t minuser, mincomp, maxcomp, total; memory_limits(&minuser, &mincomp, &maxcomp, &total); return maxcomp; } void cgroup_init(void) { if (cgroup_setup(chier, "memory") < 0) exit(EXIT_FAILURE); int ret = cgroup_create(chier, cgroup); if (ret < 0) exit(EXIT_FAILURE); if (ret > 0) { /* CGroup newly created, set limit. */ if (cgroup_set_mem_limit(chier, cgroup, get_default_mem_limit()) < 0) exit(EXIT_FAILURE); } } void mprintf(int fd, char *fmt, ...) { char buf[1024]; va_list v; va_start(v, fmt); vsnprintf(buf, sizeof(buf), fmt, v); va_end(v); struct iovec iov = { .iov_base = buf, .iov_len = strlen(buf), }; struct msghdr msg = { .msg_iov = &iov, .msg_iovlen = 1, }; ssize_t sent = sendmsg(fd, &msg, 0); if (sent < 0) { logperror("sendmsg"); return; } if (sent < iov.iov_len) { syslog(LOG_INFO, "incomplete send %zd < %zu, FIXME", sent, iov.iov_len); return; } } int main(int argc, char *argv[]) { /* Do this while everyone can still see the error. */ cgroup_init(); pid_t p = fork(); if (p < 0) { perror("fork"); exit(EXIT_FAILURE); } if (p > 0) exit(EXIT_SUCCESS); fclose(stderr); fclose(stdout); fclose(stdin); openlog("compctl", LOG_PID, LOG_DAEMON); cgroup_perror = logperror; setsid(); int s = socket(AF_UNIX, SOCK_STREAM, 0); /* TODO: Protect against double execution? */ unlink(SOCKFILE); struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE }; if (bind(s, (struct sockaddr *) &sun, sizeof(sun.sun_family) + strlen(sun.sun_path) + 1) < 0) { logperror(SOCKFILE); exit(EXIT_FAILURE); } chmod(SOCKFILE, 0777); listen(s, 10); int fd; while ((fd = accept(s, NULL, NULL)) >= 0) { /* We handle only a single client at a time. This means * that it is rather easy to write a script that will DOS * the daemon, this is just an attack vector we ignore. */ /* TODO: alarm() to wake from stuck clients. */ /* Decode the message with command and credentials. */ int on = 1; setsockopt(fd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on)); struct ucred *cred; char cbuf[CMSG_SPACE(sizeof(*cred))]; char line[1024]; struct iovec iov = { .iov_base = line, .iov_len = sizeof(line), }; struct msghdr msg = { .msg_iov = &iov, .msg_iovlen = 1, .msg_control = cbuf, .msg_controllen = sizeof(cbuf), }; char *errmsg; recvagain:; int replylen = recvmsg(fd, &msg, 0); if (replylen < 0) { if (errno == EAGAIN) goto recvagain; errmsg = "recvmsg"; sockerror: logperror(errmsg); close(fd); continue; } struct cmsghdr *cmsg; cmsg = CMSG_FIRSTHDR(&msg); if (cmsg == NULL || cmsg->cmsg_len != CMSG_LEN(sizeof(*cred))) { syslog(LOG_INFO, "want %lu", CMSG_LEN(sizeof(*cred))); errmsg = "cmsg"; goto sockerror; } if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDENTIALS) { errmsg = "cmsg designation"; goto sockerror; } cred = (struct ucred *) CMSG_DATA(cmsg); line[replylen] = 0; if (replylen < 2) { syslog(LOG_WARNING, "protocol error (%s)", line); close(fd); continue; } /* Analyze command */ if (!strcmp("blessme", line)) { syslog(LOG_INFO, "new computation process %d", cred->pid); if (cgroup_add_task(chier, cgroup, cred->pid) < 0) mprintf(fd, "0 error: %s", strerror(errno)); else mprintf(fd, "1 blessed"); } else if (begins_with("stop ", line)) { pid_t pid = atoi(line + sizeof("stop ")); /* Sanity check. */ if (pid < 10 || pid > 32768) { syslog(LOG_WARNING, "stop: invalid pid (%d)", pid); mprintf(fd, "0 invalid pid"); close(fd); continue; } if (!cgroup_is_task_in_cgroup(chier, cgroup, pid)) { mprintf(fd, "0 task not marked as computation"); close(fd); continue; } syslog(LOG_INFO, "stopping process %d (request by pid %d uid %d)", pid, cred->pid, cred->uid); kill(pid, SIGTERM); /* TODO: Grace period and then kill with SIGKILL. */ mprintf(fd, "1 task stopped"); } else if (!strcmp("stopall", line)) { pid_t *tasks; int tasks_n = cgroup_task_list(chier, cgroup, &tasks); if (tasks_n < 0) { mprintf(fd, "0 error: %s\r\n", strerror(errno)); close(fd); continue; } for (int i = 0; i < tasks_n; i++) { syslog(LOG_INFO, "stopping process %d (mass request by pid %d uid %d)", tasks[i], cred->pid, cred->uid); kill(tasks[i], SIGTERM); } /* TODO: Grace period and then kill with SIGKILL. */ mprintf(fd, "1 %d tasks stopped", tasks_n); free(tasks); } else if (begins_with("limitmem ", line)) { size_t limit = atol(line + sizeof("limitmem ")); size_t minuser, mincomp, maxcomp, total; memory_limits(&minuser, &mincomp, &maxcomp, &total); /* Sanity check. */ if (limit < 1024 || limit > total) { syslog(LOG_WARNING, "limitmem: invalid limit (%zu)", limit); mprintf(fd, "0 invalid limit value"); close(fd); continue; } if (limit < mincomp) { mprintf(fd, "-1 at least %zuM must remain available for computations.", mincomp / 1048576); close(fd); continue; } if (total - limit < minuser) { mprintf(fd, "-2 at least %zuM must remain available for users.", minuser / 1048576); close(fd); continue; } syslog(LOG_INFO, "setting limit %zu (request by pid %d uid %d)", limit, cred->pid, cred->uid); if (cgroup_set_mem_limit(chier, cgroup, limit) < 0) mprintf(fd, "0 error: %s", strerror(errno)); else mprintf(fd, "1 limit set"); } else { syslog(LOG_WARNING, "invalid command (%s)", line); } close(fd); } return EXIT_FAILURE; }