qemu-devel
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

Re: [Qemu-devel] [PATCH v5 4/7] Add domain socket communication for vhos


From: Michael S. Tsirkin
Subject: Re: [Qemu-devel] [PATCH v5 4/7] Add domain socket communication for vhost-user backend
Date: Thu, 9 Jan 2014 17:31:33 +0200

On Thu, Jan 09, 2014 at 03:59:58PM +0100, Antonios Motakis wrote:
> Add structures for passing vhost-user messages over a unix domain socket.
> This is the equivalent of the existing vhost-kernel ioctls.
> 
> Connect to the named unix domain socket. The system call sendmsg
> is used for communication. To be able to pass file descriptors
> between processes - we use SCM_RIGHTS type in the message control header.
> 
> Signed-off-by: Antonios Motakis <address@hidden>
> Signed-off-by: Nikolay Nikolaev <address@hidden>
> ---
>  hw/virtio/vhost-backend.c | 268 
> +++++++++++++++++++++++++++++++++++++++++++++-
>  1 file changed, 263 insertions(+), 5 deletions(-)
> 
> diff --git a/hw/virtio/vhost-backend.c b/hw/virtio/vhost-backend.c
> index 1d83e1d..b33d35f 100644
> --- a/hw/virtio/vhost-backend.c
> +++ b/hw/virtio/vhost-backend.c
> @@ -11,34 +11,292 @@
>  #include "hw/virtio/vhost.h"
>  #include "hw/virtio/vhost-backend.h"
>  #include "qemu/error-report.h"
> +#include "qemu/sockets.h"
>  
>  #include <fcntl.h>
>  #include <unistd.h>
>  #include <sys/ioctl.h>
> +#include <sys/socket.h>
> +#include <sys/un.h>
> +#include <linux/vhost.h>
> +
> +#define VHOST_MEMORY_MAX_NREGIONS    8
> +#define VHOST_USER_SOCKTO            (300) /* msec */
> +
> +typedef enum VhostUserRequest {
> +    VHOST_USER_NONE = 0,
> +    VHOST_USER_GET_FEATURES = 1,
> +    VHOST_USER_SET_FEATURES = 2,
> +    VHOST_USER_SET_OWNER = 3,
> +    VHOST_USER_RESET_OWNER = 4,
> +    VHOST_USER_SET_MEM_TABLE = 5,
> +    VHOST_USER_SET_LOG_BASE = 6,
> +    VHOST_USER_SET_LOG_FD = 7,
> +    VHOST_USER_SET_VRING_NUM = 8,
> +    VHOST_USER_SET_VRING_ADDR = 9,
> +    VHOST_USER_SET_VRING_BASE = 10,
> +    VHOST_USER_GET_VRING_BASE = 11,
> +    VHOST_USER_SET_VRING_KICK = 12,
> +    VHOST_USER_SET_VRING_CALL = 13,
> +    VHOST_USER_SET_VRING_ERR = 14,
> +    VHOST_USER_NET_SET_BACKEND = 15,
> +    VHOST_USER_ECHO = 16,
> +    VHOST_USER_MAX
> +} VhostUserRequest;
> +
> +typedef struct VhostUserMemoryRegion {
> +    __u64 guest_phys_addr;
> +    __u64 memory_size;
> +    __u64 userspace_addr;

Don't use __uXX like this.
Use posix uintXX_t types.

> +} VhostUserMemoryRegion;
> +
> +typedef struct VhostUserMemory {
> +    __u32 nregions;
> +    __u32 padding;
> +    VhostUserMemoryRegion regions[VHOST_MEMORY_MAX_NREGIONS];
> +} VhostUserMemory;
> +
> +typedef struct VhostUserMsg {
> +    VhostUserRequest request;
> +
> +#define VHOST_USER_VERSION_MASK     (0x3)
> +#define VHOST_USER_REPLY_MASK       (0x1<<2)
> +    __u32 flags;
> +    __u32 size; /* the following payload size */
> +    union {
> +        __u64 u64;
> +        int fd;
> +        struct vhost_vring_state state;
> +        struct vhost_vring_addr addr;
> +        struct vhost_vring_file file;

Hmm I don't get it. fd is passes in anciliarry data, right?
what are the fd #s doing in data portion?

> +        VhostUserMemory memory;
> +    };
> +} QEMU_PACKED VhostUserMsg;
> +
> +#define MEMB_SIZE(t, m)     (sizeof(((t *)0)->m))

That's unnecessarily tricky.
Just do

static VhostUserMsg m;

And then you can use sizeof(m.request) without issues and macros.

> +#define VHOST_USER_HDR_SIZE (MEMB_SIZE(VhostUserMsg, request) \
> +                            + MEMB_SIZE(VhostUserMsg, flags) \
> +                            + MEMB_SIZE(VhostUserMsg, size))

or just split header to a separate structure and use sizeof on that.

> +
> +/* The version of the protocol we support */
> +#define VHOST_USER_VERSION    (0x1)
> +
> +static int vhost_user_cleanup(struct vhost_dev *dev);

dont forward declare static functions.
order them logically instead.

> +
> +static int vhost_user_recv(int fd, VhostUserMsg *msg)
> +{
> +    ssize_t r;
> +    __u8 *p = (__u8 *)msg;
> +
> +    /* read the header */
> +    do {
> +        r = read(fd, p, VHOST_USER_HDR_SIZE);
> +    } while (r < 0 && errno == EINTR);

What if you get a partial buffer?
You use SOCK_STREAM so I think it can happen right?


> +
> +    if (r > 0) {
> +        p += VHOST_USER_HDR_SIZE;
> +        /* read the payload */
> +        do {
> +            r = read(fd, p, msg->size);
> +        } while (r < 0 && errno == EINTR);

That's a security bug.
What if msg->size is bigger than sizeof *msg?
you will corrupt memory.
You must validate size.

> +    }
> +
> +    if (r < 0) {
> +        error_report("Failed to read msg, reason: %s\n", strerror(errno));
> +    }
> +
> +    return (r < 0) ? -1 : 0;
> +}
> +
> +static int vhost_user_send_fds(int fd, const VhostUserMsg *msg, int *fds,
> +        size_t fd_num)
> +{
> +    int r;
> +
> +    struct msghdr msgh;
> +    struct iovec iov[1];

same as struct iovec iov only longer and more confusing?
> +
> +    size_t fd_size = fd_num * sizeof(int);
> +    char control[CMSG_SPACE(fd_size)];
> +    struct cmsghdr *cmsg;
> +
> +    memset(&msgh, 0, sizeof(msgh));
> +    memset(control, 0, sizeof(control));
> +
> +    /* set the payload */
> +    iov[0].iov_base = (void *)msg;

it's void * already, no need for cast

> +    iov[0].iov_len = VHOST_USER_HDR_SIZE + msg->size;
> +
> +    msgh.msg_iov = iov;
> +    msgh.msg_iovlen = 1;
> +
> +    if (fd_num) {
> +        msgh.msg_control = control;
> +        msgh.msg_controllen = sizeof(control);
> +
> +        cmsg = CMSG_FIRSTHDR(&msgh);
> +
> +        cmsg->cmsg_len = CMSG_LEN(fd_size);
> +        cmsg->cmsg_level = SOL_SOCKET;
> +        cmsg->cmsg_type = SCM_RIGHTS;
> +        memcpy(CMSG_DATA(cmsg), fds, fd_size);
> +    } else {
> +        msgh.msg_control = 0;
> +        msgh.msg_controllen = 0;
> +    }
> +
> +    do {
> +        r = sendmsg(fd, &msgh, 0);
> +    } while (r < 0 && errno == EINTR);

again I think this can write message partially for SOCK_STREAM.

> +
> +    if (r < 0) {
> +        error_report("Failed to send msg(%d), reason: %s\n",
> +                msg->request, strerror(errno));
> +    } else {
> +        r = 0;
> +    }
> +
> +    return r;
> +}
> +
> +static int vhost_user_echo(struct vhost_dev *dev)
> +{
> +    VhostUserMsg msg;
> +    int fd = dev->control;
> +
> +    if (fd < 0) {
> +        return -1;
> +    }
> +
> +    /* check connection */
> +    msg.request = VHOST_USER_ECHO;
> +    msg.flags = VHOST_USER_VERSION;
> +    msg.size = 0;
> +
> +    if (vhost_user_send_fds(fd, &msg, 0, 0) < 0) {
> +        error_report("ECHO failed\n");
> +        return -1;
> +    }
> +
> +    if (vhost_user_recv(fd, &msg) < 0) {
> +        error_report("ECHO failed\n");
> +        return -1;
> +    }
> +
> +    if ((msg.flags & VHOST_USER_REPLY_MASK) == 0 ||
> +        (msg.flags & VHOST_USER_VERSION_MASK) != VHOST_USER_VERSION) {
> +        error_report("ECHO failed\n");
> +        return -1;
> +    }
> +
> +    return 0;
> +}
>  
>  static int vhost_user_call(struct vhost_dev *dev, unsigned long int request,
>          void *arg)
>  {
> +    int fd = dev->control;
> +    VhostUserMsg msg;
> +    int result = 0;
> +    int fds[VHOST_MEMORY_MAX_NREGIONS];
> +    size_t fd_num = 0;
> +
>      assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);
> -    error_report("vhost_user_call not implemented\n");
>  
> -    return -1;
> +    if (fd < 0) {
> +        return -1;
> +    }
> +
> +    msg.request = VHOST_USER_NONE;
> +    msg.flags = VHOST_USER_VERSION;
> +    msg.size = 0;
> +
> +    switch (request) {
> +    default:
> +        error_report("vhost-user trying to send unhandled ioctl\n");
> +        return -1;
> +        break;
> +    }
> +
> +    result = vhost_user_send_fds(fd, &msg, fds, fd_num);
> +
> +    return result;
>  }
>  
>  static int vhost_user_init(struct vhost_dev *dev, const char *devpath)
>  {
> +    int fd = -1;
> +    struct sockaddr_un un;
> +    struct timeval tv;
> +    size_t len;
> +
>      assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);
> -    error_report("vhost_user_init not implemented\n");
>  
> +    /* Create the socket */
> +    fd = qemu_socket(AF_UNIX, SOCK_STREAM, 0);
> +    if (fd == -1) {
> +        error_report("socket: %s", strerror(errno));
> +        return -1;
> +    }
> +
> +    /* Set socket options */
> +    tv.tv_sec = VHOST_USER_SOCKTO / 1000;
> +    tv.tv_usec = (VHOST_USER_SOCKTO % 1000) * 1000;
> +
> +    if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(struct timeval))
> +            == -1) {
> +        error_report("setsockopt SO_SNDTIMEO: %s", strerror(errno));
> +        goto fail;
> +    }
> +
> +    if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(struct timeval))
> +            == -1) {
> +        error_report("setsockopt SO_RCVTIMEO: %s", strerror(errno));
> +        goto fail;
> +    }
> +
> +    /* Connect */
> +    un.sun_family = AF_UNIX;
> +    strcpy(un.sun_path, devpath);
> +
> +    len = sizeof(un.sun_family) + strlen(devpath);
> +
> +    if (connect(fd, (struct sockaddr *) &un, len) == -1) {
> +        error_report("connect: %s", strerror(errno));
> +        goto fail;
> +    }
> +
> +    /* Cleanup if there is previous connection left */
> +    if (dev->control >= 0) {
> +        vhost_user_cleanup(dev);
> +    }
> +    dev->control = fd;
> +
> +    if (vhost_user_echo(dev) < 0) {
> +        dev->control = -1;
> +        goto fail;
> +    }
> +
> +    return fd;
> +
> +fail:
> +    close(fd);
>      return -1;
> +
>  }
>  
>  static int vhost_user_cleanup(struct vhost_dev *dev)
>  {
> +    int r = 0;
> +
>      assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);
> -    error_report("vhost_user_cleanup not implemented\n");
>  
> -    return -1;
> +    if (dev->control >= 0) {
> +        r = close(dev->control);
> +    }
> +    dev->control = -1;
> +
> +    return r;
>  }
>  
>  static const VhostOps user_ops = {
> -- 
> 1.8.3.2
> 



reply via email to

[Prev in Thread] Current Thread [Next in Thread]