#define _GNU_SOURCE /* struct ucred */
#include <assert.h>
#include <errno.h>
#include <signal.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <syslog.h>
#include <unistd.h>

#include "cgroup.h"
#include "common.h"


#define begins_with(s_, a_) (!strncmp(s_, a_, sizeof(s_) - 1))


void
logperror(const char *s)
{
	syslog(LOG_ERR, "%s: %s", s, strerror(errno));
}

void
memory_limits(size_t *minuser, size_t *mincomp, size_t *maxcomp, size_t *total)
{
	FILE *f = fopen("/proc/meminfo", "r");
	char line[1024];
	while (fgets(line, sizeof(line), f)) {
		if (begins_with("MemTotal:", line)) {
			*total = 0;
			sscanf(line, "MemTotal:%zu", total);
			*total *= 1024;
			break;
		}
	}
	fclose(f);

	*minuser = *total * split_ratio;
	if (*minuser < static_minfree)
		*minuser = static_minfree;
	if (*minuser > static_maxfree)
		*minuser = static_maxfree;

	*mincomp = static_minfree;
	*maxcomp = *total - *minuser;
	if (*maxcomp < 0) *maxcomp = 0;
	/* maxcomp < mincomp may happen; they are used in different
	 * settings. */
}

size_t
get_default_mem_limit(void)
{
	size_t minuser, mincomp, maxcomp, total;
	memory_limits(&minuser, &mincomp, &maxcomp, &total);
	return maxcomp;
}


void
cgroup_init(void)
{
	if (cgroup_setup(chier, "memory") < 0)
		exit(EXIT_FAILURE);
	int ret = cgroup_create(chier, cgroup);
	if (ret < 0)
		exit(EXIT_FAILURE);
	if (ret > 0) {
		/* CGroup newly created, set limit. */
		if (cgroup_set_mem_limit(chier, cgroup, get_default_mem_limit()) < 0)
			exit(EXIT_FAILURE);
	}
}


void
mprintf(int fd, char *fmt, ...)
{
	char buf[1024];

	va_list v;
	va_start(v, fmt);
	vsnprintf(buf, sizeof(buf), fmt, v);
	va_end(v);

	struct iovec iov = {
		.iov_base = buf,
		.iov_len = strlen(buf),
	};
	struct msghdr msg = {
		.msg_iov = &iov,
		.msg_iovlen = 1,
	};
	ssize_t sent = sendmsg(fd, &msg, 0);
	if (sent < 0) {
		logperror("sendmsg");
		return;
	}
	if (sent < iov.iov_len) {
		syslog(LOG_INFO, "incomplete send %zd < %zu, FIXME", sent, iov.iov_len);
		return;
	}
}


int
main(int argc, char *argv[])
{
	/* Do this while everyone can still see the error. */
	cgroup_init();

	pid_t p = fork();
	if (p < 0) {
		perror("fork");
		exit(EXIT_FAILURE);
	}
	if (p > 0)
		exit(EXIT_SUCCESS);

	fclose(stderr);
	fclose(stdout);
	fclose(stdin);
	openlog("compctl", LOG_PID, LOG_DAEMON);
	cgroup_perror = logperror;

	setsid();

	int s = socket(AF_UNIX, SOCK_STREAM, 0);
	/* TODO: Protect against double execution? */
	unlink(SOCKFILE);
	struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE };
	if (bind(s, (struct sockaddr *) &sun, sizeof(sun.sun_family) + strlen(sun.sun_path) + 1) < 0) {
		logperror(SOCKFILE);
		exit(EXIT_FAILURE);
	}
	chmod(SOCKFILE, 0777);
	listen(s, 10);

	int fd;
	while ((fd = accept(s, NULL, NULL)) >= 0) {
		/* We handle only a single client at a time. This means
		 * that it is rather easy to write a script that will DOS
		 * the daemon, this is just an attack vector we ignore. */
		/* TODO: alarm() to wake from stuck clients. */

		/* Decode the message with command and credentials. */

		int on = 1; setsockopt(fd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on));

		struct ucred *cred;
		char cbuf[CMSG_SPACE(sizeof(*cred))];
		char line[1024];
		struct iovec iov = {
			.iov_base = line,
			.iov_len = sizeof(line),
		};
		struct msghdr msg = {
			.msg_iov = &iov,
			.msg_iovlen = 1,
			.msg_control = cbuf,
			.msg_controllen = sizeof(cbuf),
		};
		char *errmsg;
recvagain:;
		int replylen = recvmsg(fd, &msg, 0);
		if (replylen < 0) {
			if (errno == EAGAIN)
				goto recvagain;
			errmsg = "recvmsg";
sockerror:
			logperror(errmsg);
			close(fd);
			continue;
		}
		struct cmsghdr *cmsg;
		cmsg = CMSG_FIRSTHDR(&msg);
		if (cmsg == NULL || cmsg->cmsg_len != CMSG_LEN(sizeof(*cred))) {
			syslog(LOG_INFO, "want %lu", CMSG_LEN(sizeof(*cred)));
			errmsg = "cmsg";
			goto sockerror;
		}
		if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDENTIALS) {
			errmsg = "cmsg designation";
			goto sockerror;
		}
		cred = (struct ucred *) CMSG_DATA(cmsg);

		line[replylen] = 0;
		if (replylen < 2) {
			syslog(LOG_WARNING, "protocol error (%s)", line);
			close(fd);
			continue;
		}

		/* Analyze command */
		if (!strcmp("blessme", line)) {
			syslog(LOG_INFO, "new computation process %d", cred->pid);
			if (cgroup_add_task(chier, cgroup, cred->pid) < 0)
				mprintf(fd, "0 error: %s", strerror(errno));
			else
				mprintf(fd, "1 blessed");

		} else if (begins_with("stop ", line)) {
			pid_t pid = atoi(line + sizeof("stop "));

			/* Sanity check. */
			if (pid < 10 || pid > 32768) {
				syslog(LOG_WARNING, "stop: invalid pid (%d)", pid);
				mprintf(fd, "0 invalid pid");
				close(fd);
				continue;
			}
			if (!cgroup_is_task_in_cgroup(chier, cgroup, pid)) {
				mprintf(fd, "0 task not marked as computation");
				close(fd);
				continue;
			}

			syslog(LOG_INFO, "stopping process %d (request by pid %d uid %d)", pid, cred->pid, cred->uid);
			kill(pid, SIGTERM);
			/* TODO: Grace period and then kill with SIGKILL. */
			mprintf(fd, "1 task stopped");

		} else if (!strcmp("stopall", line)) {
			pid_t *tasks;
			int tasks_n = cgroup_task_list(chier, cgroup, &tasks);
			if (tasks_n < 0) {
				mprintf(fd, "0 error: %s\r\n", strerror(errno));
				close(fd);
				continue;
			}
			for (int i = 0; i < tasks_n; i++) {
				syslog(LOG_INFO, "stopping process %d (mass request by pid %d uid %d)", tasks[i], cred->pid, cred->uid);
				kill(tasks[i], SIGTERM);
			}
			/* TODO: Grace period and then kill with SIGKILL. */
			mprintf(fd, "1 %d tasks stopped", tasks_n);
			free(tasks);

		} else if (begins_with("limitmem ", line)) {
			size_t limit = atol(line + sizeof("limitmem "));
			size_t minuser, mincomp, maxcomp, total;
			memory_limits(&minuser, &mincomp, &maxcomp, &total);

			/* Sanity check. */
			if (limit < 1024 || limit > total) {
				syslog(LOG_WARNING, "limitmem: invalid limit (%zu)", limit);
				mprintf(fd, "0 invalid limit value");
				close(fd);
				continue;
			}

			if (limit < mincomp) {
				mprintf(fd, "-1 at least %zuM must remain available for computations.", mincomp / 1048576);
				close(fd);
				continue;
			}
			if (total - limit < minuser) {
				mprintf(fd, "-2 at least %zuM must remain available for users.", minuser / 1048576);
				close(fd);
				continue;
			}

			syslog(LOG_INFO, "setting limit %zu (request by pid %d uid %d)", limit, cred->pid, cred->uid);
			if (cgroup_set_mem_limit(chier, cgroup, limit) < 0)
				mprintf(fd, "0 error: %s", strerror(errno));
			else
				mprintf(fd, "1 limit set");

		} else {
			syslog(LOG_WARNING, "invalid command (%s)", line);
		}

		close(fd);
	}

	return EXIT_FAILURE;
}