///erase all sockets and callbacks since the refcount is now zero;
MultiplexedSocket::~MultiplexedSocket() {
    Stream::SubstreamCallback callbackToBeDeleted=mNewSubstreamCallback;
    mNewSubstreamCallback=&Stream::ignoreSubstreamCallback;
    TCPSetCallbacks setCallbackFunctor(this,NULL);
    callbackToBeDeleted(NULL,setCallbackFunctor);
    for (unsigned int i=0;i<(unsigned int)mSockets.size();++i){
        mSockets[i].shutdownAndClose();
    }
    boost::lock_guard<boost::mutex> connecting_mutex(sConnectingMutex);
    for (unsigned int i=0;i<(unsigned int)mSockets.size();++i){
        mSockets[i].destroySocket();
    }
    mSockets.clear();

    while (!mCallbackRegistration.empty()){
        delete mCallbackRegistration.front().mCallback;
        mCallbackRegistration.pop_front();
    }
    if (mNewRequests) {
        std::deque<RawRequest> newRequests;
        mNewRequests->popAll(&newRequests);
        for (std::deque<RawRequest>::iterator i=newRequests.begin(),ie=newRequests.end();i!=ie;++i) {
            delete i->data;
        }
        delete mNewRequests;
        mNewRequests=NULL;
    }
    while(!mCallbacks.empty()) {
        delete mCallbacks.begin()->second;
        mCallbacks.erase(mCallbacks.begin());
    }
}
///gets called when a complete 24 byte header is actually received: uses the UUID within to match up appropriate sockets
void buildStream(TcpSstHeaderArray *buffer,
                 TCPSocket *socket,
                 std::tr1::shared_ptr<TCPStreamListener::Data> data,
                 const boost::system::error_code &error,
                 std::size_t bytes_transferred) {

    if (error || std::memcmp(buffer->begin(),TCPStream::STRING_PREFIX(),TCPStream::STRING_PREFIX_LENGTH)!=0) {
        SILOG(tcpsst,warning,"Connection received with incomprehensible header");
    } else {
        boost::asio::ip::tcp::no_delay option(data->mNoDelay);
        socket->set_option(option);
        UUID context=UUID(buffer->begin()+(TCPStream::TcpSstHeaderSize-16),16);
        IncompleteStreamMap::iterator where=sIncompleteStreams.find(context);
        unsigned int numConnections=(((*buffer)[TCPStream::STRING_PREFIX_LENGTH]-'0')%10)*10+(((*buffer)[TCPStream::STRING_PREFIX_LENGTH+1]-'0')%10);
        if (numConnections>data->mMaxSimultaneousSockets) numConnections=data->mMaxSimultaneousSockets;
        if (where==sIncompleteStreams.end()) {
            sIncompleteStreams[context].mNumSockets=numConnections;
            where=sIncompleteStreams.find(context);
            assert(where!=sIncompleteStreams.end());
        }
        if ((int)numConnections!=where->second.mNumSockets) {
            SILOG(tcpsst,warning,"Single client disagrees on number of connections to establish: "<<numConnections<<" != "<<where->second.mNumSockets);
            sIncompleteStreams.erase(where);
        } else {
            where->second.mSockets.push_back(socket);
            if (numConnections==(unsigned int)where->second.mSockets.size()) {
                MultiplexedSocketPtr shared_socket(
                    MultiplexedSocket::construct<MultiplexedSocket>(&data->ios,context,where->second.mSockets,data->cb,data->mSendBufferSize));
                MultiplexedSocket::sendAllProtocolHeaders(shared_socket,UUID::random());
                sIncompleteStreams.erase(where);
                Stream::StreamID newID=Stream::StreamID(1);
                TCPStream * strm=new TCPStream(shared_socket,newID);

                TCPSetCallbacks setCallbackFunctor(&*shared_socket,strm);
                data->cb(strm,setCallbackFunctor);
                if (setCallbackFunctor.mCallbacks==NULL) {
                    SILOG(tcpsst,error,"Client code for stream "<<newID.read()<<" did not set listener on socket");
                    shared_socket->closeStream(shared_socket,newID);
                }
            } else {
                sStaleUUIDs.push_back(context);
            }
        }
    }
    delete buffer;
}
///gets called when a complete 24 byte header is actually received: uses the UUID within to match up appropriate sockets
void buildStream(Array<uint8,TCPStream::TcpSstHeaderSize> *buffer,
                 TCPSocket *socket,
                 IOService *ioService,
                 Stream::SubstreamCallback callback,
                 const boost::system::error_code &error,
                 std::size_t bytes_transferred) {
    if (error || std::memcmp(buffer->begin(),TCPStream::STRING_PREFIX(),TCPStream::STRING_PREFIX_LENGTH)!=0) {
        SILOG(tcpsst,warning,"Connection received with incomprehensible header");
    }else {
        UUID context=UUID(buffer->begin()+(TCPStream::TcpSstHeaderSize-16),16);
        IncompleteStreamMap::iterator where=sIncompleteStreams.find(context);
        unsigned int numConnections=(((*buffer)[TCPStream::STRING_PREFIX_LENGTH]-'0')%10)*10+(((*buffer)[TCPStream::STRING_PREFIX_LENGTH+1]-'0')%10);
        if (numConnections>99) numConnections=99;//FIXME: some option in options
        if (where==sIncompleteStreams.end()){
            sIncompleteStreams[context].mNumSockets=numConnections;
            where=sIncompleteStreams.find(context);
            assert(where!=sIncompleteStreams.end());
        }
        if ((int)numConnections!=where->second.mNumSockets) {
            SILOG(tcpsst,warning,"Single client disagrees on number of connections to establish: "<<numConnections<<" != "<<where->second.mNumSockets);
            sIncompleteStreams.erase(where);
        }else {
            where->second.mSockets.push_back(socket);
            if (numConnections==(unsigned int)where->second.mSockets.size()) {
                std::tr1::shared_ptr<MultiplexedSocket> shared_socket(
                    MultiplexedSocket::construct<MultiplexedSocket>(ioService,context,where->second.mSockets,callback));
                MultiplexedSocket::sendAllProtocolHeaders(shared_socket,UUID::random());
                sIncompleteStreams.erase(where);
                Stream::StreamID newID=Stream::StreamID(1);
                TCPStream * strm=new TCPStream(shared_socket,newID);

                TCPSetCallbacks setCallbackFunctor(&*shared_socket,strm);
                callback(strm,setCallbackFunctor);
                if (setCallbackFunctor.mCallbacks==NULL) {
                    SILOG(tcpsst,error,"Client code for stream "<<newID.read()<<" did not set listener on socket");
                    shared_socket->closeStream(shared_socket,newID);
                }
            }else{
                sStaleUUIDs.push_back(context);
            }
        }
    }
    delete buffer;
}
Beispiel #4
0
///gets called when a complete 24 byte header is actually received: uses the UUID within to match up appropriate sockets
void buildStream(TcpSstHeaderArray *buffer,
                 TCPSocket *socket,
                 std::tr1::shared_ptr<TCPStreamListener::Data> data,
                 const boost::system::error_code &error,
                 std::size_t bytes_transferred)
{
    // Buffer always needs to be cleaned up when we get out of this method
    std::auto_ptr<TcpSstHeaderArray> buffer_ptr(buffer);

    // Sanity check start
    if (error || bytes_transferred < 5 || std::string((const char*)buffer->begin(), 5) != std::string("GET /")) {
        SILOG(tcpsst,warning,"Connection received with truncated header: "<<error);
        delete socket;
        return;
    }

    // Sanity check end: 8 bytes from WebSocket spec after headers, then
    // \r\n\r\n before that.
    std::string buffer_str((const char*)buffer->begin(), bytes_transferred);
    if (buffer_str[ bytes_transferred - 12] != '\r' ||
        buffer_str[ bytes_transferred - 11] != '\n' ||
        buffer_str[ bytes_transferred - 10] != '\r' ||
        buffer_str[ bytes_transferred - 9] != '\n')
    {
        SILOG(tcpsst,warning,"Request doesn't end properly:\n" << buffer_str << "\n");
        delete socket;
        return;
    }

    std::string headers_str = buffer_str.substr(0, bytes_transferred - 10);
    // Parse headers
    UUID context;
    std::map<std::string, std::string> headers;
    std::string::size_type offset = 0;
    while(offset < headers_str.size()) {
        std::string::size_type last_offset = offset;
        offset = headers_str.find("\r\n", offset);
        if (offset == std::string::npos) {
            SILOG(tcpsst,warning,"Error parsing headers.");
            delete socket;
            return;
        }

        std::string line = headers_str.substr(last_offset, offset - last_offset);

        // Special case the initial GET line
        if (line.substr(0, 5) == "GET /") {
            std::string::size_type uuid_end = line.find(' ', 5);
            if (uuid_end == std::string::npos) {
                SILOG(tcpsst,warning,"Error parsing headers: invalid get line.");
                delete socket;
                return;
            }
            std::string uuid_str = line.substr(5, uuid_end - 5);
            try {
                context = UUID(uuid_str,UUID::HumanReadable());
            } catch(...) {
                SILOG(tcpsst,warning,"Error parsing headers: invalid get uuid.");
                delete socket;
                return;
            }

            offset += 2;
            continue;
        }

        std::string::size_type colon = line.find(":");
        if (colon == std::string::npos) {
            SILOG(tcpsst,warning,"Error parsing headers: missing colon.");
            delete socket;
            return;
        }
        std::string head = line.substr(0, colon);
        std::string val = line.substr(colon+2);

        headers[head] = val;

        // Advance to get past the \r\n
        offset += 2;
    }

    if (headers.find("Host") == headers.end() ||
        headers.find("Origin") == headers.end() ||
        headers.find("Sec-WebSocket-Key1") == headers.end() ||
        headers.find("Sec-WebSocket-Key2") == headers.end())
    {
        SILOG(tcpsst,warning,"Connection request didn't specify all required fields.");
        delete socket;
        return;
    }

    std::string host = headers["Host"];
    std::string origin = headers["Origin"];
    std::string protocol = "wssst1";
    if (headers.find("Sec-WebSocket-Protocol") != headers.end())
        protocol = headers["Sec-WebSocket-Protocol"];
    std::string key1 = headers["Sec-WebSocket-Key1"];
    std::string key2 = headers["Sec-WebSocket-Key2"];
    std::string key3 = buffer_str.substr(bytes_transferred - 8);
    assert(key3.size() == 8);

    std::string reply_str = getWebSocketSecReply(key1, key2, key3);

    bool binaryStream=protocol.find("sst")==0;
    bool base64Stream=!binaryStream;
    boost::asio::ip::tcp::no_delay option(data->mNoDelay);
    socket->set_option(option);
    IncompleteStreamMap::iterator where=sIncompleteStreams.find(context);

    unsigned int numConnections=1;

    for (std::string::iterator i=protocol.begin(),ie=protocol.end();i!=ie;++i) {
        if (*i>='0'&&*i<='9') {
            char* endptr=NULL;
            const char *start=protocol.c_str();
            size_t offset=(i-protocol.begin());
            start+=offset;
            numConnections=strtol(start,&endptr,10);
            size_t numberlen=endptr-start;
            if (numConnections>data->mMaxSimultaneousSockets) {
                numConnections=data->mMaxSimultaneousSockets;
                char numcon[256];
                sprintf(numcon,"%d",numConnections);
                protocol=protocol.substr(0,offset)+numcon+protocol.substr(offset+numberlen);
            }
            break;
        }
    }

    if (where==sIncompleteStreams.end()){
        sIncompleteStreams[context].mNumSockets=numConnections;
        where=sIncompleteStreams.find(context);
        assert(where!=sIncompleteStreams.end());
        // Setup a timer to clean up the sockets if we don't complete it in time
        data->strand->post(
            Duration::seconds(10),
            std::tr1::bind(&handleBuildStreamTimeout, context)
        );
    }
    if ((int)numConnections!=where->second.mNumSockets) {
        SILOG(tcpsst,warning,"Single client disagrees on number of connections to establish: "<<numConnections<<" != "<<where->second.mNumSockets);
        sIncompleteStreams.erase(where);
    }else {
        where->second.mSockets.push_back(socket);
        where->second.mWebSocketResponses[socket] = reply_str;
        if (numConnections==(unsigned int)where->second.mSockets.size()) {
            MultiplexedSocketPtr shared_socket(
                MultiplexedSocket::construct<MultiplexedSocket>(data->strand,context,data->cb,base64Stream));
            shared_socket->initFromSockets(where->second.mSockets,data->mSendBufferSize);
            std::string port=shared_socket->getASIOSocketWrapper(0).getLocalEndpoint().getService();
            std::string resource_name='/'+context.toString();
            MultiplexedSocket::sendAllProtocolHeaders(shared_socket,origin,host,port,resource_name,protocol, where->second.mWebSocketResponses);
            sIncompleteStreams.erase(where);


            Stream::StreamID newID=Stream::StreamID(1);
            TCPStream * strm=new TCPStream(shared_socket,newID);

            TCPSetCallbacks setCallbackFunctor(&*shared_socket,strm);
            data->cb(strm,setCallbackFunctor);
            if (setCallbackFunctor.mCallbacks==NULL) {
                SILOG(tcpsst,error,"Client code for stream "<<newID.read()<<" did not set listener on socket");
                shared_socket->closeStream(shared_socket,newID);
            }
        }else{
            sStaleUUIDs.push_back(context);
        }
    }
}
Stream::ReceivedResponse MultiplexedSocket::receiveFullChunk(unsigned int whichSocket, Stream::StreamID id,Chunk&newChunk){
    Stream::ReceivedResponse retval = Stream::AcceptedData;
    if (id==Stream::StreamID()) {//control packet
        if(newChunk.size()) {
            unsigned int controlCode=*newChunk.begin();
            switch (controlCode) {
              case TCPStream::TCPStreamCloseStream:
              case TCPStream::TCPStreamAckCloseStream:
                if (newChunk.size()>1) {
                    unsigned int avail_len=newChunk.size()-1;
                    id.unserialize((const uint8*)&(newChunk[1]),avail_len);
                    if (avail_len+1>newChunk.size()) {
                        SILOG(tcpsst,warning,"Control Chunk too short");
                    }
                }
                if (id!=Stream::StreamID()) {
                    std::tr1::unordered_map<Stream::StreamID,unsigned int>::iterator where=mAckedClosingStreams.find(id);
                    if (where!=mAckedClosingStreams.end()){
                        where->second++;
                        int how_much=where->second;
                        if (where->second==mSockets.size()) {
                            mAckedClosingStreams.erase(where);
                            shutDownClosedStream(controlCode,id);
                            if (controlCode==TCPStream::TCPStreamCloseStream) {
                                closeStream(getSharedPtr(),id,TCPStream::TCPStreamAckCloseStream);
                            }
                        }
                    }else{
                        if (mSockets.size()==1) {
                            shutDownClosedStream(controlCode,id);
                            if (controlCode==TCPStream::TCPStreamCloseStream) {
                                closeStream(getSharedPtr(),id,TCPStream::TCPStreamAckCloseStream);
                            }
                        }else {
                            mAckedClosingStreams[id]=1;
                        }
                    }
                }
                break;
              default:
                break;
            }
        }
    }else {
        std::deque<StreamIDCallbackPair> registrations;
        CommitCallbacks(registrations,CONNECTED,false);
        CallbackMap::iterator where=mCallbacks.find(id);
        if (where!=mCallbacks.end()) {
            retval=where->second->mBytesReceivedCallback(newChunk);
        }else if (mOneSidedClosingStreams.find(id)==mOneSidedClosingStreams.end()) {
            //new substream
            TCPStream*newStream=new TCPStream(getSharedPtr(),id);
            TCPSetCallbacks setCallbackFunctor(this,newStream);
            mNewSubstreamCallback(newStream,setCallbackFunctor);
            if (setCallbackFunctor.mCallbacks != NULL) {
                CommitCallbacks(registrations,CONNECTED,false);//make sure bytes are received
                retval=setCallbackFunctor.mCallbacks->mBytesReceivedCallback(newChunk);
            }else {
                closeStream(getSharedPtr(),id);
            }
        }else {
            //IGNORED MESSAGE
        }
    }
    return retval;
}