#define _GNU_SOURCE /* struct ucred */
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>

#include "cgroup.h"
#include "common.h"


FILE *
connectd(void)
{
	int s = socket(AF_UNIX, SOCK_STREAM, 0);
	struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE };
	if (connect(s, (struct sockaddr *) &sun, sizeof(sun.sun_family) + strlen(sun.sun_path) + 1) < 0) {
		perror(SOCKFILE);
		exit(EXIT_FAILURE);
	}

	/* Send message with credentials. */

	struct ucred cred = {
		.pid = getpid(),
		.uid = getuid(),
		.gid = getgid(),
	};

	char cbuf[CMSG_SPACE(sizeof(cred))];
	struct msghdr msg = {0};
	msg.msg_control = cbuf;
	msg.msg_controllen = sizeof(cbuf);

	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
	cmsg->cmsg_level = SOL_SOCKET;
	cmsg->cmsg_type = SCM_CREDENTIALS;
	cmsg->cmsg_len = CMSG_LEN(sizeof(cred));
	memcpy(CMSG_DATA(cmsg), &cred, sizeof(cred));

	sendmsg(s, &msg, 0);

	return fdopen(s, "rw");
}


void
help(FILE *f)
{
	fputs("compctl - Computations under control\n\n"
#include "help-in-quotes"
	"Contact <wizards@kam.mff.cuni.cz> with bug reports and comments.\n", f);
}

int
run(int argc, char *argv[])
{
	FILE *f = connectd();
	fputs("blessme\r\n", f);
	char line[1024];
	fgets(line, sizeof(line), f);
	fclose(f);
	if (line[0] != '1') {
		fputs(*line ? line : "unexpected hangup\n", stderr);
		return EXIT_FAILURE;
	}

	char *argvx[argc + 1];
	for (int i = 0; i < argc; i++)
		argvx[i] = argv[i];
	argvx[argc] = NULL;
	execvp(argvx[0], argvx);
	perror("execvp");
	return EXIT_FAILURE;
}

int
screen(int argc, char *argv[])
{
	char *argvx[argc + 2];
	argvx[0] = "screen";
	argvx[1] = "-m";
	for (int i = 0; i < argc; i++)
		argvx[i + 2] = argv[i];
	return run(argc + 2, argvx);
}

void
stop(pid_t pid)
{
	FILE *f = connectd();
	fprintf(f, "stop %d\r\n", pid);
	char line[1024];
	fgets(line, sizeof(line), f);
	fclose(f);
	if (line[0] != '1') {
		fputs(*line ? line : "unexpected hangup\n", stderr);
		exit(EXIT_FAILURE);
	}
}

void
stop_all(void)
{
	FILE *f = connectd();
	fputs("stopall\r\n", f);
	char line[1024];
	fgets(line, sizeof(line), f);
	fclose(f);
	if (line[0] != '1') {
		fputs(*line ? line : "unexpected hangup\n", stderr);
		exit(EXIT_FAILURE);
	}
	fputs(line + 2, stdout);
}

void
limit_mem(size_t limit)
{
	FILE *f = connectd();
	fprintf(f, "limitmem %zu\r\n", limit);
	char line[1024];
	fgets(line, sizeof(line), f);
	fclose(f);
	if (line[0] != '1') {
		/* TODO: Error message postprocessing. */
		fputs(*line ? line : "unexpected hangup\n", stderr);
		exit(EXIT_FAILURE);
	}
}

void
usage(void)
{
	size_t usage = cgroup_get_mem_usage(chier, cgroup);
	size_t limit = cgroup_get_mem_limit(chier, cgroup);
	printf("Memory usage:\t%zuM / %zuM\n", usage / 1048576, limit / 1048576);
}

void
list(void)
{
	pid_t *tasks;
	int tasks_n = cgroup_task_list(chier, cgroup, &tasks);
	if (tasks_n < 0)
		exit(EXIT_FAILURE);
	for (int i = 0; i < tasks_n; i++) {
		/* TODO: Print process details. */
		printf("%d\n", tasks[i]);
	}
}

int
main(int argc, char *argv[])
{
	int optind = 1;

	if (argc == optind) {
		help(stderr);
		return EXIT_FAILURE;
	}

	while (argc > optind) {
		char *cmd = argv[optind++];
		if (!strcmp(cmd, "--run")) {
			if (argc <= optind) {
				fputs("missing arguments for --run\n", stderr);
				exit(EXIT_FAILURE);
			}
			return run(argc - optind, &argv[optind]);

		} else if (!strcmp(cmd, "--screen")) {
			if (argc <= optind) {
				fputs("missing arguments for --screen\n", stderr);
				exit(EXIT_FAILURE);
			}
			return screen(argc - optind, &argv[optind]);

		} else if (!strcmp(cmd, "--usage")) {
			usage();

		} else if (!strcmp(cmd, "--list")) {
			list();

		} else if (!strcmp(cmd, "--stop")) {
			if (argc <= optind) {
				fputs("missing argument for --stop\n", stderr);
				exit(EXIT_FAILURE);
			}
			stop(atoi(argv[optind++]));

		} else if (!strcmp(cmd, "--stopall")) {
			stop_all();

		} else if (!strcmp(cmd, "--limitmem")) {
			if (argc <= optind) {
				fputs("missing argument for --limitmem\n", stderr);
				exit(EXIT_FAILURE);
			}
			limit_mem(atol(argv[optind++]));

		} else if (!strcmp(cmd, "--help")) {
			help(stdout);
		}
	}

	return EXIT_SUCCESS;
}