vendredi 29 mai 2015

POSIX req'ts for sendmsg() SCM_RIGHTS with some but not all fds invalid

Does POSIX, or any other relevant standard, mandate any particular behavior when one or more (but not all) of the file descriptors being sent across an AF_UNIX socket with sendmsg() are invalid?

Concretely, the test program at the end of this question (apologies for its length) attempts to send an array of file descriptors over a socket, one of which may be deliberately initialized to -1 rather than a valid file descriptor. Linux (kenel 3.13), NetBSD (6.1.5), and OSX (10.10) all agree that if any of the descriptors are invalid, the sendmsg call fails, with errno set to EBADF. I would like to know if this is actually standard-required behavior, and (regardless) whether any other Unix implementations behave differently. The only thing I was able to find in the online copy of SUSv7 regarding SCM_RIGHTS messages was a requirement that the constant be defined:

The <sys/socket.h> header shall define the following symbolic constant for use as the cmsg_type value when cmsg_level is SOL_SOCKET: SCM_RIGHTS: Indicates that the data array contains the access rights to be sent or received.

(Abstractly, it would seem more useful for the call to succeed when at least one of the file descriptors is valid, and for the receiving process to get an array of the same length and ordering as was sent, with all invalid entries forced to -1.)

(When invoked with no arguments, the test program sends an array of two fds, both of which are valid. When invoked with a single argument, which must be the single character 0, 1, or 2, the test program sends an array of three fds, and the slot in the array corresponding to the value of the argument is set to -1, whereas the other two are set to valid fds. This confirms that the receive-side code works, and demonstrates that the kernel behavior does not depend on the position of the invalid fd within the array.)


#include <sys/types.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/uio.h>
#include <sys/wait.h>
#include <fcntl.h>
#include <unistd.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <signal.h>

static void
describe_fds(int nfd, int fds[], const char *tag)
{
  struct stat st;
  int i;

  printf("%s: nfd=%d: ", tag, nfd);
  if (nfd < 0) nfd = 0;

  for (i = 0; i < nfd; i++) {
    if (i > 0)
      fputs("; ", stdout);

    if (fstat(fds[i], &st))
      printf("[%d] = %d (%s)", i, fds[i], strerror(errno));
    else
      printf("[%d] = %d (%ld)", i, fds[i], (long) st.st_ino);
  }
  putchar('\n');
}

static void
do_receive_fds(int sk, int xnfd)
{
  char dummy[1];
  int *fds, nfd, got_fds = 0;
  ssize_t n;
  struct iovec iov[1];
  struct msghdr msg;
  struct cmsghdr *cmsg;
  char *cmsgbuf = malloc(CMSG_SPACE(sizeof(int) * xnfd));

  if (!cmsgbuf) {
    printf("R: malloc: %s\n", strerror(errno));
    return;
  }

  iov[0].iov_base    = &dummy;
  iov[0].iov_len     = 1;

  msg.msg_name       = 0;
  msg.msg_namelen    = 0;
  msg.msg_iov        = iov;
  msg.msg_iovlen     = 1;
  msg.msg_control    = cmsgbuf;
  msg.msg_controllen = CMSG_SPACE(sizeof(int) * xnfd);
  msg.msg_flags      = 0;

  n = recvmsg(sk, &msg, MSG_WAITALL);
  if (n < 0) {
    printf("R: recvmsg: %s\n", strerror(errno));
    return;
  }
  printf("R: recieve flags = %x\n", (unsigned) msg.msg_flags);
  if (n == 0)
    puts("R: short read ordinary data");
  else
    printf("R: ordinary data = %02x\n", (unsigned char)dummy[0]);

  for (cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
    if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) {
      printf("R: unexpected cmsg %d/%d\n",
             cmsg->cmsg_level, cmsg->cmsg_type);
      continue;
    }
    nfd = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
    if (cmsg->cmsg_len < CMSG_LEN(sizeof(int) * xnfd))
      printf("R: short read ancillary data (exp %lu/%d got %u/%d)\n",
             CMSG_LEN(sizeof(int)*xnfd), xnfd, cmsg->cmsg_len, nfd);

    fds = calloc(nfd, sizeof(int));
    if (!fds) {
      printf("R: calloc: %s\n", strerror(errno));
      continue;
    }

    memcpy(fds, CMSG_DATA(cmsg), nfd * sizeof(int));
    got_fds = 1;
    describe_fds(nfd, fds, "R");
    free(fds);
  }
  if (!got_fds)
    puts("R: no fds received");

  free(cmsgbuf);
}

static int
do_send_fds(int sk, int nfd, int *fds)
{
  char dummy[1];
  ssize_t n;
  int err;
  struct iovec iov[1];
  struct msghdr msg;
  struct cmsghdr *cmsg;
  char *cmsgbuf;

  if (nfd < 0) nfd = 0;

  cmsgbuf = malloc(CMSG_SPACE(sizeof(int) * nfd));
  if (!cmsgbuf) {
    printf("S: malloc: %s\n", strerror(errno));
    return 1;
  }

  dummy[0]           = 'X';
  iov[0].iov_base    = &dummy;
  iov[0].iov_len     = 1;

  msg.msg_name       = 0;
  msg.msg_namelen    = 0;
  msg.msg_iov        = iov;
  msg.msg_iovlen     = 1;
  msg.msg_control    = cmsgbuf;
  msg.msg_controllen = CMSG_SPACE(sizeof(int) * nfd);
  msg.msg_flags      = 0;

  cmsg               = CMSG_FIRSTHDR(&msg);
  cmsg->cmsg_level   = SOL_SOCKET;
  cmsg->cmsg_type    = SCM_RIGHTS;
  cmsg->cmsg_len     = CMSG_LEN(sizeof(int) * nfd);

  memcpy(CMSG_DATA(cmsg), fds, sizeof(int) * nfd);

  n = sendmsg(sk, &msg, 0);
  err = errno;
  free(cmsgbuf);

  if (n < 0) {
    printf("S: sendmsg: %s\n", strerror(err));
    return 1;
  }
  if (n == 0) {
    puts("S: sendmsg: short write");
    return 1;
  }
  return 0;
}

static int
child_task(int sk, int xnfd)
{
  do_receive_fds(sk, xnfd);
  return 0;
}

static int
parent_task(int sk, pid_t pid, int loc)
{
  FILE *a, *b;
  int fds[3];
  int status;
  int nfd;

  a = tmpfile();
  if (!a) {
    printf("S: tmpfile A: %s\n", strerror(errno));
    goto fail;
  }
  b = tmpfile();
  if (!b) {
    printf("S: tmpfile B: %s\n", strerror(errno));
    goto fail;
  }

  switch (loc) {
  case 0: nfd = 2; fds[0] = fileno(a); fds[1] = fileno(b); fds[2] = -1; break;
  case 1: nfd = 3; fds[0] = fileno(a); fds[1] = fileno(b); fds[2] = -1; break;
  case 2: nfd = 3; fds[0] = fileno(a); fds[1] = -1; fds[2] = fileno(b); break;
  case 3: nfd = 3; fds[0] = -1; fds[1] = fileno(a); fds[2] = fileno(b); break;
  default:
    printf("S: impossible: loc=%d\n", loc);
    goto fail;
  }

  describe_fds(nfd, fds, "S");
  fflush(0);

  if (do_send_fds(sk, nfd, fds))
    goto fail;

  if (waitpid(pid, &status, 0) != pid) {
    printf("S: waitpid: %s\n", strerror(errno));
    return 1;
  }
  if (status != 0) {
    printf("S: abnormal child exit status %04x\n", (unsigned short)status);
    return 1;
  }
  return 0;

 fail:
  kill(pid, SIGKILL);
  waitpid(pid, &status, 0);
  return 1;
}

int
main(int argc, char **argv)
{
  int skp[2];
  pid_t pid;
  int loc, xnfd;

  if (argc == 1) {
    xnfd = 2;
    loc = 0;
  }
  else if (argc == 2 && argv[1][1] == '\0') {
    xnfd = 3;
    switch (argv[1][0]) {
    case '0': loc = 3; break;
    case '1': loc = 2; break;
    case '2': loc = 1; break;
    default: goto usage;
    }
  } else {
  usage:
    fprintf(stderr, "usage: %s [0|1|2]\n", argv[0]);
    return 2;
  }

  if (socketpair(PF_LOCAL, SOCK_STREAM, 0, skp)) {
    perror("socketpair");
    return 1;
  }

  fflush(0);
  pid = fork();
  if (pid == -1) {
    perror("fork");
    return 1;
  }
  if (pid == 0) {
    return child_task(skp[0], xnfd);
  } else {
    return parent_task(skp[1], pid, loc);
  }
}

Aucun commentaire:

Enregistrer un commentaire