/*
 * Copyright (C) 2005 -2012 Michael Tuexen
 * Copyright (C) 2011 -2012 Irene Ruengeler
 *
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of the project nor the names of its contributors
 *    may be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE PROJECT AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.	IN NO EVENT SHALL THE PROJECT OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#include <sys/types.h>
#if defined(__Userspace_os_Windows)
#include <WinSock2.h>
#include <WS2tcpip.h>
#include <stdlib.h>
#include <crtdbg.h>
#else
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/time.h>
#include <unistd.h>
#include <pthread.h>
#endif
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <signal.h>
#include <errno.h>
#ifdef LINUX
#include <getopt.h>
#endif
#include <usrsctp.h>

/* global for the send callback, but used in kernel version as well */
static unsigned long number_of_messages;
static char *buffer;
static int length;
static struct sockaddr_in remote_addr;
static int unordered;
uint32_t optval = 1;
struct socket *psock = NULL;

static struct timeval start_time;
static unsigned long messages = 0;
static unsigned int first_length = 0;
static unsigned long long sum = 0;
static unsigned int use_cb = 0;

#ifndef timersub
#define timersub(tvp, uvp, vvp)                                   \
	do {                                                      \
		(vvp)->tv_sec = (tvp)->tv_sec - (uvp)->tv_sec;    \
		(vvp)->tv_usec = (tvp)->tv_usec - (uvp)->tv_usec; \
		if ((vvp)->tv_usec < 0) {                         \
			(vvp)->tv_sec--;                          \
			(vvp)->tv_usec += 1000000;                \
		}                                                 \
	} while (0)
#endif


char Usage[] =
"Usage: tsctp [options] [address]\n"
"Options:\n"
"        -a             set adaptation layer indication\n"
"        -c             use callback API\n"
"        -E             local UDP encapsulation port (default 9899)\n"
"        -f             fragmentation point\n"
"        -l             size of send/receive buffer\n"
"        -n             number of messages sent (0 means infinite)/received\n"
"        -D             turns Nagle off\n"
"        -T             time to send messages\n"
"        -u             use unordered user messages\n"
"        -U             remote UDP encapsulation port\n"
"        -v             verbose\n"
"        -V             very verbose\n"
;

#define DEFAULT_LENGTH             1024
#define DEFAULT_NUMBER_OF_MESSAGES 1024
#define DEFAULT_PORT               5001
#define BUFFERSIZE                 (1<<16)

static int verbose, very_verbose;
static unsigned int done;

void stop_sender(int sig)
{
	done = 1;
}

static void*
handle_connection(void *arg)
{
	ssize_t n;
	unsigned long long sum = 0;
	char *buf;
#if defined(__Userspace_os_Windows)
	HANDLE tid;
#else
	pthread_t tid;
#endif
	struct socket *conn_sock;
	struct timeval start_time, now, diff_time;
	double seconds;
	unsigned long messages = 0;
	unsigned long recv_calls = 0;
	unsigned long notifications = 0;
	unsigned int first_length;
	int flags;
	struct sockaddr_in addr;
	socklen_t len;
	union sctp_notification *snp;
	struct sctp_paddr_change *spc;
	struct timeval note_time;
	unsigned int infotype = 0;
	struct sctp_recvv_rn rn;
	socklen_t infolen = sizeof(struct sctp_recvv_rn);
	struct sctp_rcvinfo rcv;
	struct sctp_nxtinfo nxt;

	conn_sock = *(struct socket **)arg;
#if defined(__Userspace_os_Windows)
	tid = GetCurrentThread();
#else
	tid = pthread_self();
	pthread_detach(tid);
#endif

	buf = malloc(BUFFERSIZE);
	flags = 0;
	len = (socklen_t)sizeof(struct sockaddr_in);
	rn.recvv_rcvinfo = rcv;
	rn.recvv_nxtinfo = nxt;

	n = usrsctp_recvv(conn_sock, buf, BUFFERSIZE, (struct sockaddr *) &addr, &len, (void *)&rn,
	                 &infolen, &infotype, &flags);

#if defined (__Userspace_os_Windows)
	getwintimeofday(&start_time);
#else
	gettimeofday(&start_time, NULL);
#endif
	first_length = 0;
	while (n > 0) {
		recv_calls++;
		if (flags & MSG_NOTIFICATION) {
			notifications++;
#if defined (__Userspace_os_Windows)
			getwintimeofday(&note_time);
#else
			gettimeofday(&note_time, NULL);
#endif
			printf("notification arrived at %f\n", note_time.tv_sec+(double)note_time.tv_usec/1000000.0);
			snp = (union sctp_notification*)&buf;
			if (snp->sn_header.sn_type==SCTP_PEER_ADDR_CHANGE)
			{
				spc = &snp->sn_paddr_change;
				printf("SCTP_PEER_ADDR_CHANGE: state=%d, error=%d\n",spc->spc_state, spc->spc_error);
			}
		} else {
			if (very_verbose) {
				printf("Message received\n");
			}
			sum += n;
			if (flags & MSG_EOR) {
				messages++;
				if (first_length == 0)
					first_length = sum;
			}
		}
		flags = 0;
		len = (socklen_t)sizeof(struct sockaddr_in);
		infolen = sizeof(struct sctp_recvv_rn);
		infotype = 0;
		n = usrsctp_recvv(conn_sock, (void *) buf, BUFFERSIZE, (struct sockaddr *) &addr, &len, (void *)&rn,
	                      &infolen, &infotype, &flags);
	}
	if (n < 0)
		perror("sctp_recvv");
#if defined (__Userspace_os_Windows)
	getwintimeofday(&now);
#else
	gettimeofday(&now, NULL);
#endif
	timersub(&now, &start_time, &diff_time);
	seconds = diff_time.tv_sec + (double)diff_time.tv_usec/1000000.0;
	printf("%u, %lu, %lu, %lu, %llu, %f, %f\n",
	        first_length, messages, recv_calls, notifications, sum, seconds, (double)first_length * (double)messages / seconds);
	fflush(stdout);
	usrsctp_close(conn_sock);
	free(buf);
	return NULL;
}

static int
send_cb(struct socket *sock, uint32_t sb_free) {
	struct sctp_sndinfo sndinfo;
	/*struct sctp_prinfo prinfo;
	struct sctp_sendv_spa spa;*/

	sndinfo.snd_sid = 0;
	sndinfo.snd_flags = 0;
	sndinfo.snd_ppid = 0;
	sndinfo.snd_context = 0;
	sndinfo.snd_assoc_id = 0;

	/*prinfo.pr_policy = SCTP_PR_SCTP_RTX;
	prinfo.pr_value = 2;
	spa.sendv_sndinfo = sndinfo;
	spa.sendv_prinfo = prinfo;
	spa.sendv_flags = SCTP_SEND_SNDINFO_VALID | SCTP_SEND_PRINFO_VALID;*/

	while (!done && ((number_of_messages == 0) || (messages < (number_of_messages - 1)))) {
		if (very_verbose)
			printf("Sending message number %lu.\n", messages + 1);

		if (usrsctp_sendv(psock, buffer, length, (struct sockaddr *) &remote_addr, 1,
				              (void *)&sndinfo, (socklen_t)sizeof(struct sctp_sndinfo), SCTP_SENDV_SNDINFO,
				              unordered?SCTP_UNORDERED:0) < 0) {
			if (errno != EWOULDBLOCK && errno != EAGAIN) {
				perror("usrsctp_sendmsg (cb) returned < 0");
				exit(1);
			} else {
				/* send until EWOULDBLOCK then exit callback. */
				return 1;
			}
		}
		messages++;
	}
	if ((done == 1) || (messages == (number_of_messages - 1))) {
		if (very_verbose)
			printf("Sending final message number %lu.\n", messages + 1);

		if (usrsctp_sendv(psock, buffer, length, (struct sockaddr *) &remote_addr, 1,
				              (void *)&sndinfo, (socklen_t)sizeof(struct sctp_sndinfo), SCTP_SENDV_SNDINFO,
				              unordered?(SCTP_UNORDERED|SCTP_EOF):SCTP_EOF) < 0) {
			if (errno != EWOULDBLOCK && errno != EAGAIN) {
				perror("usrsctp_sendmsg (cb) returned < 0");
				exit(1);
			} else {
				/* send until EWOULDBLOCK then exit callback. */
				return 1;
			}
		}
		messages++;
		done = 2;
	}

	return 1;
}

static int
receive_cb(struct socket *sock, union sctp_sockstore addr, void *data,
           size_t datalen, struct sctp_rcvinfo rcv, int flags, void *ulp_info)
{
	struct timeval now, diff_time;
	double seconds;

	if (data == NULL) {
#if defined (__Userspace_os_Windows)
		getwintimeofday(&now);
#else
		gettimeofday(&now, NULL);
#endif
		timersub(&now, &start_time, &diff_time);
		seconds = diff_time.tv_sec + (double)diff_time.tv_usec/1000000.0;
		printf("%u, %lu, %llu, %f, %f\n",
			first_length, messages, sum, seconds, (double)first_length * (double)messages / seconds);
		usrsctp_close(sock);
		first_length = 0;
		sum = 0;
		messages = 0;
		return 1;
	}
	if (first_length == 0) {
		first_length = datalen;
#if defined (__Userspace_os_Windows)
		getwintimeofday(&start_time);
#else
		gettimeofday(&start_time, NULL);
#endif
	}
	sum += datalen;
	messages++;

  free(data);
	return 1;
}


int main(int argc, char **argv)
{
#if !defined (__Userspace_os_Windows)
	char c;
#endif
	socklen_t addr_len;
	struct sockaddr_in local_addr;
	struct timeval start_time, now, diff_time;
	int client;
	uint16_t local_port, remote_port, port, local_udp_port, remote_udp_port;
	double seconds;
	double throughput;
	int nodelay = 0;
	struct sctp_assoc_value av;
	struct sctp_udpencaps encaps;
	struct sctp_sndinfo sndinfo;
#if defined(__Userspace_os_Windows)
	HANDLE tid;
#else
	pthread_t tid;
#endif
	int fragpoint = 0;
	unsigned int runtime = 0;
	struct sctp_setadaptation ind = {0};
#if defined (__Userspace_os_Windows)
	char *opt;
	int optind;
#endif
	unordered = 0;
	/*struct sctp_prinfo prinfo;
	struct sctp_sendv_spa spa;*/

	length = DEFAULT_LENGTH;
	number_of_messages = DEFAULT_NUMBER_OF_MESSAGES;
	port = DEFAULT_PORT;
	remote_udp_port = 0;
	local_udp_port = 9899;
	verbose = 0;
	very_verbose = 0;

	memset((void *) &remote_addr, 0, sizeof(struct sockaddr_in));
	memset((void *) &local_addr, 0, sizeof(struct sockaddr_in));

#if !defined (__Userspace_os_Windows)
	while ((c = getopt(argc, argv, "a:cp:l:E:f:n:T:uU:vVD")) != -1)
		switch(c) {
			case 'a':
				ind.ssb_adaptation_ind = atoi(optarg);
				break;
			case 'c':
				use_cb = 1;
				break;
			case 'l':
				length = atoi(optarg);
				break;
			case 'n':
				number_of_messages = atoi(optarg);
				break;
			case 'p':
				port = atoi(optarg);
				break;
			case 'E':
				local_udp_port = atoi(optarg);
				break;
			case 'f':
				fragpoint = atoi(optarg);
				break;
			case 'T':
				runtime = atoi(optarg);
				number_of_messages = 0;
				break;
			case 'u':
				unordered = 1;
				break;
			case 'U':
				remote_udp_port = atoi(optarg);
				break;
			case 'v':
				verbose = 1;
				break;
			case 'V':
				verbose = 1;
				very_verbose = 1;
				break;
			case 'D':
				nodelay = 1;
				break;
			default:
				fprintf(stderr, "%s", Usage);
				exit(1);
		}
#else
	for (optind = 1; optind < argc; optind++) {
		if (argv[optind][0] == '-') {
			switch (argv[optind][1]) {
				case 'a':
					if (++optind >= argc) {
						printf("%s", Usage);
						exit(1);
					}
					opt = argv[optind];
					ind.ssb_adaptation_ind = atoi(opt);
					break;
				case 'c':
					use_cb = 1;
					break;
				case 'l':
					if (++optind >= argc) {
						printf("%s", Usage);
						exit(1);
					}
					opt = argv[optind];
					length = atoi(opt);
					break;
				case 'p':
					if (++optind >= argc) {
						printf("%s", Usage);
						exit(1);
					}
					opt = argv[optind];
					port = atoi(opt);
					break;
				case 'n':
					if (++optind >= argc) {
						printf("%s", Usage);
						exit(1);
					}
					opt = argv[optind];
					number_of_messages = atoi(opt);
					break;
				case 'f':
					if (++optind >= argc) {
						printf("%s", Usage);
						exit(1);
					}
					opt = argv[optind];
					fragpoint = atoi(opt);
					break;
				case 'U':
					if (++optind >= argc) {
						printf("%s", Usage);
						exit(1);
					}
					opt = argv[optind];
					remote_udp_port = atoi(opt);
					break;
				case 'E':
					if (++optind >= argc) {
						printf("%s", Usage);
						exit(1);
					}
					opt = argv[optind];
					local_udp_port = atoi(opt);
					break;
				case 'T':
					if (++optind >= argc) {
						printf("%s", Usage);
						exit(1);
					}
					opt = argv[optind];
					runtime = atoi(opt);
					number_of_messages = 0;
					break;
				case 'u':
					unordered = 1;
					break;
				case 'v':
					verbose = 1;
					break;
				case 'V':
					verbose = 1;
					very_verbose = 1;
					break;
				case 'D':
					nodelay = 1;
					break;
				default:
					printf("%s", Usage);
					exit(1);
			}
		} else {
			break;
		}
	}
#endif
	if (optind == argc) {
		client = 0;
		local_port = port;
		remote_port = 0;
	} else {
		client = 1;
		local_port = 0;
		remote_port = port;
	}
	local_addr.sin_family = AF_INET;
#ifdef HAVE_SIN_LEN
	local_addr.sin_len = sizeof(struct sockaddr_in);
#endif
	local_addr.sin_port = htons(local_port);
	local_addr.sin_addr.s_addr = htonl(INADDR_ANY);

	usrsctp_init(local_udp_port);
	usrsctp_sysctl_set_sctp_debug_on(0);
	usrsctp_sysctl_set_sctp_blackhole(2);

	if (client) {
		if (use_cb) {
			if (!(psock = usrsctp_socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP, receive_cb, send_cb, length, NULL)) ){
				printf("user_socket() returned NULL\n");
				exit(1);
			}
		} else {
			if (!(psock = usrsctp_socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP, NULL, NULL, 0, NULL)) ){
				printf("user_socket() returned NULL\n");
				exit(1);
			}
		}
	} else {
		if (use_cb) {
			if (!(psock = usrsctp_socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP, receive_cb, NULL, 0, NULL)) ){
				printf("user_socket() returned NULL\n");
				exit(1);
			}
		} else {
			if (!(psock = usrsctp_socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP, NULL, NULL, 0, NULL)) ){
				printf("user_socket() returned NULL\n");
				exit(1);
			}
		}
	}

	if (usrsctp_bind(psock, (struct sockaddr *)&local_addr, sizeof(struct sockaddr_in)) == -1) {
		printf("usrsctp_bind failed.\n");
		exit(1);
	}

	if (usrsctp_setsockopt(psock, IPPROTO_SCTP, SCTP_ADAPTATION_LAYER, (const void*)&ind, (socklen_t)sizeof(struct sctp_setadaptation)) < 0) {
		perror("setsockopt");
	}

	if (!client) {
		if (usrsctp_listen(psock, 1) < 0) {
			printf("usrsctp_listen failed.\n");
			exit(1);
		}

		while (1) {
			memset(&remote_addr, 0, sizeof(struct sockaddr_in));
			addr_len = sizeof(struct sockaddr_in);
			if (use_cb) {
				struct socket *conn_sock;

				if ((conn_sock = usrsctp_accept(psock, (struct sockaddr *) &remote_addr, &addr_len))== NULL) {
					printf("usrsctp_accept failed.  exiting...\n");
					continue;
				}
			} else {
				struct socket **conn_sock;

				conn_sock = (struct socket **)malloc(sizeof(struct socket *));
				if ((*conn_sock = usrsctp_accept(psock, (struct sockaddr *) &remote_addr, &addr_len))== NULL) {
					printf("usrsctp_accept failed.  exiting...\n");
					continue;
				}
#if defined(__Userspace_os_Windows)
				tid = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE)&handle_connection, (void *)conn_sock, 0, NULL);
#else
				pthread_create(&tid, NULL, &handle_connection, (void *)conn_sock);
#endif
			}
			if (verbose)
				printf("Connection accepted from %s:%d\n", inet_ntoa(remote_addr.sin_addr), ntohs(remote_addr.sin_port));
		}
		usrsctp_close(psock);
	} else {
		memset(&encaps, 0, sizeof(struct sctp_udpencaps));
		encaps.sue_address.ss_family = AF_INET;
		encaps.sue_port = htons(remote_udp_port);
		if (usrsctp_setsockopt(psock, IPPROTO_SCTP, SCTP_REMOTE_UDP_ENCAPS_PORT, (const void*)&encaps, (socklen_t)sizeof(struct sctp_udpencaps)) < 0) {
			perror("setsockopt");
		}

		remote_addr.sin_family = AF_INET;
#ifdef HAVE_SIN_LEN
		remote_addr.sin_len = sizeof(struct sockaddr_in);
#endif
		remote_addr.sin_addr.s_addr = inet_addr(argv[optind]);
		remote_addr.sin_port = htons(remote_port);

		/* TODO fragpoint stuff */
		if (nodelay == 1) {
			optval = 1;
		} else {
			optval = 0;
		}
		usrsctp_setsockopt(psock, IPPROTO_SCTP, SCTP_NODELAY, &optval, sizeof(int));

		if (fragpoint) {
			av.assoc_id = 0;
			av.assoc_value = fragpoint;
			if (usrsctp_setsockopt(psock, IPPROTO_SCTP, SCTP_MAXSEG, &av, sizeof(struct sctp_assoc_value)) < 0)
				perror("setsockopt: SCTP_MAXSEG");
		}

		if (usrsctp_connect(psock, (struct sockaddr *) &remote_addr, sizeof(struct sockaddr_in)) == -1 ) {
			printf("usrsctpconnect failed.  exiting...\n");
			exit(1);
		}

		buffer = malloc(length);
		memset(buffer, 'b', length);
#if defined (__Userspace_os_Windows)
		getwintimeofday(&start_time);
#else
		gettimeofday(&start_time, NULL);
#endif
		if (verbose) {
			printf("Start sending %ld messages...\n", (long)number_of_messages);
			fflush(stdout);
		}

		done = 0;

		if (runtime > 0) {
#if !defined (__Userspace_os_Windows)
			signal(SIGALRM, stop_sender);
			alarm(runtime);
#else
			printf("You cannot set the runtime in Windows yet\n");
			exit(-1);
#endif
		}

		messages = 0;

		sndinfo.snd_sid = 0;
		sndinfo.snd_flags = 0;
		sndinfo.snd_ppid = 0;
		sndinfo.snd_context = 0;
		sndinfo.snd_assoc_id = 0;

	/*prinfo.pr_policy = SCTP_PR_SCTP_RTX;
		prinfo.pr_value = 2;
		spa.sendv_sndinfo = sndinfo;
		spa.sendv_prinfo = prinfo;
		spa.sendv_flags = SCTP_SEND_SNDINFO_VALID | SCTP_SEND_PRINFO_VALID;*/

		if (use_cb) {
			if (very_verbose)
				printf("Sending message number %lu.\n", messages);

				if (usrsctp_sendv(psock, buffer, length, (struct sockaddr *) &remote_addr, 1,
				                  (void *)&sndinfo, (socklen_t)sizeof(struct sctp_sndinfo), SCTP_SENDV_SNDINFO,
				                  unordered?SCTP_UNORDERED:0) < 0) {
				perror("usrctp_sendv returned < 0");
				exit(1);
			}
			messages++;
			while (!done && (messages < (number_of_messages - 1))) {
#if defined (__Userspace_os_Windows)
				Sleep(1000);
#else
				sleep(1);
#endif
			}
		} else {
			while (!done && ((number_of_messages == 0) || (messages < (number_of_messages - 1)))) {
				if (very_verbose)
					printf("Sending message number %lu.\n", messages + 1);

				if (usrsctp_sendv(psock, buffer, length, (struct sockaddr *) &remote_addr, 1,
				                  (void *)&sndinfo, (socklen_t)sizeof(struct sctp_sndinfo), SCTP_SENDV_SNDINFO,
				                  unordered?SCTP_UNORDERED:0) < 0) {
					perror("usrsctp_sendv returned < 0");
					exit(1);
				}
				messages++;
			}
			if (very_verbose)
				printf("Sending message number %lu.\n", messages + 1);

			if (usrsctp_sendv(psock, buffer, length, (struct sockaddr *) &remote_addr, 1,
			                  (void *)&sndinfo, (socklen_t)sizeof(struct sctp_sndinfo), SCTP_SENDV_SNDINFO,
			                  unordered?SCTP_UNORDERED:0) < 0) {
				perror("final usrsctp_sendv returned\n");
				exit(1);
			}
			messages++;
		}
		free (buffer);
		if (verbose)
			printf("done.\n");

		usrsctp_close(psock);
#if defined (__Userspace_os_Windows)
		getwintimeofday(&now);
#else
		gettimeofday(&now, NULL);
#endif
		timersub(&now, &start_time, &diff_time);
		seconds = diff_time.tv_sec + (double)diff_time.tv_usec/1000000;
		printf("%s of %ld messages of length %u took %f seconds.\n",
		       "Sending", messages, length, seconds);
		throughput = (double)messages * (double)length / seconds;
		printf("Throughput was %f Byte/sec.\n", throughput);
	}

	while (usrsctp_finish() != 0) {
#if defined (__Userspace_os_Windows)
		Sleep(1000);
#else
		sleep(1);
#endif
	}
	return 0;
}
