/***************************************************************************
 $RCSfile: socket.cpp,v $
                             -------------------
    cvs         : $Id: socket.cpp,v 1.14 2003/02/14 16:11:53 aquamaniac Exp $
    begin       : Tue Aug 28 2001
    copyright   : (C) 2001 by Martin Preuss
    email       : martin@aquamaniac.de
*/

/***************************************************************************
 *                                                                         *
 *   This library is free software; you can redistribute it and/or         *
 *   modify it under the terms of the GNU Lesser General Public            *
 *   License as published by the Free Software Foundation; either          *
 *   version 2.1 of the License, or (at your option) any later version.    *
 *                                                                         *
 *   This library is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU     *
 *   Lesser General Public License for more details.                       *
 *                                                                         *
 *   You should have received a copy of the GNU Lesser General Public      *
 *   License along with this library; if not, write to the Free Software   *
 *   Foundation, Inc., 59 Temple Place, Suite 330, Boston,                 *
 *   MA  02111-1307  USA                                                   *
 *                                                                         *
 ***************************************************************************/


#include <errno.h>
#include <unistd.h>

#include "error.h"
#include "socket.h"

#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

namespace HBCI {



/* AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
 * SocketSet
 * AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
 */


SocketSet::SocketSet(){
    FD_ZERO(&_set);
    _highest=0;
}


SocketSet::~SocketSet(){
}


void SocketSet::addSocket(Socket *s){
    if (!s)
        throw Error("SocketSet::addSocket",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      "no socket");
    if (s->_sock==-1)
        throw Error("SocketSet::addSocket",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      "socket not connected");
    _highest=(_highest<s->_sock)?s->_sock:_highest;
    FD_SET(s->_sock,&_set);
}


void SocketSet::removeSocket(Socket *s){
    if (!s)
        throw Error("SocketSet::removeSocket",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      "no socket");
    if (s->_sock==-1)
        throw Error("SocketSet::removeSocket",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      "socket not connected");
    FD_CLR(s->_sock,&_set);
}


bool SocketSet::hasSocket(Socket *s){
    if (!s)
        throw Error("SocketSet::hasSocket",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      "no socket");
    if (s->_sock==-1)
        throw Error("SocketSet::hasSocket",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      "socket not connected");
    return FD_ISSET(s->_sock,&_set);
}


void SocketSet::clear(){
    FD_ZERO(&_set);
    _highest=0;
}



/* AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
 * Socket
 * AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
 */


Socket::Socket(int stype){
    switch(stype) {
    case SOCKET_TYPE_TCP:
        _sock=socket(PF_INET,SOCK_STREAM,0);
        break;
    case SOCKET_TYPE_UDP:
        _sock=socket(PF_INET,SOCK_DGRAM,0);
        break;
    default:
        throw Error("Socket::Socket",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      "socket type");
    } // switch
    if (_sock==-1)
        throw Error("Socket::Socket",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      strerror(errno),
                      "socket type");
}


Socket::Socket() {
    _sock=-1;
}


Socket::Socket(const Socket &s){
    _sock=s._sock;
}


Socket::~Socket(){
    close();
}


Error Socket::close(){
    int rv;

    if (_sock==-1)
        return Error("Socket::close()",
                         ERROR_LEVEL_NORMAL,
                         0,
                         ERROR_ADVISE_DONTKNOW,
                         "socket not open");
    rv=::close(_sock);
    _sock=-1;
    if (rv==-1)
        return Error("Socket::close()",
                         ERROR_LEVEL_NORMAL,
                         0,
                         ERROR_ADVISE_DONTKNOW,
                         strerror(errno),
                         "error on close");
    return Error();
}


void Socket::operator=(const Socket &s){
    _sock=s._sock;
}


bool Socket::_waitSocketRead(long timeout) {
    SocketSet s;
    int rv;

    s.addSocket(this);
    rv=select(&s,0,0,timeout);
    if (rv<1)
        return false;
    return true;
}


bool Socket::_waitSocketWrite(long timeout) {
    SocketSet s;
    int rv;

    s.addSocket(this);
    rv=select(0,&s,0,timeout);
    if (rv<1)
        return false;
    return true;
}



int Socket::select(SocketSet *rs,
                     SocketSet *ws,
                     SocketSet *xs,
                     long timeout){
    int h,h1,h2,h3;
    fd_set *s1,*s2,*s3;
    int rv;
    struct timeval tv;

    s1=s2=s3=0;
    h1=h2=h3=0;
    
    if (rs) {
        h1=rs->_highest;
        s1=&rs->_set;
    }
    if (ws) {
        h2=ws->_highest;
        s2=&ws->_set;
    }
    if (xs) {
        h3=xs->_highest;
        s3=&xs->_set;
    }
    h=(h1>h2)?h1:h2;
    h=(h>h3)?h:h3;
    if (timeout<0)
        // wait for ever
        rv=::select(h+1,s1,s2,s3,0);
    else {
        // return immediately
        tv.tv_sec=0;
        tv.tv_usec=timeout*1000;
        rv=::select(h+1,s1,s2,s3,&tv);
    }
    if (rv<0) {
      if (errno==EINTR)
	throw Error("Socket::select",
		    ERROR_LEVEL_NORMAL,
		    HBCI_ERROR_CODE_SOCKET_ERROR_INTERRUPT,
		    ERROR_ADVISE_RETRY,
		    "interrupted",
		    "error on select");
      throw Error("Socket::select",
		  ERROR_LEVEL_NORMAL,
		  HBCI_ERROR_CODE_SOCKET_ERROR_UNKNOWN,
		  ERROR_ADVISE_DONTKNOW,
		  strerror(errno),
		  "error on select");
    }
    return rv;
}


Error Socket::startConnect(const InetAddress &addr,
			   unsigned short port){
    int fl;
    int rv;
    SocketSet s;
    sockaddr_in inaddr;

    inaddr=addr._inaddr;
    inaddr.sin_port=htons(port);

    // get current socket flags
    fl=fcntl(_sock,F_GETFL);
    if (fl==-1)
        return Error("Socket::startConnect",
                         ERROR_LEVEL_NORMAL,
                         0,
                         ERROR_ADVISE_DONTKNOW,
                         strerror(errno),
                         "error on fcntl(1)");

    // set nonblocking
    if (-1==fcntl(_sock,F_SETFL,fl | O_NONBLOCK))
	return Error("Socket::startConnect",
		     ERROR_LEVEL_NORMAL,
		     0,
		     ERROR_ADVISE_DONTKNOW,
		     strerror(errno),
		     "error on fcntl(2)");

    // try to connect
    rv=::connect(_sock,(sockaddr*)&inaddr,sizeof(inaddr));
    if (rv==-1) {
        // error, is it EINPROGRESS ?
	if (errno!=EINPROGRESS) {
	    // other error
	    abortConnect();
	    return Error("Socket::startConnect",
			 ERROR_LEVEL_NORMAL,
			 0,
			 ERROR_ADVISE_DONTKNOW,
			 strerror(errno),
			 "error on select");
	}
    } // if rv==-1

    return Error();
}


Error Socket::checkConnect(long timeout){
    SocketSet s;
    int fl;
    socklen_t rvs;
    int rv;

    try {
	// wait for connection or timeout
	s.addSocket(this);
	if (0==select(0,&s,0,timeout))
	    return Error("Socket::checkConnect",
			 ERROR_LEVEL_NORMAL,
			 HBCI_ERROR_CODE_SOCKET_ERROR_TIMEOUT,
			 ERROR_ADVISE_DONTKNOW,
			 "select timed out");
    }
    catch (HBCI::Error xerr) {
	abortConnect();
	return HBCI::Error("Socket::checkConnect",xerr);
    }

    // socket state changed, check for connection
    rvs=sizeof(rv);
    if (-1==getsockopt(_sock,SOL_SOCKET,SO_ERROR,&rv,&rvs)) {
        abortConnect();
	return Error("Socket::checkConnect",
		     ERROR_LEVEL_NORMAL,
		     0,
		     ERROR_ADVISE_DONTKNOW,
		     strerror(errno),
		     "error on getsockopt");
    }
    if (rv!=0) {
	abortConnect();
	return Error("Socket::checkConnect",
		     ERROR_LEVEL_NORMAL,
		     HBCI_ERROR_CODE_SOCKET_NO_CONNECT,
		     ERROR_ADVISE_DONTKNOW,
		     strerror(rv),
		     "error on connect");
    }

    fl=fcntl(_sock,F_GETFL);
    if (fl==-1)
      return Error("Socket::checkConnect",
		   ERROR_LEVEL_NORMAL,
		   0,
		   ERROR_ADVISE_DONTKNOW,
		   strerror(errno),
		   "error on fcntl(1)");

    // set blocking again
    if (-1==fcntl(_sock,F_SETFL,fl & ~O_NONBLOCK)) {
        abortConnect();
	return Error("Socket::checkConnect",
		     ERROR_LEVEL_NORMAL,
		     0,
		     ERROR_ADVISE_DONTKNOW,
		     strerror(errno),
		     "error on fcntl(2)");
    }
    return HBCI::Error();
}


void Socket::abortConnect(){
    int fl;

    // get socket flags
    fl=fcntl(_sock,F_GETFL);
    if (fl!=-1)
	// set blocking
	fcntl(_sock,F_SETFL,fl & ~O_NONBLOCK);
    // shutdown socket
    ::shutdown(_sock, SHUT_RDWR);
}


void Socket::bind(const InetAddress &addr, unsigned short port){
    sockaddr_in inaddr;

    inaddr=addr._inaddr;
    inaddr.sin_port=htons(port);
    if (::bind(_sock,(sockaddr*)&inaddr,sizeof(inaddr)))
        throw Error("Socket::bind",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      strerror(errno),
                      "error on bind");
}


void Socket::listen(int backlog){
    if (::listen(_sock,backlog))
        throw Error("Socket::listen",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      strerror(errno),
                      "error on listen");
}


Socket *Socket::accept(long tm){
    sockaddr_in a;
    socklen_t al;
    int s;
    SocketSet st;
    Socket *so;

    st.addSocket(this);
    if (0==select(&st,0,0,tm))
        throw Error("Socket::accept",
                      ERROR_LEVEL_NORMAL,
                      HBCI_ERROR_CODE_SOCKET_ERROR_TIMEOUT,
                      ERROR_ADVISE_DONTKNOW,
                      "select timed out");
    al=sizeof(a);
    s=::accept(_sock,(sockaddr*)&a, &al);
    if (s==-1)
        throw Error("Socket::accept",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      strerror(errno),
                      "error on accept");
    so=new Socket();
    so->_sock=s;
    return so;
}


Error Socket::readData(string &data, unsigned int size, long timeout){
    char *buffer;
    int i;

    if (_sock==-1)
        return Error("Socket::readData",
                         ERROR_LEVEL_NORMAL,
                         0,
                         ERROR_ADVISE_DONTKNOW,
                         "no socket");
    data.erase();

    // handle timeout
    if (timeout)
      try {
	if (!_waitSocketRead(timeout))
	  return Error("Socket::readData",
		       ERROR_LEVEL_NORMAL,
		       HBCI_ERROR_CODE_SOCKET_ERROR_TIMEOUT,
		       ERROR_ADVISE_DONTKNOW,
		       "_waitSocketRead timed out");
      }
    catch (Error xerr) {
      return Error("Socket::readData",xerr);
    }

    buffer=new char[size];
    i=recv(_sock,buffer,size,0);
    if (i>0)
        data.assign(buffer,i);
    delete buffer;
    // check for error
    if (i<0)
        return Error("Socket::readData",
                         ERROR_LEVEL_NORMAL,
                         0,
                         ERROR_ADVISE_DONTKNOW,
                         strerror(errno),
                         "error on recv");
    return Error();
}


Error Socket::writeData(string &data, long timeout){
    unsigned int i;
    unsigned int j;
    const char *p;

    if (_sock==-1)
      return Error("Socket::writeData",
		   ERROR_LEVEL_NORMAL,
		   0,
		   ERROR_ADVISE_DONTKNOW,
		   "no socket");

    // handle timeout
    if (timeout)
      try {
	if (!_waitSocketWrite(timeout))
	  return Error("Socket::writeData",
		       ERROR_LEVEL_NORMAL,
		       HBCI_ERROR_CODE_SOCKET_ERROR_TIMEOUT,
		       ERROR_ADVISE_DONTKNOW,
		       "_waitSocketWrite timed out");
      }
    catch (Error xerr) {
      return Error("Socket::writeData",xerr);
    }

    p=data.c_str();
    i=data.length();
#ifndef MSG_NOSIGNAL
#define MSG_NOSIGNAL 0
#endif    
    while(i) {
        j=send(_sock,p,i,MSG_NOSIGNAL);
        if (j<=0)
            return Error("Socket::writeData",
                             ERROR_LEVEL_NORMAL,
                             0,
                             ERROR_ADVISE_DONTKNOW,
                             strerror(errno),
                             "error on send");
        i-=j;
        p+=j;
    } // while
    return Error();
}


InetAddress Socket::getPeerAddress(){
    sockaddr_in peer_saddr;
    socklen_t i;
    string rv;

    i=sizeof(peer_saddr);
    if (getpeername(_sock,(sockaddr*)&peer_saddr,&i))
        throw Error("Socket::getPeerAddress",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      strerror(errno),
                      "getpeername");
    rv=::inet_ntoa(peer_saddr.sin_addr);
    if (rv.empty())
        throw Error("Socket::getPeerAddress",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      "bad addr");
    return rv;
}


unsigned short Socket::getPeerPort(){
    sockaddr_in peer_saddr;
    socklen_t i;

    i=sizeof(peer_saddr);
    if (getpeername(_sock,(sockaddr*)&peer_saddr,&i))
        throw Error("Socket::getPeerPort",
                      ERROR_LEVEL_NORMAL,
                      0,
                      ERROR_ADVISE_DONTKNOW,
                      strerror(errno),
                      "getpeername");
    return ntohs(peer_saddr.sin_port);
}


Error Socket::readDataFrom(string &data, unsigned int size, long timeout,
                               InetAddress &addr,
                               unsigned short &port){
    int bytesread;
    socklen_t i;
    char *buffer;

    if (_sock==-1)
        return Error("Socket::readDataFrom",
                         ERROR_LEVEL_NORMAL,
                         HBCI_ERROR_CODE_SOCKET_ERROR_TIMEOUT,
                         ERROR_ADVISE_DONTKNOW,
                         "no socket");

    // handle timeout
    if (timeout)
        if (!_waitSocketRead(timeout))
            return Error("Socket::readDataFrom",
                             ERROR_LEVEL_NORMAL,
                             HBCI_ERROR_CODE_SOCKET_ERROR_TIMEOUT,
                             ERROR_ADVISE_DONTKNOW,
                             "_waitSocketRead timed out");

    // Now read the bytes
    i=sizeof(sockaddr);
    buffer=new char[size];
    bytesread=::recvfrom(_sock,buffer,size,0,(sockaddr*)&addr._inaddr, &i);
    if (bytesread>0)
        data.assign(buffer,bytesread);
    delete buffer;

    if (bytesread==-1)
        return Error("Socket::readDataFrom",
                         ERROR_LEVEL_NORMAL,
                         0,
                         ERROR_ADVISE_DONTKNOW,
                         strerror(errno),
                         "recvfrom");
    port=ntohs(addr._inaddr.sin_port);

    return Error();
}


Error Socket::writeDataTo(string &data, long timeout,
                              const InetAddress &addr,
                              unsigned short port){
    int byteswritten;
    socklen_t i;
    sockaddr_in inaddr;

    inaddr=addr._inaddr;

    // handle timeout
    if (timeout)
        if (!_waitSocketWrite(timeout))
            return Error("Socket::writeDataTo",
                             ERROR_LEVEL_NORMAL,
                             HBCI_ERROR_CODE_SOCKET_ERROR_TIMEOUT,
                             ERROR_ADVISE_DONTKNOW,
                             "_waitSocketWrite timed out");

    // write data
    i=sizeof(sockaddr);
    inaddr.sin_port=htons(port);
    byteswritten=sendto(_sock,data.data(),data.length(),0,
                        (sockaddr*)&inaddr, i);
    if (byteswritten!=(int)data.length())
        return Error("Socket::writeDataTo",
                         ERROR_LEVEL_NORMAL,
                         HBCI_ERROR_CODE_SOCKET_ERROR_TIMEOUT,
                         ERROR_ADVISE_DONTKNOW,
                         strerror(errno),
                         "error on SENDTO");

    return Error();
}


} /* namespace HBCI */
