/* Star Emcast - Star Emcast Interface
 * Copyright (C) 2001  The Regents of the University of Michigan
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#include <emcast-protocol.h>
#include <util.h>

#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <fcntl.h>
#include <limits.h>
#include <ctype.h>
#include <string.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/time.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>

#ifndef socklen_t
#define socklen_t size_t
#endif

#define LISTENQ 10

#ifndef MAX
#  define MAX(A,B) (((A) > (B))?(A):(B))
#endif

/* ******************** */

typedef struct _starpkt
{
  uint8_t   id;			/* ID ("*")				*/
  uint8_t   version;		/* Version (0)				*/
  uint16_t  len;		/* Data length (network byte order)	*/
  char	    data[0];		/* Data					*/

} starpkt;

#define STAR_PORT 5748


/* ******************** */

#define MAX_PKT_SIZE	4098
#define BUF_SIZE 	(sizeof(starpkt) + MAX_PKT_SIZE)

typedef struct _starconn
{
  int	    sd;

  int	    len_in;
  int	    len_out;

  struct _starconn* prev;
  struct _starconn* next;

  char	    buf_in[BUF_SIZE];
  char	    buf_out[BUF_SIZE];

} starconn;


/* ******************** */

#define pfail(S) do { \
   fprintf(stderr, "%s: ERROR: %s: %s)\n", progname, S, strerror(errno)); \
   exit(EXIT_FAILURE); \
   } while(0)
#define pwarn(S) do { \
   fprintf(stderr, "%s: WARNING: %s: %s\n", progname, S, strerror(errno)); \
   } while(0)
#define warn(S) do { \
   fprintf(stderr, "%s: WARNING: %s\n", progname, S); \
   } while(0)


/* ******************** */

static void center_accept (void);

static int emjoin (char* url);
static int emleave (void);
static int emsend (char* buf, unsigned short len);
static int emgetopt (char* optname, void* optval, unsigned short* optlen);
static int emsetopt (char* optname, void* optval, unsigned short optlen);

static int	 server_new (struct sockaddr* sap);
static int       client_new (struct sockaddr* sap);

static starconn* starconn_new (int sd);
static void 	 starconn_delete (starconn* conn);
static void 	 starconn_recv (starconn* conn);
static void 	 starconn_send (starconn* conn);

static Emfuncs emfuncs = { NULL,
			   emjoin,
			   emleave,
			   emsend,
			   emgetopt,	
			   emsetopt, 1, 1};


/* ******************** */

static const char* progname;
static int fd_fifo;

static int server_sd = 0;
static starconn* conns = NULL;
static starconn* server_conn = NULL;  /* Points into conns */


/* ******************** */


int
main (int argc, char* argv[])
{
  int rv;

  progname = argv[0];

  /* Not-joined loop - block in emcast-protocol */
 not_joined:
  while (1)
    {
      /* Loop once */
      rv = emcast_loop_once (&emfuncs, STDIN_FILENO, STDOUT_FILENO, &fd_fifo);
      if (rv == -1)
	exit (EXIT_SUCCESS);
      else if (rv == 1)
	exit (EXIT_FAILURE);

      /* Check if now connected */
      if (server_sd || conns)
	break;
    }

  /* Joined loop - block here */
  while (1)
    {
      fd_set	read_fdset;
      fd_set	write_fdset;
      int	max_fd = STDIN_FILENO;
      starconn* i;
      
      /* Clear FD sets */
      FD_ZERO (&read_fdset);
      FD_ZERO (&write_fdset);

      /* Set FD sets */
      FD_SET (STDIN_FILENO, &read_fdset);
      for (i = conns; i != NULL; i = i->next)
	{
	  max_fd = MAX(max_fd, i->sd);
	  FD_SET (i->sd, &read_fdset);
	  if (i->len_out)
	    FD_SET (i->sd, &write_fdset);
	}
      if (server_sd)
	{
	  max_fd = MAX(max_fd, server_sd);
	  FD_SET (server_sd, &read_fdset);
	}

      /* Select */
      rv = select (max_fd + 1, &read_fdset, &write_fdset, NULL, NULL);
      if (rv == -1) pfail("select() failed");
      
      /* Check Emcast interface */
      if (FD_ISSET(STDIN_FILENO, &read_fdset))
	{
	  rv = emcast_loop_once (&emfuncs, STDIN_FILENO, STDOUT_FILENO, 
				 &fd_fifo);
	  if (rv == -1)
	    exit (EXIT_SUCCESS);
	  else if (rv == 1)
	    exit (EXIT_FAILURE);

	  if (!server_sd && !conns)
	    goto not_joined;
	}

      /* Check for new connection */
      if (server_sd && FD_ISSET(server_sd, &read_fdset))
	center_accept ();

      /* Check for data to read */
      for (i = conns; i != NULL; )
	{
	  starconn* conn = i;
	  i = i->next;

	  if (FD_ISSET(conn->sd, &read_fdset))
	    starconn_recv (conn);
	}

      /* Check for data to write */
      for (i = conns; i != NULL; )
	{
	  starconn* conn = i;
	  i = i->next;

	  /* Write data if there's any to write.  In the worst case
             nothing is written. */
	  if (conn->buf_out)
	    starconn_send (conn);
	}
    }

  exit (EXIT_SUCCESS);
  return 0;
}


static void
center_accept (void)
{
  int clientfd;
  struct sockaddr addr;
  socklen_t addrlen = sizeof(addr);

  /* Accept the client */
  clientfd = accept (server_sd, &addr, &addrlen);
  if (clientfd == -1)
    {
      /* Client disappeared? */
      if (errno == EWOULDBLOCK || 
	  errno == ECONNABORTED ||
#ifdef EPROTO		/* OpenBSD does not have EPROTO */
	  errno == EPROTO || 
#endif
	  errno == EINTR)
	return;
      else
	pfail ("accept() failed");
    }

/*    fprintf (stderr, "connect %s\n", inet_ntoa(((struct sockaddr_in*) &addr)->sin_addr)); */

  /* Create new client */
  starconn_new (clientfd);
}




/* ******************** */


static int
emjoin (char* url)
{
  int   rv;
  char* hostname;
  int   port;

  struct hostent* he;
  struct sockaddr sa;
  struct sockaddr_in* sa_in;

  /* Parse URL */
  rv = emcast_parse_url (url, NULL, &hostname, &port, NULL);
  if (rv != 0)
    return 1;
  if (port == 0)
    port = STAR_PORT;

  /* Lookup address */
  he = gethostbyname(hostname);
  if (!he) return 1;
  memset ((void*) &sa, 0, sizeof(sa));
  sa_in = (struct sockaddr_in*) &sa;
  sa_in->sin_family = AF_INET;
  sa_in->sin_port = htons(port);
  memcpy (&sa_in->sin_addr, he->h_addr, he->h_length);

  /* Create server */
  rv = server_new (&sa);
  if (rv >= 0)
    {
      server_sd = rv;
    }

  /* Create client if server creation failed */
  else
    {
      rv = client_new (&sa);
      if (rv < 0) return 1;

      /* Create conn for the server */
      server_conn = starconn_new (rv);
    }

  return 0;
}


static int
emleave (void)
{
  if (server_sd)
    {
      close (server_sd); 
      server_sd = 0;
    }

  while (conns)
    starconn_delete (conns);
  conns = NULL;
  server_conn = NULL;

  return 0;
}


static int
emsend (char* buf, unsigned short len)
{
  starpkt   pkt_hdr;
  ssize_t   pkt_len;
  starconn* i;

  /* Error if not a member */
  if (!server_sd && !conns)
    return 1;

  /* Check if too big */
  if (len > MAX_PKT_SIZE)
    {
      fprintf (stderr, "Dropping packet (too big)\n");
      return 1;
    }

  /* Noop if we are the center and have no connections */
  if (server_sd && !conns)
    return 0;

  /* Create packet header*/
  pkt_hdr.id = '*';
  pkt_hdr.version = 0;
  pkt_hdr.len = htons(len);

  /* Buffer at each connection */
  pkt_len = sizeof(starpkt) + len;
  for (i = conns; i != NULL; i = i->next)
    {
      starconn* conn = i;

      if ((conn->len_out + pkt_len) > sizeof(conn->buf_out))
	{
	  fprintf (stderr, "Dropping packet (buffer full)\n");
	  continue;
	}

      memcpy (&conn->buf_out[conn->len_out], &pkt_hdr, sizeof(pkt_hdr));
      conn->len_out += sizeof(pkt_hdr);
      memcpy (&conn->buf_out[conn->len_out], buf, len);
      conn->len_out += len;
    }

  return 0;
}


static int
emgetopt (char* optname, void* optval, unsigned short* optlen)
{
  return 1;
}


static int
emsetopt (char* optname, void* optval, unsigned short optlen)
{
  return 1;
}



/* ******************** */

static int
server_new (struct sockaddr* sap)
{
  int   rv;
  int   s;
  int   flags;
  const int on = 1;

  /* Create socket */
  s = socket (AF_INET, SOCK_STREAM, 0);
  if (s < 0) 
    {
      pwarn ("socket() failed");
      return -1;
    }

  /* Set REUSEADDR so we can reuse the port */
  rv = setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (void*) &on, sizeof(on));
  if (rv != 0) 
    {
      pwarn ("setsockopt() failed");
      close (s);
      return -1;
    }

  /* Get the flags */
  flags = fcntl (s, F_GETFL, 0);
  if (flags == -1) 
    {
      pwarn ("fcntl() get flags failed");
      close (s);
      return -1;
    }

  /* Make the socket non-blocking (for safe accept()) */
  rv = fcntl (s, F_SETFL, flags | O_NONBLOCK);
  if (rv == -1) 
    {
      pwarn ("fcntl() set flags failed");
      close (s);
      return -1;
    }

  /* Bind */
  rv = bind (s, sap, sizeof(*sap));
  if (rv != 0)
    {
      close (s);
      return -1;
    }

  /* Listen */
  rv = listen(s, LISTENQ);
  if (rv != 0) 
    {
      pwarn ("listen() failed");
      close (s);
      return -1;
    }

  return s;
}



static int
client_new (struct sockaddr* sap)
{
  int s;
  int rv;

  /* Create new TCP socket */
  s = socket (AF_INET, SOCK_STREAM, 0);
  if (s < 0)
    {
      pwarn ("socket() failed");
      return -1;
    }

  /* Connect to server */
  rv = connect (s, sap, sizeof(*sap));
  if (rv != 0)
    {
      close(s);
      return -1;
    }

  return s;
}



static starconn*
starconn_new (int sd)
{
  starconn* conn;

  /* Create conn */
  conn = calloc (1, sizeof(starconn));
  if (conn == NULL) pfail ("calloc() failed");
  conn->sd = sd;

  /* Prepend to list */
  if (conns) conns->prev = conn;
  conn->next = conns;
  conns = conn;

  return conn;

}


static void
starconn_delete (starconn* conn)
{
  /* Remove conn from list */
  if (conn->prev) conn->prev->next = conn->next;
  if (conn->next) conn->next->prev = conn->prev;
  if (conns == conn) conns = conn->next;

  /* Delete */
  close (conn->sd);
  free (conn);
}


static void
starconn_recv (starconn* conn)
{
  int rv;
  starpkt* pkt;
  int data_len;
  int pkt_len;
  starconn* i;

  /* Read into buffer */
  rv = recv (conn->sd, &conn->buf_in[conn->len_in], 
	     sizeof(conn->buf_in) - conn->len_in, MSG_DONTWAIT);

  if (rv == -1)
    {
      if (errno == EWOULDBLOCK || errno == EINTR) /* Nothing there */
	return;
      else /* Error */
	exit (EXIT_FAILURE);
    }
  if (rv == 0) /* EOF */
    {
      if (conn == server_conn)
	exit (EXIT_SUCCESS);
      else
	{
	  starconn_delete (conn);
	  return;
	}
    }

  /* Save bytes */
  conn->len_in += rv;

  /* Check if have full header */
  if (conn->len_in < sizeof(starpkt))
    return;

  pkt = (starpkt*) conn->buf_in;
  data_len = ntohs(pkt->len);
  pkt_len = sizeof(starpkt) + data_len;

  /* Validity checks */
  if (pkt->id != '*')
    {
      warn ("Received packet with bad version");
      return;
    }
  if (pkt->version != 0)
    {
      warn ("Received packet with bad version");
      return;
    }
  if (data_len > MAX_PKT_SIZE)
    {
      warn ("Received packet that's too long");
      exit (EXIT_FAILURE);
    }

  /* Keep reading if we don't have the whole packet yet */
  if (conn->len_in < pkt_len)
    return;

  /* Receive the packet */
  rv = emcast_handler_recv (fd_fifo, pkt->data, data_len, NULL, 0);
  if (rv != data_len)
    exit (EXIT_FAILURE);

  /* Forward to other connections */
  for (i = conns; i != NULL; )
    {
      starconn* node = i;
      i = i->next;
      
      if (node == conn)
	continue;

      if ((node->len_out + pkt_len) > sizeof(node->buf_out))
	{
	  fprintf (stderr, "Dropping packet (buffer full)\n");
	  continue;
	}

      memcpy (&node->buf_out[node->len_out], pkt, pkt_len);
      node->len_out += pkt_len;
    }

  /* Move bytes over */
  memmove (conn->buf_in, &conn->buf_in[pkt_len], conn->len_in - pkt_len);
  conn->len_in -= pkt_len;
}


static void
starconn_send (starconn* conn)
{
  int rv;

  if (!conn->len_out)
    return;

  rv = send (conn->sd, conn->buf_out, conn->len_out, MSG_DONTWAIT);
  if (rv == -1)
    {
      if (errno == EWOULDBLOCK || errno == EINTR)
	return;

      if (conn == server_conn)
	exit (EXIT_FAILURE);
      else
	{
	  starconn_delete (conn);
	  return;
	}
    }

  memmove (conn->buf_out, &conn->buf_out[rv], conn->len_out - rv);
  conn->len_out -= rv;
}
