Skip to content
Snippets Groups Projects
Commit a1a8e13f authored by Petr Baudis's avatar Petr Baudis
Browse files

Rewrite the communication protocol

sendmsg() with only ancilliary message does not work, apparently.
Therefore, to make things cleaner, pass command/reply directly using
sendmsg() instead of newline-terminated strings.
parent c40a98c1
Branches
No related tags found
No related merge requests found
...@@ -10,9 +10,9 @@ tweaking the cgroup limits. ...@@ -10,9 +10,9 @@ tweaking the cgroup limits.
The client compctl interface simply queries the server using The client compctl interface simply queries the server using
a synchronous protocol over a UNIX socket. First, the client a synchronous protocol over a UNIX socket. First, the client
sends a SCM_CREDENTIALS ancilliary message. Then, it follows sends a command string message coupled with a SCM_CREDENTIALS
with a CRLF-terminated command string and receives a CRLF-terminated ancilliary message. Then, it receives a reply message.
reply string. Connection is closed immediately on breach of protocol. Connection is closed immediately on breach of protocol.
You can tweak some simple compile-time configuration variables You can tweak some simple compile-time configuration variables
... ...
......
#define _GNU_SOURCE /* struct ucred */ #define _GNU_SOURCE /* struct ucred */
#include <assert.h> #include <assert.h>
#include <errno.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
...@@ -11,8 +12,8 @@ ...@@ -11,8 +12,8 @@
#include "common.h" #include "common.h"
FILE * char *
connectd(void) daemon_chat(char *cmd)
{ {
int s = socket(AF_UNIX, SOCK_STREAM, 0); int s = socket(AF_UNIX, SOCK_STREAM, 0);
struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE }; struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE };
...@@ -21,29 +22,67 @@ connectd(void) ...@@ -21,29 +22,67 @@ connectd(void)
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
/* Send message with credentials. */ /* Send command. */
struct iovec iov_cmd = {
.iov_base = cmd,
.iov_len = strlen(cmd),
};
struct msghdr msg = {
.msg_iov = &iov_cmd,
.msg_iovlen = 1,
};
/* Include credentials in the message. */
struct ucred cred = { struct ucred cred = {
.pid = getpid(), .pid = getpid(),
.uid = getuid(), .uid = getuid(),
.gid = getgid(), .gid = getgid(),
}; };
char cbuf[CMSG_SPACE(sizeof(cred))]; char cbuf[CMSG_SPACE(sizeof(cred))];
struct msghdr msg = { msg.msg_control = cbuf;
.msg_control = cbuf, msg.msg_controllen = sizeof(cbuf);
.msg_controllen = sizeof(cbuf),
};
struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_SOCKET; cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_CREDENTIALS; cmsg->cmsg_type = SCM_CREDENTIALS;
cmsg->cmsg_len = CMSG_LEN(sizeof(cred)); cmsg->cmsg_len = CMSG_LEN(sizeof(cred));
memcpy(CMSG_DATA(cmsg), &cred, sizeof(cred)); memcpy(CMSG_DATA(cmsg), &cred, sizeof(cred));
sendmsg(s, &msg, 0); ssize_t sent = sendmsg(s, &msg, 0);
if (sent < 0) {
perror("sendmsg");
exit(EXIT_FAILURE);
}
if (sent < msg.msg_iov->iov_len) {
fprintf(stderr, "incomplete send %zd < %zu, FIXME\n", sent, msg.msg_iov->iov_len);
exit(EXIT_FAILURE);
}
/* Receive reply. */
char reply[1024];
struct iovec iov_reply = {
.iov_base = reply,
.iov_len = sizeof(reply),
};
msg.msg_iov = &iov_reply;
msg.msg_iovlen = 1;
recvagain:;
int replylen = recvmsg(s, &msg, 0);
if (replylen < 0) {
if (errno == EAGAIN)
goto recvagain;
perror("recvmsg");
exit(EXIT_FAILURE);
}
if (replylen >= 1024) {
fprintf(stderr, "too long reply from the server\n");
exit(EXIT_FAILURE);
}
reply[replylen] = 0;
return fdopen(s, "rw"); return strdup(reply);
} }
...@@ -58,15 +97,12 @@ help(FILE *f) ...@@ -58,15 +97,12 @@ help(FILE *f)
int int
run(int argc, char *argv[]) run(int argc, char *argv[])
{ {
FILE *f = connectd(); char *line = daemon_chat("blessme");
fputs("blessme\r\n", f);
char line[1024];
fgets(line, sizeof(line), f);
fclose(f);
if (line[0] != '1') { if (line[0] != '1') {
fputs(*line ? line : "unexpected hangup\n", stderr); fputs(*line ? line : "unexpected hangup\n", stderr);
return EXIT_FAILURE; return EXIT_FAILURE;
} }
free(line);
char *argvx[argc + 1]; char *argvx[argc + 1];
for (int i = 0; i < argc; i++) for (int i = 0; i < argc; i++)
...@@ -91,45 +127,38 @@ screen(int argc, char *argv[]) ...@@ -91,45 +127,38 @@ screen(int argc, char *argv[])
void void
stop(pid_t pid) stop(pid_t pid)
{ {
FILE *f = connectd(); char cmd[256]; snprintf(cmd, sizeof(cmd), "stop %d", pid);
fprintf(f, "stop %d\r\n", pid); char *line = daemon_chat(cmd);
char line[1024];
fgets(line, sizeof(line), f);
fclose(f);
if (line[0] != '1') { if (line[0] != '1') {
fputs(*line ? line : "unexpected hangup\n", stderr); fputs(*line ? line : "unexpected hangup\n", stderr);
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
free(line);
} }
void void
stop_all(void) stop_all(void)
{ {
FILE *f = connectd(); char *line = daemon_chat("stopall");
fputs("stopall\r\n", f);
char line[1024];
fgets(line, sizeof(line), f);
fclose(f);
if (line[0] != '1') { if (line[0] != '1') {
fputs(*line ? line : "unexpected hangup\n", stderr); fputs(*line ? line : "unexpected hangup\n", stderr);
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
fputs(line + 2, stdout); fputs(line + 2, stdout);
free(line);
} }
void void
limit_mem(size_t limit) limit_mem(size_t limit)
{ {
FILE *f = connectd(); char cmd[256]; snprintf(cmd, sizeof(cmd), "limitmem %zu", limit);
fprintf(f, "limitmem %zu\r\n", limit); char *line = daemon_chat(cmd);
char line[1024];
fgets(line, sizeof(line), f);
fclose(f);
if (line[0] != '1') { if (line[0] != '1') {
/* TODO: Error message postprocessing. */ /* TODO: Error message postprocessing. */
fputs(*line ? line : "unexpected hangup\n", stderr); fputs(*line ? line : "unexpected hangup\n", stderr);
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
free(line);
} }
void void
... ...
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <assert.h> #include <assert.h>
#include <errno.h> #include <errno.h>
#include <signal.h> #include <signal.h>
#include <stdarg.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
...@@ -77,6 +78,36 @@ cgroup_init(void) ...@@ -77,6 +78,36 @@ cgroup_init(void)
} }
void
mprintf(int fd, char *fmt, ...)
{
char buf[1024];
va_list v;
va_start(v, fmt);
vsnprintf(buf, sizeof(buf), fmt, v);
va_end(v);
struct iovec iov = {
.iov_base = buf,
.iov_len = strlen(buf),
};
struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
};
ssize_t sent = sendmsg(fd, &msg, 0);
if (sent < 0) {
logperror("sendmsg");
return;
}
if (sent < iov.iov_len) {
syslog(LOG_INFO, "incomplete send %zd < %zu, FIXME", sent, iov.iov_len);
return;
}
}
int int
main(int argc, char *argv[]) main(int argc, char *argv[])
{ {
...@@ -100,8 +131,6 @@ main(int argc, char *argv[]) ...@@ -100,8 +131,6 @@ main(int argc, char *argv[])
setsid(); setsid();
int s = socket(AF_UNIX, SOCK_STREAM, 0); int s = socket(AF_UNIX, SOCK_STREAM, 0);
int on = 1; setsockopt(s, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on));
/* TODO: Protect against double execution? */ /* TODO: Protect against double execution? */
unlink(SOCKFILE); unlink(SOCKFILE);
struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE }; struct sockaddr_un sun = { .sun_family = AF_UNIX, .sun_path = SOCKFILE };
...@@ -119,17 +148,27 @@ main(int argc, char *argv[]) ...@@ -119,17 +148,27 @@ main(int argc, char *argv[])
* the daemon, this is just an attack vector we ignore. */ * the daemon, this is just an attack vector we ignore. */
/* TODO: alarm() to wake from stuck clients. */ /* TODO: alarm() to wake from stuck clients. */
/* Decode the message with credentials. */ /* Decode the message with command and credentials. */
int on = 1; setsockopt(fd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on));
struct ucred *cred; struct ucred *cred;
char cbuf[CMSG_SPACE(sizeof(*cred))]; char cbuf[CMSG_SPACE(sizeof(*cred))];
char line[1024];
struct iovec iov = {
.iov_base = line,
.iov_len = sizeof(line),
};
struct msghdr msg = { struct msghdr msg = {
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = cbuf, .msg_control = cbuf,
.msg_controllen = sizeof(cbuf), .msg_controllen = sizeof(cbuf),
}; };
char *errmsg; char *errmsg;
recvagain: recvagain:;
if (recvmsg(fd, &msg, MSG_WAITALL) < 0) { int replylen = recvmsg(fd, &msg, 0);
if (replylen < 0) {
if (errno == EAGAIN) if (errno == EAGAIN)
goto recvagain; goto recvagain;
errmsg = "recvmsg"; errmsg = "recvmsg";
...@@ -151,24 +190,20 @@ sockerror: ...@@ -151,24 +190,20 @@ sockerror:
} }
cred = (struct ucred *) CMSG_DATA(cmsg); cred = (struct ucred *) CMSG_DATA(cmsg);
FILE *f = fdopen(fd, "r"); line[replylen] = 0;
char line[1024]; if (replylen < 2) {
fgets(line, sizeof(line), f);
size_t linelen = strlen(line);
if (linelen < 2 || strcmp(&line[linelen - 2], "\r\n")) {
syslog(LOG_WARNING, "protocol error (%s)", line); syslog(LOG_WARNING, "protocol error (%s)", line);
fclose(f); close(fd);
continue; continue;
} }
line[linelen - 2] = 0;
/* Analyze command */ /* Analyze command */
if (!strcmp("blessme", line)) { if (!strcmp("blessme", line)) {
syslog(LOG_INFO, "new computation process %d", cred->pid); syslog(LOG_INFO, "new computation process %d", cred->pid);
if (cgroup_add_task(chier, cgroup, cred->pid) < 0) if (cgroup_add_task(chier, cgroup, cred->pid) < 0)
fprintf(f, "0 error: %s\r\n", strerror(errno)); mprintf(fd, "0 error: %s", strerror(errno));
else else
fputs("1 blessed\r\n", f); mprintf(fd, "1 blessed");
} else if (begins_with("stop ", line)) { } else if (begins_with("stop ", line)) {
pid_t pid = atoi(line + sizeof("stop ")); pid_t pid = atoi(line + sizeof("stop "));
...@@ -176,27 +211,27 @@ sockerror: ...@@ -176,27 +211,27 @@ sockerror:
/* Sanity check. */ /* Sanity check. */
if (pid < 10 || pid > 32768) { if (pid < 10 || pid > 32768) {
syslog(LOG_WARNING, "stop: invalid pid (%d)", pid); syslog(LOG_WARNING, "stop: invalid pid (%d)", pid);
fputs("0 invalid pid\r\n", f); mprintf(fd, "0 invalid pid");
fclose(f); close(fd);
continue; continue;
} }
if (!cgroup_is_task_in_cgroup(chier, cgroup, pid)) { if (!cgroup_is_task_in_cgroup(chier, cgroup, pid)) {
fputs("0 task not marked as computation\r\n", f); mprintf(fd, "0 task not marked as computation");
fclose(f); close(fd);
continue; continue;
} }
syslog(LOG_INFO, "stopping process %d (request by pid %d uid %d)", pid, cred->pid, cred->uid); syslog(LOG_INFO, "stopping process %d (request by pid %d uid %d)", pid, cred->pid, cred->uid);
kill(pid, SIGTERM); kill(pid, SIGTERM);
/* TODO: Grace period and then kill with SIGKILL. */ /* TODO: Grace period and then kill with SIGKILL. */
fputs("1 task stopped\r\n", f); mprintf(fd, "1 task stopped");
} else if (!strcmp("stopall", line)) { } else if (!strcmp("stopall", line)) {
pid_t *tasks; pid_t *tasks;
int tasks_n = cgroup_task_list(chier, cgroup, &tasks); int tasks_n = cgroup_task_list(chier, cgroup, &tasks);
if (tasks_n < 0) { if (tasks_n < 0) {
fprintf(f, "0 error: %s\r\n", strerror(errno)); mprintf(fd, "0 error: %s\r\n", strerror(errno));
fclose(f); close(fd);
continue; continue;
} }
for (int i = 0; i < tasks_n; i++) { for (int i = 0; i < tasks_n; i++) {
...@@ -204,7 +239,7 @@ sockerror: ...@@ -204,7 +239,7 @@ sockerror:
kill(tasks[i], SIGTERM); kill(tasks[i], SIGTERM);
} }
/* TODO: Grace period and then kill with SIGKILL. */ /* TODO: Grace period and then kill with SIGKILL. */
fprintf(f, "1 %d tasks stopped\r\n", tasks_n); mprintf(fd, "1 %d tasks stopped", tasks_n);
free(tasks); free(tasks);
} else if (begins_with("limitmem ", line)) { } else if (begins_with("limitmem ", line)) {
...@@ -215,33 +250,33 @@ sockerror: ...@@ -215,33 +250,33 @@ sockerror:
/* Sanity check. */ /* Sanity check. */
if (limit < 1024 || limit > total) { if (limit < 1024 || limit > total) {
syslog(LOG_WARNING, "limitmem: invalid limit (%zu)", limit); syslog(LOG_WARNING, "limitmem: invalid limit (%zu)", limit);
fputs("0 invalid limit value\r\n", f); mprintf(fd, "0 invalid limit value");
fclose(f); close(fd);
continue; continue;
} }
if (limit < mincomp) { if (limit < mincomp) {
fprintf(f, "-1 at least %zuM must remain available for computations.\r\n", mincomp / 1048576); mprintf(fd, "-1 at least %zuM must remain available for computations.", mincomp / 1048576);
fclose(f); close(fd);
continue; continue;
} }
if (total - limit < minuser) { if (total - limit < minuser) {
fprintf(f, "-2 at least %zuM must remain available for users.\r\n", minuser / 1048576); mprintf(fd, "-2 at least %zuM must remain available for users.", minuser / 1048576);
fclose(f); close(fd);
continue; continue;
} }
syslog(LOG_INFO, "setting limit %zu (request by pid %d uid %d)", limit, cred->pid, cred->uid); syslog(LOG_INFO, "setting limit %zu (request by pid %d uid %d)", limit, cred->pid, cred->uid);
if (cgroup_set_mem_limit(chier, cgroup, limit) < 0) if (cgroup_set_mem_limit(chier, cgroup, limit) < 0)
fprintf(f, "0 error: %s\r\n", strerror(errno)); mprintf(fd, "0 error: %s", strerror(errno));
else else
fputs("1 limit set\r\n", f); mprintf(fd, "1 limit set");
} else { } else {
syslog(LOG_WARNING, "invalid command (%s)", line); syslog(LOG_WARNING, "invalid command (%s)", line);
} }
fclose(f); close(fd);
} }
return EXIT_FAILURE; return EXIT_FAILURE;
... ...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment