//
// TO USE:
//
//   save file as sendfd.c
//   gcc -o sendfd sendfd.c
//   in one process:      sendfd --send
//   in another process:  sendfd --recieve
//

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>


#define FATAL_ERROR() exit(1)
#define SOCKET_NAME "/tmp/example_socket"

int receive_fd(const char *socket_name)
{
    int listen_fd;
    struct sockaddr_un sock_addr;
    int connect_fd;
    struct sockaddr_un connect_addr;
    socklen_t connect_addr_len = 0;
    struct msghdr msg;
    struct iovec iov[1];
    char msg_buf[1];
    char ctrl_buf[CMSG_SPACE(sizeof(int))];
    struct cmsghdr *cmsg;

    listen_fd = socket(PF_UNIX, SOCK_STREAM, 0);
    if (listen_fd < 0) FATAL_ERROR();

    unlink(socket_name);

    memset(&sock_addr, 0, sizeof(struct sockaddr_un));
    sock_addr.sun_family = AF_UNIX;
    strncpy(sock_addr.sun_path, socket_name, sizeof(sock_addr.sun_path)-1);

    if (bind(listen_fd,
             (const struct sockaddr*)&sock_addr,
             sizeof(struct sockaddr_un)))
        FATAL_ERROR();
    
    if (listen(listen_fd, 1)) FATAL_ERROR();

    connect_fd = accept(
                    listen_fd,
                    (struct sockaddr *)&connect_addr,
                    &connect_addr_len);
    close(listen_fd);
    unlink(socket_name);
    if (connect_fd < 0) FATAL_ERROR();

    memset(&msg, 0, sizeof(msg));

    iov[0].iov_base = msg_buf;
    iov[0].iov_len  = sizeof(msg_buf);
    msg.msg_iov = iov;
    msg.msg_iovlen = 1;

    msg.msg_control = ctrl_buf;
    msg.msg_controllen = sizeof(ctrl_buf);

    if (recvmsg(connect_fd, &msg, 0) != 1) FATAL_ERROR();

    cmsg = CMSG_FIRSTHDR(&msg);
	if (!cmsg) FATAL_ERROR();
    if (cmsg->cmsg_level != SOL_SOCKET) FATAL_ERROR();
    if (cmsg->cmsg_type != SCM_RIGHTS) FATAL_ERROR();

    return *(int *) CMSG_DATA(cmsg);
}

void send_fd(const char *socket_name, int fd_to_send)
{
    int sock_fd;
    struct sockaddr_un sock_addr;
    struct msghdr msg;
    struct iovec iov[1];
    char ctrl_buf[CMSG_SPACE(sizeof(int))];
    struct cmsghdr *cmsg = NULL;

    sock_fd = socket(PF_UNIX, SOCK_STREAM, 0);
    if (sock_fd < 0) FATAL_ERROR();

    memset(&sock_addr, 0, sizeof(struct sockaddr_un));
    sock_addr.sun_family = AF_UNIX;
    strncpy(sock_addr.sun_path, socket_name, sizeof(sock_addr.sun_path)-1);
    
    while (connect(sock_fd,
                (const struct sockaddr*)&sock_addr,
                sizeof(struct sockaddr_un))) {
        printf("Waiting for reciever to listen at %s\n",socket_name);
        sleep(1);
    }

    memset(&msg, 0, sizeof(msg));

    iov[0].iov_base = "x";  // must send at least 1 byte
    iov[0].iov_len  = 1;
    msg.msg_iov = iov;
    msg.msg_iovlen = 1;

    memset(ctrl_buf, 0, sizeof(ctrl_buf));
    msg.msg_control = ctrl_buf;
    msg.msg_controllen = sizeof(ctrl_buf);

    cmsg = CMSG_FIRSTHDR(&msg);
    cmsg->cmsg_level = SOL_SOCKET;
    cmsg->cmsg_type = SCM_RIGHTS;
    cmsg->cmsg_len = CMSG_LEN(sizeof(int));
    *((int *) CMSG_DATA(cmsg)) = fd_to_send;

    msg.msg_controllen = cmsg->cmsg_len;

    if (sendmsg(sock_fd, &msg, 0) != 1) FATAL_ERROR();

    close(sock_fd);
}





int my_fd_getc(int fd)
{
    char buf[1];
    size_t n;

    n = read(fd, buf, 1);
    if (n!=1) {
        perror("Error in read");
        return '?';
    }
    return buf[0];
}

const char *TMPFILE = "/tmp/sendfd_tmp_file";
const char *TMPTEXT = "test text";

void do_receive(void)
{
    int received_fd = receive_fd(SOCKET_NAME);
	int c1 = my_fd_getc(received_fd);
	int c2 = my_fd_getc(received_fd);

	printf("Received received_fd=%d\n",received_fd);

    printf("  char[2] = %c\n",c1);
    printf("  char[3] = %c\n",c2);
    printf("  Above characters should be '%c' and '%c'\n",
		TMPTEXT[2],
		TMPTEXT[3]);

	if (c1 == TMPTEXT[2] && c2 == TMPTEXT[3]) {
		printf("Success! File descriptor recieved\n");
	} else {
		printf("FAILED! expected characters 's' and 't' from file %s\n",
			TMPFILE);
	}

	close(received_fd);
}

// needed for open()
#include <fcntl.h>

void do_send(void)
{
    int fd_to_send;
	FILE *fp =fopen(TMPFILE, "w");

	fprintf(fp,"%s\n", TMPTEXT);
	fclose(fp);

	// this is so sent fd will be 5 (not 3)
    open("/dev/null", O_RDONLY);
    open("/dev/null", O_RDONLY);

    fd_to_send = open(TMPFILE, O_RDONLY);
    printf("fd_to_send = %d\n",fd_to_send);
    printf("  char[0] = %c\n",my_fd_getc(fd_to_send));
    printf("  char[1] = %c\n",my_fd_getc(fd_to_send));
    printf("  Above characters should be '%c' and '%c'\n",
		TMPTEXT[0],
		TMPTEXT[1]);

    printf("Sending fd=%d\n",fd_to_send);

    send_fd(SOCKET_NAME, fd_to_send);

    printf("Success! Exiting.\n");
	close(fd_to_send);
}

int main(int argc, char *argv[])
{
	if (argc == 2 && !strcmp(argv[1],"--recieve")) {
		do_receive();
	} else if (argc == 2 && !strcmp(argv[1],"--send")) {
		do_send();
	} else {
		printf("USAGE: sendfd --recieve\n");
		printf("USAGE: sendfd --send\n");
		printf("  Send or receive a file descriptor.\n");
		printf("  To use: run 'send --recieve' in one process\n");
		printf("          and 'sendfd --send' in another process.\n");
	}
	return 0;
}