예제 #1
0
///gets called when a complete 24 byte header is actually received: uses the UUID within to match up appropriate sockets
void buildStream(TcpSstHeaderArray *buffer,
                 TCPSocket *socket,
                 std::tr1::shared_ptr<TCPStreamListener::Data> data,
                 const boost::system::error_code &error,
                 std::size_t bytes_transferred) {

    if (error || std::memcmp(buffer->begin(),TCPStream::STRING_PREFIX(),TCPStream::STRING_PREFIX_LENGTH)!=0) {
        SILOG(tcpsst,warning,"Connection received with incomprehensible header");
    } else {
        boost::asio::ip::tcp::no_delay option(data->mNoDelay);
        socket->set_option(option);
        UUID context=UUID(buffer->begin()+(TCPStream::TcpSstHeaderSize-16),16);
        IncompleteStreamMap::iterator where=sIncompleteStreams.find(context);
        unsigned int numConnections=(((*buffer)[TCPStream::STRING_PREFIX_LENGTH]-'0')%10)*10+(((*buffer)[TCPStream::STRING_PREFIX_LENGTH+1]-'0')%10);
        if (numConnections>data->mMaxSimultaneousSockets) numConnections=data->mMaxSimultaneousSockets;
        if (where==sIncompleteStreams.end()) {
            sIncompleteStreams[context].mNumSockets=numConnections;
            where=sIncompleteStreams.find(context);
            assert(where!=sIncompleteStreams.end());
        }
        if ((int)numConnections!=where->second.mNumSockets) {
            SILOG(tcpsst,warning,"Single client disagrees on number of connections to establish: "<<numConnections<<" != "<<where->second.mNumSockets);
            sIncompleteStreams.erase(where);
        } else {
            where->second.mSockets.push_back(socket);
            if (numConnections==(unsigned int)where->second.mSockets.size()) {
                MultiplexedSocketPtr shared_socket(
                    MultiplexedSocket::construct<MultiplexedSocket>(&data->ios,context,where->second.mSockets,data->cb,data->mSendBufferSize));
                MultiplexedSocket::sendAllProtocolHeaders(shared_socket,UUID::random());
                sIncompleteStreams.erase(where);
                Stream::StreamID newID=Stream::StreamID(1);
                TCPStream * strm=new TCPStream(shared_socket,newID);

                TCPSetCallbacks setCallbackFunctor(&*shared_socket,strm);
                data->cb(strm,setCallbackFunctor);
                if (setCallbackFunctor.mCallbacks==NULL) {
                    SILOG(tcpsst,error,"Client code for stream "<<newID.read()<<" did not set listener on socket");
                    shared_socket->closeStream(shared_socket,newID);
                }
            } else {
                sStaleUUIDs.push_back(context);
            }
        }
    }
    delete buffer;
}
예제 #2
0
///gets called when a complete 24 byte header is actually received: uses the UUID within to match up appropriate sockets
void buildStream(TcpSstHeaderArray *buffer,
                 TCPSocket *socket,
                 std::tr1::shared_ptr<TCPStreamListener::Data> data,
                 const boost::system::error_code &error,
                 std::size_t bytes_transferred)
{
    // Buffer always needs to be cleaned up when we get out of this method
    std::auto_ptr<TcpSstHeaderArray> buffer_ptr(buffer);

    // Sanity check start
    if (error || bytes_transferred < 5 || std::string((const char*)buffer->begin(), 5) != std::string("GET /")) {
        SILOG(tcpsst,warning,"Connection received with truncated header: "<<error);
        delete socket;
        return;
    }

    // Sanity check end: 8 bytes from WebSocket spec after headers, then
    // \r\n\r\n before that.
    std::string buffer_str((const char*)buffer->begin(), bytes_transferred);
    if (buffer_str[ bytes_transferred - 12] != '\r' ||
        buffer_str[ bytes_transferred - 11] != '\n' ||
        buffer_str[ bytes_transferred - 10] != '\r' ||
        buffer_str[ bytes_transferred - 9] != '\n')
    {
        SILOG(tcpsst,warning,"Request doesn't end properly:\n" << buffer_str << "\n");
        delete socket;
        return;
    }

    std::string headers_str = buffer_str.substr(0, bytes_transferred - 10);
    // Parse headers
    UUID context;
    std::map<std::string, std::string> headers;
    std::string::size_type offset = 0;
    while(offset < headers_str.size()) {
        std::string::size_type last_offset = offset;
        offset = headers_str.find("\r\n", offset);
        if (offset == std::string::npos) {
            SILOG(tcpsst,warning,"Error parsing headers.");
            delete socket;
            return;
        }

        std::string line = headers_str.substr(last_offset, offset - last_offset);

        // Special case the initial GET line
        if (line.substr(0, 5) == "GET /") {
            std::string::size_type uuid_end = line.find(' ', 5);
            if (uuid_end == std::string::npos) {
                SILOG(tcpsst,warning,"Error parsing headers: invalid get line.");
                delete socket;
                return;
            }
            std::string uuid_str = line.substr(5, uuid_end - 5);
            try {
                context = UUID(uuid_str,UUID::HumanReadable());
            } catch(...) {
                SILOG(tcpsst,warning,"Error parsing headers: invalid get uuid.");
                delete socket;
                return;
            }

            offset += 2;
            continue;
        }

        std::string::size_type colon = line.find(":");
        if (colon == std::string::npos) {
            SILOG(tcpsst,warning,"Error parsing headers: missing colon.");
            delete socket;
            return;
        }
        std::string head = line.substr(0, colon);
        std::string val = line.substr(colon+2);

        headers[head] = val;

        // Advance to get past the \r\n
        offset += 2;
    }

    if (headers.find("Host") == headers.end() ||
        headers.find("Origin") == headers.end() ||
        headers.find("Sec-WebSocket-Key1") == headers.end() ||
        headers.find("Sec-WebSocket-Key2") == headers.end())
    {
        SILOG(tcpsst,warning,"Connection request didn't specify all required fields.");
        delete socket;
        return;
    }

    std::string host = headers["Host"];
    std::string origin = headers["Origin"];
    std::string protocol = "wssst1";
    if (headers.find("Sec-WebSocket-Protocol") != headers.end())
        protocol = headers["Sec-WebSocket-Protocol"];
    std::string key1 = headers["Sec-WebSocket-Key1"];
    std::string key2 = headers["Sec-WebSocket-Key2"];
    std::string key3 = buffer_str.substr(bytes_transferred - 8);
    assert(key3.size() == 8);

    std::string reply_str = getWebSocketSecReply(key1, key2, key3);

    bool binaryStream=protocol.find("sst")==0;
    bool base64Stream=!binaryStream;
    boost::asio::ip::tcp::no_delay option(data->mNoDelay);
    socket->set_option(option);
    IncompleteStreamMap::iterator where=sIncompleteStreams.find(context);

    unsigned int numConnections=1;

    for (std::string::iterator i=protocol.begin(),ie=protocol.end();i!=ie;++i) {
        if (*i>='0'&&*i<='9') {
            char* endptr=NULL;
            const char *start=protocol.c_str();
            size_t offset=(i-protocol.begin());
            start+=offset;
            numConnections=strtol(start,&endptr,10);
            size_t numberlen=endptr-start;
            if (numConnections>data->mMaxSimultaneousSockets) {
                numConnections=data->mMaxSimultaneousSockets;
                char numcon[256];
                sprintf(numcon,"%d",numConnections);
                protocol=protocol.substr(0,offset)+numcon+protocol.substr(offset+numberlen);
            }
            break;
        }
    }

    if (where==sIncompleteStreams.end()){
        sIncompleteStreams[context].mNumSockets=numConnections;
        where=sIncompleteStreams.find(context);
        assert(where!=sIncompleteStreams.end());
        // Setup a timer to clean up the sockets if we don't complete it in time
        data->strand->post(
            Duration::seconds(10),
            std::tr1::bind(&handleBuildStreamTimeout, context)
        );
    }
    if ((int)numConnections!=where->second.mNumSockets) {
        SILOG(tcpsst,warning,"Single client disagrees on number of connections to establish: "<<numConnections<<" != "<<where->second.mNumSockets);
        sIncompleteStreams.erase(where);
    }else {
        where->second.mSockets.push_back(socket);
        where->second.mWebSocketResponses[socket] = reply_str;
        if (numConnections==(unsigned int)where->second.mSockets.size()) {
            MultiplexedSocketPtr shared_socket(
                MultiplexedSocket::construct<MultiplexedSocket>(data->strand,context,data->cb,base64Stream));
            shared_socket->initFromSockets(where->second.mSockets,data->mSendBufferSize);
            std::string port=shared_socket->getASIOSocketWrapper(0).getLocalEndpoint().getService();
            std::string resource_name='/'+context.toString();
            MultiplexedSocket::sendAllProtocolHeaders(shared_socket,origin,host,port,resource_name,protocol, where->second.mWebSocketResponses);
            sIncompleteStreams.erase(where);


            Stream::StreamID newID=Stream::StreamID(1);
            TCPStream * strm=new TCPStream(shared_socket,newID);

            TCPSetCallbacks setCallbackFunctor(&*shared_socket,strm);
            data->cb(strm,setCallbackFunctor);
            if (setCallbackFunctor.mCallbacks==NULL) {
                SILOG(tcpsst,error,"Client code for stream "<<newID.read()<<" did not set listener on socket");
                shared_socket->closeStream(shared_socket,newID);
            }
        }else{
            sStaleUUIDs.push_back(context);
        }
    }
}