Select Git revision
-
Martin Mareš authoredMartin Mareš authored
compctld.c 6.48 KiB
#define _GNU_SOURCE /* struct ucred */
#include <assert.h>
#include <errno.h>
#include <signal.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <syslog.h>
#include <unistd.h>
#include "cgroup.h"
#include "common.h"
#define begins_with(s_, a_) (!strncmp(s_, a_, strlen(s_)))
void
logperror(const char *s)
{
syslog(LOG_ERR, "%s: %s", s, strerror(errno));
}
void
memory_limits(size_t *minuser, size_t *mincomp, size_t *maxcomp, size_t *total)
{
FILE *f = fopen("/proc/meminfo", "r");
char line[1024];
while (fgets(line, sizeof(line), f)) {
if (begins_with("MemTotal:", line)) {
*total = 0;
sscanf(line, "MemTotal:%zu", total);
*total *= 1024;
break;
}
}
fclose(f);
*minuser = *total * split_ratio;
if (*minuser < static_minfree)
*minuser = static_minfree;
if (*minuser > static_maxfree)
*minuser = static_maxfree;
*mincomp = static_minfree;
ssize_t smaxcomp = *total - *minuser;
*maxcomp = smaxcomp > 0 ? smaxcomp : 0;
/* maxcomp < mincomp may happen; they are used in different
* settings. */
}
size_t
get_default_mem_limit(void)
{
size_t minuser, mincomp, maxcomp, total;
memory_limits(&minuser, &mincomp, &maxcomp, &total);
return maxcomp;
}
void
cgroup_init(void)
{
if (cgroup_setup(chier, "memory") < 0)
exit(EXIT_FAILURE);
int ret = cgroup_create(chier, cgroup);
if (ret < 0)
exit(EXIT_FAILURE);
if (ret > 0) {
/* CGroup newly created, set limit. */
if (cgroup_set_mem_limit(chier, cgroup, get_default_mem_limit()) < 0)
exit(EXIT_FAILURE);
}
}
void
mprintf(int fd, char *fmt, ...)
{
char buf[1024];
va_list v;
va_start(v, fmt);
vsnprintf(buf, sizeof(buf), fmt, v);
va_end(v);
struct iovec iov = {
.iov_base = buf,
.iov_len = strlen(buf),
};
struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
};
ssize_t sent = sendmsg(fd, &msg, 0);
if (sent < 0) {
logperror("sendmsg");
return;
}
if ((size_t) sent < iov.iov_len) {
syslog(LOG_INFO, "incomplete send %zd < %zu, FIXME", sent, iov.iov_len);
return;
}
}
int
main(int argc, char *argv[])
{
/* Do this while everyone can still see the error. */
cgroup_init();
pid_t p = fork();
if (p < 0) {
perror("fork");
exit(EXIT_FAILURE);
}
if (p > 0)
exit(EXIT_SUCCESS);
fclose(stderr);
fclose(stdout);
fclose(stdin);
openlog("compctl", LOG_PID, LOG_DAEMON);
cgroup_perror = logperror;
setsid();
int s = socket(AF_UNIX, SOCK_STREAM, 0);
/* TODO: Protect against double execution? */
unlink(SOCKFILE);
struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE };
if (bind(s, (struct sockaddr *) &sun, sizeof(sun.sun_family) + strlen(sun.sun_path) + 1) < 0) {
logperror(SOCKFILE);
exit(EXIT_FAILURE);
}
chmod(SOCKFILE, 0777);
listen(s, 10);
int fd;
while ((fd = accept(s, NULL, NULL)) >= 0) {
/* We handle only a single client at a time. This means
* that it is rather easy to write a script that will DOS
* the daemon, this is just an attack vector we ignore. */
/* TODO: alarm() to wake from stuck clients. */
/* Decode the message with command and credentials. */
int on = 1; setsockopt(fd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on));
struct ucred *cred;
char cbuf[CMSG_SPACE(sizeof(*cred))];
char line[1024];
struct iovec iov = {
.iov_base = line,
.iov_len = sizeof(line),
};
struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = cbuf,
.msg_controllen = sizeof(cbuf),
};
char *errmsg;
recvagain:;
int replylen = recvmsg(fd, &msg, 0);
if (replylen < 0) {
if (errno == EAGAIN)
goto recvagain;
errmsg = "recvmsg";
sockerror:
logperror(errmsg);
close(fd);
continue;
}
struct cmsghdr *cmsg;
cmsg = CMSG_FIRSTHDR(&msg);
if (cmsg == NULL || cmsg->cmsg_len != CMSG_LEN(sizeof(*cred))) {
syslog(LOG_INFO, "want %zu", CMSG_LEN(sizeof(*cred)));
errmsg = "cmsg";
goto sockerror;
}
if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDENTIALS) {
errmsg = "cmsg designation";
goto sockerror;
}
cred = (struct ucred *) CMSG_DATA(cmsg);
line[replylen] = 0;
if (replylen < 2) {
syslog(LOG_WARNING, "protocol error (%s)", line);
close(fd);
continue;
}
/* Analyze command */
if (!strcmp("blessme", line)) {
syslog(LOG_INFO, "new computation process %d", cred->pid);
if (cgroup_add_task(chier, cgroup, cred->pid) < 0)
mprintf(fd, "0 error: %s", strerror(errno));
else
mprintf(fd, "1 blessed");
} else if (begins_with("kill ", line)) {
pid_t pid = atoi(line + strlen("kill "));
/* Sanity check. */
if (pid < 10) {
syslog(LOG_WARNING, "kill: invalid pid (%d)", pid);
mprintf(fd, "0 invalid pid");
close(fd);
continue;
}
if (!cgroup_is_task_in_cgroup(chier, cgroup, pid)) {
mprintf(fd, "0 task not marked as computation");
close(fd);
continue;
}
syslog(LOG_INFO, "killing process %d (request by pid %d uid %d)", pid, cred->pid, cred->uid);
kill(pid, SIGTERM);
/* TODO: Grace period and then kill with SIGKILL. */
mprintf(fd, "1 task killed");
} else if (!strcmp("killall", line)) {
pid_t *tasks;
int tasks_n = cgroup_task_list(chier, cgroup, &tasks);
if (tasks_n < 0) {
mprintf(fd, "0 error: %s\r\n", strerror(errno));
close(fd);
continue;
}
for (int i = 0; i < tasks_n; i++) {
syslog(LOG_INFO, "killing process %d (mass request by pid %d uid %d)", tasks[i], cred->pid, cred->uid);
kill(tasks[i], SIGTERM);
}
/* TODO: Grace period and then kill with SIGKILL. */
mprintf(fd, "1 %d tasks killed", tasks_n);
free(tasks);
} else if (begins_with("limitmem ", line)) {
size_t limit = atol(line + strlen("limitmem "));
size_t minuser, mincomp, maxcomp, total;
memory_limits(&minuser, &mincomp, &maxcomp, &total);
if (limit < mincomp) {
mprintf(fd, "-1 at least %zuM must remain available for computations.", mincomp / 1048576);
close(fd);
continue;
}
if (limit > total || total - limit < minuser) {
mprintf(fd, "-2 at least %zuM must remain available for users; maximum limit for computations is %zuM.", minuser / 1048576, (total - minuser) / 1048576);
close(fd);
continue;
}
syslog(LOG_INFO, "setting limit %zu (request by pid %d uid %d)", limit, cred->pid, cred->uid);
if (cgroup_set_mem_limit(chier, cgroup, limit) < 0)
mprintf(fd, "0 error: %s", strerror(errno));
else
mprintf(fd, "1 limit set");
} else {
syslog(LOG_WARNING, "invalid command (%s)", line);
}
close(fd);
}
return EXIT_FAILURE;
}