Commit a1a8e13f authored by Petr Baudis's avatar Petr Baudis

Rewrite the communication protocol

sendmsg() with only ancilliary message does not work, apparently.
Therefore, to make things cleaner, pass command/reply directly using
sendmsg() instead of newline-terminated strings.
parent c40a98c1
......@@ -10,9 +10,9 @@ tweaking the cgroup limits.
The client compctl interface simply queries the server using
a synchronous protocol over a UNIX socket. First, the client
sends a SCM_CREDENTIALS ancilliary message. Then, it follows
with a CRLF-terminated command string and receives a CRLF-terminated
reply string. Connection is closed immediately on breach of protocol.
sends a command string message coupled with a SCM_CREDENTIALS
ancilliary message. Then, it receives a reply message.
Connection is closed immediately on breach of protocol.
You can tweak some simple compile-time configuration variables
......
#define _GNU_SOURCE /* struct ucred */
#include <assert.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
......@@ -11,8 +12,8 @@
#include "common.h"
FILE *
connectd(void)
char *
daemon_chat(char *cmd)
{
int s = socket(AF_UNIX, SOCK_STREAM, 0);
struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE };
......@@ -21,29 +22,67 @@ connectd(void)
exit(EXIT_FAILURE);
}
/* Send message with credentials. */
/* Send command. */
struct iovec iov_cmd = {
.iov_base = cmd,
.iov_len = strlen(cmd),
};
struct msghdr msg = {
.msg_iov = &iov_cmd,
.msg_iovlen = 1,
};
/* Include credentials in the message. */
struct ucred cred = {
.pid = getpid(),
.uid = getuid(),
.gid = getgid(),
};
char cbuf[CMSG_SPACE(sizeof(cred))];
struct msghdr msg = {
.msg_control = cbuf,
.msg_controllen = sizeof(cbuf),
};
msg.msg_control = cbuf;
msg.msg_controllen = sizeof(cbuf);
struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_CREDENTIALS;
cmsg->cmsg_len = CMSG_LEN(sizeof(cred));
memcpy(CMSG_DATA(cmsg), &cred, sizeof(cred));
sendmsg(s, &msg, 0);
ssize_t sent = sendmsg(s, &msg, 0);
if (sent < 0) {
perror("sendmsg");
exit(EXIT_FAILURE);
}
if (sent < msg.msg_iov->iov_len) {
fprintf(stderr, "incomplete send %zd < %zu, FIXME\n", sent, msg.msg_iov->iov_len);
exit(EXIT_FAILURE);
}
/* Receive reply. */
char reply[1024];
struct iovec iov_reply = {
.iov_base = reply,
.iov_len = sizeof(reply),
};
msg.msg_iov = &iov_reply;
msg.msg_iovlen = 1;
recvagain:;
int replylen = recvmsg(s, &msg, 0);
if (replylen < 0) {
if (errno == EAGAIN)
goto recvagain;
perror("recvmsg");
exit(EXIT_FAILURE);
}
if (replylen >= 1024) {
fprintf(stderr, "too long reply from the server\n");
exit(EXIT_FAILURE);
}
reply[replylen] = 0;
return fdopen(s, "rw");
return strdup(reply);
}
......@@ -58,15 +97,12 @@ help(FILE *f)
int
run(int argc, char *argv[])
{
FILE *f = connectd();
fputs("blessme\r\n", f);
char line[1024];
fgets(line, sizeof(line), f);
fclose(f);
char *line = daemon_chat("blessme");
if (line[0] != '1') {
fputs(*line ? line : "unexpected hangup\n", stderr);
return EXIT_FAILURE;
}
free(line);
char *argvx[argc + 1];
for (int i = 0; i < argc; i++)
......@@ -91,45 +127,38 @@ screen(int argc, char *argv[])
void
stop(pid_t pid)
{
FILE *f = connectd();
fprintf(f, "stop %d\r\n", pid);
char line[1024];
fgets(line, sizeof(line), f);
fclose(f);
char cmd[256]; snprintf(cmd, sizeof(cmd), "stop %d", pid);
char *line = daemon_chat(cmd);
if (line[0] != '1') {
fputs(*line ? line : "unexpected hangup\n", stderr);
exit(EXIT_FAILURE);
}
free(line);
}
void
stop_all(void)
{
FILE *f = connectd();
fputs("stopall\r\n", f);
char line[1024];
fgets(line, sizeof(line), f);
fclose(f);
char *line = daemon_chat("stopall");
if (line[0] != '1') {
fputs(*line ? line : "unexpected hangup\n", stderr);
exit(EXIT_FAILURE);
}
fputs(line + 2, stdout);
free(line);
}
void
limit_mem(size_t limit)
{
FILE *f = connectd();
fprintf(f, "limitmem %zu\r\n", limit);
char line[1024];
fgets(line, sizeof(line), f);
fclose(f);
char cmd[256]; snprintf(cmd, sizeof(cmd), "limitmem %zu", limit);
char *line = daemon_chat(cmd);
if (line[0] != '1') {
/* TODO: Error message postprocessing. */
fputs(*line ? line : "unexpected hangup\n", stderr);
exit(EXIT_FAILURE);
}
free(line);
}
void
......
......@@ -2,6 +2,7 @@
#include <assert.h>
#include <errno.h>
#include <signal.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
......@@ -77,6 +78,36 @@ cgroup_init(void)
}
void
mprintf(int fd, char *fmt, ...)
{
char buf[1024];
va_list v;
va_start(v, fmt);
vsnprintf(buf, sizeof(buf), fmt, v);
va_end(v);
struct iovec iov = {
.iov_base = buf,
.iov_len = strlen(buf),
};
struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
};
ssize_t sent = sendmsg(fd, &msg, 0);
if (sent < 0) {
logperror("sendmsg");
return;
}
if (sent < iov.iov_len) {
syslog(LOG_INFO, "incomplete send %zd < %zu, FIXME", sent, iov.iov_len);
return;
}
}
int
main(int argc, char *argv[])
{
......@@ -100,8 +131,6 @@ main(int argc, char *argv[])
setsid();
int s = socket(AF_UNIX, SOCK_STREAM, 0);
int on = 1; setsockopt(s, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on));
/* TODO: Protect against double execution? */
unlink(SOCKFILE);
struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE };
......@@ -119,17 +148,27 @@ main(int argc, char *argv[])
* the daemon, this is just an attack vector we ignore. */
/* TODO: alarm() to wake from stuck clients. */
/* Decode the message with credentials. */
/* Decode the message with command and credentials. */
int on = 1; setsockopt(fd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on));
struct ucred *cred;
char cbuf[CMSG_SPACE(sizeof(*cred))];
char line[1024];
struct iovec iov = {
.iov_base = line,
.iov_len = sizeof(line),
};
struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = cbuf,
.msg_controllen = sizeof(cbuf),
};
char *errmsg;
recvagain:
if (recvmsg(fd, &msg, MSG_WAITALL) < 0) {
recvagain:;
int replylen = recvmsg(fd, &msg, 0);
if (replylen < 0) {
if (errno == EAGAIN)
goto recvagain;
errmsg = "recvmsg";
......@@ -151,24 +190,20 @@ sockerror:
}
cred = (struct ucred *) CMSG_DATA(cmsg);
FILE *f = fdopen(fd, "r");
char line[1024];
fgets(line, sizeof(line), f);
size_t linelen = strlen(line);
if (linelen < 2 || strcmp(&line[linelen - 2], "\r\n")) {
line[replylen] = 0;
if (replylen < 2) {
syslog(LOG_WARNING, "protocol error (%s)", line);
fclose(f);
close(fd);
continue;
}
line[linelen - 2] = 0;
/* Analyze command */
if (!strcmp("blessme", line)) {
syslog(LOG_INFO, "new computation process %d", cred->pid);
if (cgroup_add_task(chier, cgroup, cred->pid) < 0)
fprintf(f, "0 error: %s\r\n", strerror(errno));
mprintf(fd, "0 error: %s", strerror(errno));
else
fputs("1 blessed\r\n", f);
mprintf(fd, "1 blessed");
} else if (begins_with("stop ", line)) {
pid_t pid = atoi(line + sizeof("stop "));
......@@ -176,27 +211,27 @@ sockerror:
/* Sanity check. */
if (pid < 10 || pid > 32768) {
syslog(LOG_WARNING, "stop: invalid pid (%d)", pid);
fputs("0 invalid pid\r\n", f);
fclose(f);
mprintf(fd, "0 invalid pid");
close(fd);
continue;
}
if (!cgroup_is_task_in_cgroup(chier, cgroup, pid)) {
fputs("0 task not marked as computation\r\n", f);
fclose(f);
mprintf(fd, "0 task not marked as computation");
close(fd);
continue;
}
syslog(LOG_INFO, "stopping process %d (request by pid %d uid %d)", pid, cred->pid, cred->uid);
kill(pid, SIGTERM);
/* TODO: Grace period and then kill with SIGKILL. */
fputs("1 task stopped\r\n", f);
mprintf(fd, "1 task stopped");
} else if (!strcmp("stopall", line)) {
pid_t *tasks;
int tasks_n = cgroup_task_list(chier, cgroup, &tasks);
if (tasks_n < 0) {
fprintf(f, "0 error: %s\r\n", strerror(errno));
fclose(f);
mprintf(fd, "0 error: %s\r\n", strerror(errno));
close(fd);
continue;
}
for (int i = 0; i < tasks_n; i++) {
......@@ -204,7 +239,7 @@ sockerror:
kill(tasks[i], SIGTERM);
}
/* TODO: Grace period and then kill with SIGKILL. */
fprintf(f, "1 %d tasks stopped\r\n", tasks_n);
mprintf(fd, "1 %d tasks stopped", tasks_n);
free(tasks);
} else if (begins_with("limitmem ", line)) {
......@@ -215,33 +250,33 @@ sockerror:
/* Sanity check. */
if (limit < 1024 || limit > total) {
syslog(LOG_WARNING, "limitmem: invalid limit (%zu)", limit);
fputs("0 invalid limit value\r\n", f);
fclose(f);
mprintf(fd, "0 invalid limit value");
close(fd);
continue;
}
if (limit < mincomp) {
fprintf(f, "-1 at least %zuM must remain available for computations.\r\n", mincomp / 1048576);
fclose(f);
mprintf(fd, "-1 at least %zuM must remain available for computations.", mincomp / 1048576);
close(fd);
continue;
}
if (total - limit < minuser) {
fprintf(f, "-2 at least %zuM must remain available for users.\r\n", minuser / 1048576);
fclose(f);
mprintf(fd, "-2 at least %zuM must remain available for users.", minuser / 1048576);
close(fd);
continue;
}
syslog(LOG_INFO, "setting limit %zu (request by pid %d uid %d)", limit, cred->pid, cred->uid);
if (cgroup_set_mem_limit(chier, cgroup, limit) < 0)
fprintf(f, "0 error: %s\r\n", strerror(errno));
mprintf(fd, "0 error: %s", strerror(errno));
else
fputs("1 limit set\r\n", f);
mprintf(fd, "1 limit set");
} else {
syslog(LOG_WARNING, "invalid command (%s)", line);
}
fclose(f);
close(fd);
}
return EXIT_FAILURE;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment