Skip to content
Snippets Groups Projects
Select Git revision
  • feea6da45ccf8e506e4c8c19bc0b615a592ce761
  • master default protected
2 results

parse_op

Blame
  • compctld.c 6.48 KiB
    #define _GNU_SOURCE /* struct ucred */
    #include <assert.h>
    #include <errno.h>
    #include <signal.h>
    #include <stdarg.h>
    #include <stdio.h>
    #include <stdlib.h>
    #include <string.h>
    #include <sys/stat.h>
    #include <sys/socket.h>
    #include <sys/un.h>
    #include <syslog.h>
    #include <unistd.h>
    
    #include "cgroup.h"
    #include "common.h"
    
    
    #define begins_with(s_, a_) (!strncmp(s_, a_, strlen(s_)))
    
    
    void
    logperror(const char *s)
    {
    	syslog(LOG_ERR, "%s: %s", s, strerror(errno));
    }
    
    void
    memory_limits(size_t *minuser, size_t *mincomp, size_t *maxcomp, size_t *total)
    {
    	FILE *f = fopen("/proc/meminfo", "r");
    	char line[1024];
    	while (fgets(line, sizeof(line), f)) {
    		if (begins_with("MemTotal:", line)) {
    			*total = 0;
    			sscanf(line, "MemTotal:%zu", total);
    			*total *= 1024;
    			break;
    		}
    	}
    	fclose(f);
    
    	*minuser = *total * split_ratio;
    	if (*minuser < static_minfree)
    		*minuser = static_minfree;
    	if (*minuser > static_maxfree)
    		*minuser = static_maxfree;
    
    	*mincomp = static_minfree;
    	ssize_t smaxcomp = *total - *minuser;
    	*maxcomp = smaxcomp > 0 ? smaxcomp : 0;
    	/* maxcomp < mincomp may happen; they are used in different
    	 * settings. */
    }
    
    size_t
    get_default_mem_limit(void)
    {
    	size_t minuser, mincomp, maxcomp, total;
    	memory_limits(&minuser, &mincomp, &maxcomp, &total);
    	return maxcomp;
    }
    
    
    void
    cgroup_init(void)
    {
    	if (cgroup_setup(chier, "memory") < 0)
    		exit(EXIT_FAILURE);
    	int ret = cgroup_create(chier, cgroup);
    	if (ret < 0)
    		exit(EXIT_FAILURE);
    	if (ret > 0) {
    		/* CGroup newly created, set limit. */
    		if (cgroup_set_mem_limit(chier, cgroup, get_default_mem_limit()) < 0)
    			exit(EXIT_FAILURE);
    	}
    }
    
    
    void
    mprintf(int fd, char *fmt, ...)
    {
    	char buf[1024];
    
    	va_list v;
    	va_start(v, fmt);
    	vsnprintf(buf, sizeof(buf), fmt, v);
    	va_end(v);
    
    	struct iovec iov = {
    		.iov_base = buf,
    		.iov_len = strlen(buf),
    	};
    	struct msghdr msg = {
    		.msg_iov = &iov,
    		.msg_iovlen = 1,
    	};
    	ssize_t sent = sendmsg(fd, &msg, 0);
    	if (sent < 0) {
    		logperror("sendmsg");
    		return;
    	}
    	if ((size_t) sent < iov.iov_len) {
    		syslog(LOG_INFO, "incomplete send %zd < %zu, FIXME", sent, iov.iov_len);
    		return;
    	}
    }
    
    
    int
    main(int argc, char *argv[])
    {
    	/* Do this while everyone can still see the error. */
    	cgroup_init();
    
    	pid_t p = fork();
    	if (p < 0) {
    		perror("fork");
    		exit(EXIT_FAILURE);
    	}
    	if (p > 0)
    		exit(EXIT_SUCCESS);
    
    	fclose(stderr);
    	fclose(stdout);
    	fclose(stdin);
    	openlog("compctl", LOG_PID, LOG_DAEMON);
    	cgroup_perror = logperror;
    
    	setsid();
    
    	int s = socket(AF_UNIX, SOCK_STREAM, 0);
    	/* TODO: Protect against double execution? */
    	unlink(SOCKFILE);
    	struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE };
    	if (bind(s, (struct sockaddr *) &sun, sizeof(sun.sun_family) + strlen(sun.sun_path) + 1) < 0) {
    		logperror(SOCKFILE);
    		exit(EXIT_FAILURE);
    	}
    	chmod(SOCKFILE, 0777);
    	listen(s, 10);
    
    	int fd;
    	while ((fd = accept(s, NULL, NULL)) >= 0) {
    		/* We handle only a single client at a time. This means
    		 * that it is rather easy to write a script that will DOS
    		 * the daemon, this is just an attack vector we ignore. */
    		/* TODO: alarm() to wake from stuck clients. */
    
    		/* Decode the message with command and credentials. */
    
    		int on = 1; setsockopt(fd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on));
    
    		struct ucred *cred;
    		char cbuf[CMSG_SPACE(sizeof(*cred))];
    		char line[1024];
    		struct iovec iov = {
    			.iov_base = line,
    			.iov_len = sizeof(line),
    		};
    		struct msghdr msg = {
    			.msg_iov = &iov,
    			.msg_iovlen = 1,
    			.msg_control = cbuf,
    			.msg_controllen = sizeof(cbuf),
    		};
    		char *errmsg;
    recvagain:;
    		int replylen = recvmsg(fd, &msg, 0);
    		if (replylen < 0) {
    			if (errno == EAGAIN)
    				goto recvagain;
    			errmsg = "recvmsg";
    sockerror:
    			logperror(errmsg);
    			close(fd);
    			continue;
    		}
    		struct cmsghdr *cmsg;
    		cmsg = CMSG_FIRSTHDR(&msg);
    		if (cmsg == NULL || cmsg->cmsg_len != CMSG_LEN(sizeof(*cred))) {
    			syslog(LOG_INFO, "want %zu", CMSG_LEN(sizeof(*cred)));
    			errmsg = "cmsg";
    			goto sockerror;
    		}
    		if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDENTIALS) {
    			errmsg = "cmsg designation";
    			goto sockerror;
    		}
    		cred = (struct ucred *) CMSG_DATA(cmsg);
    
    		line[replylen] = 0;
    		if (replylen < 2) {
    			syslog(LOG_WARNING, "protocol error (%s)", line);
    			close(fd);
    			continue;
    		}
    
    		/* Analyze command */
    		if (!strcmp("blessme", line)) {
    			syslog(LOG_INFO, "new computation process %d", cred->pid);
    			if (cgroup_add_task(chier, cgroup, cred->pid) < 0)
    				mprintf(fd, "0 error: %s", strerror(errno));
    			else
    				mprintf(fd, "1 blessed");
    
    		} else if (begins_with("kill ", line)) {
    			pid_t pid = atoi(line + strlen("kill "));
    
    			/* Sanity check. */
    			if (pid < 10) {
    				syslog(LOG_WARNING, "kill: invalid pid (%d)", pid);
    				mprintf(fd, "0 invalid pid");
    				close(fd);
    				continue;
    			}
    			if (!cgroup_is_task_in_cgroup(chier, cgroup, pid)) {
    				mprintf(fd, "0 task not marked as computation");
    				close(fd);
    				continue;
    			}
    
    			syslog(LOG_INFO, "killing process %d (request by pid %d uid %d)", pid, cred->pid, cred->uid);
    			kill(pid, SIGTERM);
    			/* TODO: Grace period and then kill with SIGKILL. */
    			mprintf(fd, "1 task killed");
    
    		} else if (!strcmp("killall", line)) {
    			pid_t *tasks;
    			int tasks_n = cgroup_task_list(chier, cgroup, &tasks);
    			if (tasks_n < 0) {
    				mprintf(fd, "0 error: %s\r\n", strerror(errno));
    				close(fd);
    				continue;
    			}
    			for (int i = 0; i < tasks_n; i++) {
    				syslog(LOG_INFO, "killing process %d (mass request by pid %d uid %d)", tasks[i], cred->pid, cred->uid);
    				kill(tasks[i], SIGTERM);
    			}
    			/* TODO: Grace period and then kill with SIGKILL. */
    			mprintf(fd, "1 %d tasks killed", tasks_n);
    			free(tasks);
    
    		} else if (begins_with("limitmem ", line)) {
    			size_t limit = atol(line + strlen("limitmem "));
    			size_t minuser, mincomp, maxcomp, total;
    			memory_limits(&minuser, &mincomp, &maxcomp, &total);
    
    			if (limit < mincomp) {
    				mprintf(fd, "-1 at least %zuM must remain available for computations.", mincomp / 1048576);
    				close(fd);
    				continue;
    			}
    			if (limit > total || total - limit < minuser) {
    				mprintf(fd, "-2 at least %zuM must remain available for users; maximum limit for computations is %zuM.", minuser / 1048576, (total - minuser) / 1048576);
    				close(fd);
    				continue;
    			}
    
    			syslog(LOG_INFO, "setting limit %zu (request by pid %d uid %d)", limit, cred->pid, cred->uid);
    			if (cgroup_set_mem_limit(chier, cgroup, limit) < 0)
    				mprintf(fd, "0 error: %s", strerror(errno));
    			else
    				mprintf(fd, "1 limit set");
    
    		} else {
    			syslog(LOG_WARNING, "invalid command (%s)", line);
    		}
    
    		close(fd);
    	}
    
    	return EXIT_FAILURE;
    }