Commit a1a8e13f authored by Petr Baudis's avatar Petr Baudis
Browse files

Rewrite the communication protocol

sendmsg() with only ancilliary message does not work, apparently.
Therefore, to make things cleaner, pass command/reply directly using
sendmsg() instead of newline-terminated strings.
parent c40a98c1
......@@ -10,9 +10,9 @@ tweaking the cgroup limits.
The client compctl interface simply queries the server using
a synchronous protocol over a UNIX socket. First, the client
sends a SCM_CREDENTIALS ancilliary message. Then, it follows
with a CRLF-terminated command string and receives a CRLF-terminated
reply string. Connection is closed immediately on breach of protocol.
sends a command string message coupled with a SCM_CREDENTIALS
ancilliary message. Then, it receives a reply message.
Connection is closed immediately on breach of protocol.
You can tweak some simple compile-time configuration variables
......
#define _GNU_SOURCE /* struct ucred */
#include <assert.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
......@@ -11,8 +12,8 @@
#include "common.h"
FILE *
connectd(void)
char *
daemon_chat(char *cmd)
{
int s = socket(AF_UNIX, SOCK_STREAM, 0);
struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE };
......@@ -21,29 +22,67 @@ connectd(void)
exit(EXIT_FAILURE);
}
/* Send message with credentials. */
/* Send command. */
struct iovec iov_cmd = {
.iov_base = cmd,
.iov_len = strlen(cmd),
};
struct msghdr msg = {
.msg_iov = &iov_cmd,
.msg_iovlen = 1,
};
/* Include credentials in the message. */
struct ucred cred = {
.pid = getpid(),
.uid = getuid(),
.gid = getgid(),
};
char cbuf[CMSG_SPACE(sizeof(cred))];
struct msghdr msg = {
.msg_control = cbuf,
.msg_controllen = sizeof(cbuf),
};
msg.msg_control = cbuf;
msg.msg_controllen = sizeof(cbuf);
struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_CREDENTIALS;
cmsg->cmsg_len = CMSG_LEN(sizeof(cred));
memcpy(CMSG_DATA(cmsg), &cred, sizeof(cred));
sendmsg(s, &msg, 0);
ssize_t sent = sendmsg(s, &msg, 0);
if (sent < 0) {
perror("sendmsg");
exit(EXIT_FAILURE);
}
if (sent < msg.msg_iov->iov_len) {
fprintf(stderr, "incomplete send %zd < %zu, FIXME\n", sent, msg.msg_iov->iov_len);
exit(EXIT_FAILURE);
}
/* Receive reply. */
char reply[1024];
struct iovec iov_reply = {
.iov_base = reply,
.iov_len = sizeof(reply),
};
msg.msg_iov = &iov_reply;
msg.msg_iovlen = 1;
recvagain:;
int replylen = recvmsg(s, &msg, 0);
if (replylen < 0) {
if (errno == EAGAIN)
goto recvagain;
perror("recvmsg");
exit(EXIT_FAILURE);
}
if (replylen >= 1024) {
fprintf(stderr, "too long reply from the server\n");
exit(EXIT_FAILURE);
}
reply[replylen] = 0;
return fdopen(s, "rw");
return strdup(reply);
}
......@@ -58,15 +97,12 @@ help(FILE *f)
int
run(int argc, char *argv[])
{
FILE *f = connectd();
fputs("blessme\r\n", f);
char line[1024];
fgets(line, sizeof(line), f);
fclose(f);
char *line = daemon_chat("blessme");
if (line[0] != '1') {
fputs(*line ? line : "unexpected hangup\n", stderr);
return EXIT_FAILURE;
}
free(line);
char *argvx[argc + 1];
for (int i = 0; i < argc; i++)
......@@ -91,45 +127,38 @@ screen(int argc, char *argv[])
void
stop(pid_t pid)
{
FILE *f = connectd();
fprintf(f, "stop %d\r\n", pid);
char line[1024];
fgets(line, sizeof(line), f);
fclose(f);
char cmd[256]; snprintf(cmd, sizeof(cmd), "stop %d", pid);
char *line = daemon_chat(cmd);
if (line[0] != '1') {
fputs(*line ? line : "unexpected hangup\n", stderr);
exit(EXIT_FAILURE);
}
free(line);
}
void
stop_all(void)
{
FILE *f = connectd();
fputs("stopall\r\n", f);
char line[1024];
fgets(line, sizeof(line), f);
fclose(f);
char *line = daemon_chat("stopall");
if (line[0] != '1') {
fputs(*line ? line : "unexpected hangup\n", stderr);
exit(EXIT_FAILURE);
}
fputs(line + 2, stdout);
free(line);
}
void
limit_mem(size_t limit)
{
FILE *f = connectd();
fprintf(f, "limitmem %zu\r\n", limit);
char line[1024];
fgets(line, sizeof(line), f);
fclose(f);
char cmd[256]; snprintf(cmd, sizeof(cmd), "limitmem %zu", limit);
char *line = daemon_chat(cmd);
if (line[0] != '1') {
/* TODO: Error message postprocessing. */
fputs(*line ? line : "unexpected hangup\n", stderr);
exit(EXIT_FAILURE);
}
free(line);
}
void
......
......@@ -2,6 +2,7 @@
#include <assert.h>
#include <errno.h>
#include <signal.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
......@@ -77,6 +78,36 @@ cgroup_init(void)
}
void
mprintf(int fd, char *fmt, ...)
{
char buf[1024];
va_list v;
va_start(v, fmt);
vsnprintf(buf, sizeof(buf), fmt, v);
va_end(v);
struct iovec iov = {
.iov_base = buf,
.iov_len = strlen(buf),
};
struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
};
ssize_t sent = sendmsg(fd, &msg, 0);
if (sent < 0) {
logperror("sendmsg");
return;
}
if (sent < iov.iov_len) {
syslog(LOG_INFO, "incomplete send %zd < %zu, FIXME", sent, iov.iov_len);
return;
}
}
int
main(int argc, char *argv[])
{
......@@ -100,8 +131,6 @@ main(int argc, char *argv[])
setsid();
int s = socket(AF_UNIX, SOCK_STREAM, 0);
int on = 1; setsockopt(s, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on));
/* TODO: Protect against double execution? */
unlink(SOCKFILE);
struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE };
......@@ -119,17 +148,27 @@ main(int argc, char *argv[])
* the daemon, this is just an attack vector we ignore. */
/* TODO: alarm() to wake from stuck clients. */
/* Decode the message with credentials. */
/* Decode the message with command and credentials. */
int on = 1; setsockopt(fd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on));
struct ucred *cred;
char cbuf[CMSG_SPACE(sizeof(*cred))];
char line[1024];
struct iovec iov = {
.iov_base = line,
.iov_len = sizeof(line),
};
struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = cbuf,
.msg_controllen = sizeof(cbuf),
};
char *errmsg;
recvagain:
if (recvmsg(fd, &msg, MSG_WAITALL) < 0) {
recvagain:;
int replylen = recvmsg(fd, &msg, 0);
if (replylen < 0) {
if (errno == EAGAIN)
goto recvagain;
errmsg = "recvmsg";
......@@ -151,24 +190,20 @@ sockerror:
}
cred = (struct ucred *) CMSG_DATA(cmsg);
FILE *f = fdopen(fd, "r");
char line[1024];
fgets(line, sizeof(line), f);
size_t linelen = strlen(line);
if (linelen < 2 || strcmp(&line[linelen - 2], "\r\n")) {
line[replylen] = 0;
if (replylen < 2) {
syslog(LOG_WARNING, "protocol error (%s)", line);
fclose(f);
close(fd);
continue;
}
line[linelen - 2] = 0;
/* Analyze command */
if (!strcmp("blessme", line)) {
syslog(LOG_INFO, "new computation process %d", cred->pid);
if (cgroup_add_task(chier, cgroup, cred->pid) < 0)
fprintf(f, "0 error: %s\r\n", strerror(errno));
mprintf(fd, "0 error: %s", strerror(errno));
else
fputs("1 blessed\r\n", f);
mprintf(fd, "1 blessed");
} else if (begins_with("stop ", line)) {
pid_t pid = atoi(line + sizeof("stop "));
......@@ -176,27 +211,27 @@ sockerror:
/* Sanity check. */
if (pid < 10 || pid > 32768) {
syslog(LOG_WARNING, "stop: invalid pid (%d)", pid);
fputs("0 invalid pid\r\n", f);
fclose(f);
mprintf(fd, "0 invalid pid");
close(fd);
continue;
}
if (!cgroup_is_task_in_cgroup(chier, cgroup, pid)) {
fputs("0 task not marked as computation\r\n", f);
fclose(f);
mprintf(fd, "0 task not marked as computation");
close(fd);
continue;
}
syslog(LOG_INFO, "stopping process %d (request by pid %d uid %d)", pid, cred->pid, cred->uid);
kill(pid, SIGTERM);
/* TODO: Grace period and then kill with SIGKILL. */
fputs("1 task stopped\r\n", f);
mprintf(fd, "1 task stopped");
} else if (!strcmp("stopall", line)) {
pid_t *tasks;
int tasks_n = cgroup_task_list(chier, cgroup, &tasks);
if (tasks_n < 0) {
fprintf(f, "0 error: %s\r\n", strerror(errno));
fclose(f);
mprintf(fd, "0 error: %s\r\n", strerror(errno));
close(fd);
continue;
}
for (int i = 0; i < tasks_n; i++) {
......@@ -204,7 +239,7 @@ sockerror:
kill(tasks[i], SIGTERM);
}
/* TODO: Grace period and then kill with SIGKILL. */
fprintf(f, "1 %d tasks stopped\r\n", tasks_n);
mprintf(fd, "1 %d tasks stopped", tasks_n);
free(tasks);
} else if (begins_with("limitmem ", line)) {
......@@ -215,33 +250,33 @@ sockerror:
/* Sanity check. */
if (limit < 1024 || limit > total) {
syslog(LOG_WARNING, "limitmem: invalid limit (%zu)", limit);
fputs("0 invalid limit value\r\n", f);
fclose(f);
mprintf(fd, "0 invalid limit value");
close(fd);
continue;
}
if (limit < mincomp) {
fprintf(f, "-1 at least %zuM must remain available for computations.\r\n", mincomp / 1048576);
fclose(f);
mprintf(fd, "-1 at least %zuM must remain available for computations.", mincomp / 1048576);
close(fd);
continue;
}
if (total - limit < minuser) {
fprintf(f, "-2 at least %zuM must remain available for users.\r\n", minuser / 1048576);
fclose(f);
mprintf(fd, "-2 at least %zuM must remain available for users.", minuser / 1048576);
close(fd);
continue;
}
syslog(LOG_INFO, "setting limit %zu (request by pid %d uid %d)", limit, cred->pid, cred->uid);
if (cgroup_set_mem_limit(chier, cgroup, limit) < 0)
fprintf(f, "0 error: %s\r\n", strerror(errno));
mprintf(fd, "0 error: %s", strerror(errno));
else
fputs("1 limit set\r\n", f);
mprintf(fd, "1 limit set");
} else {
syslog(LOG_WARNING, "invalid command (%s)", line);
}
fclose(f);
close(fd);
}
return EXIT_FAILURE;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment