Request::State RequestParser::parseWebSocketMessage(Request& req, ReplyPtr reply, Buffer::iterator& begin, Buffer::iterator end) { switch (wsState_) { case ws_start: { if (req.webSocketVersion != 0 && req.webSocketVersion != 7 && req.webSocketVersion != 8 && req.webSocketVersion != 13) { LOG_ERROR("ws: unsupported protocol version " << req.webSocketVersion); // FIXME add Sec-WebSocket-Version fields return Request::Error; } if (req.webSocketVersion == 0) { LOG_INFO("ws: connect with protocol version 0"); /* * In version 00, the hand shake is two-staged, we first need to * send the 101 to be able to access the part of the handshake * that is sent after the GET */ const Request::Header *host = req.getHeader("Host"); if (!host || host->value.empty()) { LOG_ERROR("ws: missing Host field"); return Request::Error; } wsState_ = ws00_hand_shake; wsCount_ = 0; reply->setStatus(Reply::switching_protocols); reply->addHeader("Connection", "Upgrade"); reply->addHeader("Upgrade", "WebSocket"); const Request::Header *origin = req.getHeader("Origin"); if (origin && !origin->value.empty()) reply->addHeader("Sec-WebSocket-Origin", origin->value.str()); std::string location = std::string(req.urlScheme) + "://" + host->value.str() + req.request_path + "?" + req.request_query; reply->addHeader("Sec-WebSocket-Location", location); reply->consumeData(begin, begin, Request::Partial); return Request::Complete; } else { LOG_INFO("ws: connect with protocol version " << req.webSocketVersion); std::string accept = doWebSocketHandshake13(req); if (accept.empty()) { LOG_ERROR("ws: error computing handshake result"); return Request::Error; } else { wsState_ = ws13_frame_start; reply->setStatus(Reply::switching_protocols); reply->addHeader("Connection", "Upgrade"); reply->addHeader("Upgrade", "WebSocket"); reply->addHeader("Sec-WebSocket-Accept", accept); reply->consumeData(begin, begin, Request::Complete); return Request::Complete; } } break; } case ws00_hand_shake: { unsigned thisSize = std::min((::int64_t)(end - begin), (::int64_t)(8 - wsCount_)); memcpy(ws00_buf_ + wsCount_, begin, thisSize); wsCount_ += thisSize; begin += thisSize; if (wsCount_ == 8) { bool okay = doWebSocketHandshake00(req); if (okay) { wsState_ = ws00_frame_start; reply->consumeData(ws00_buf_, ws00_buf_ + 16, Request::Complete); return Request::Complete; } else { LOG_ERROR("ws: invalid client hand-shake"); return Request::Error; } } else return Request::Partial; } default: break; } Buffer::iterator dataBegin = begin; Buffer::iterator dataEnd = begin; // Initially assume no data Request::State state = Request::Partial; while (begin < end && state == Request::Partial) { switch (wsState_) { case ws00_frame_start: wsFrameType_ = *begin; if (wsFrameType_ & 0x80) { wsState_ = ws00_binary_length; remainder_ = 0; } else { wsState_ = ws00_text_data; dataBegin = begin + 1; remainder_ = 0; } ++begin; break; case ws00_binary_length: remainder_ = remainder_ << 7 | (*begin & 0x7F); if ((*begin & 0x80) == 0) { if (remainder_ == 0 || remainder_ >= MAX_WEBSOCKET_MESSAGE_LENGTH) { LOG_ERROR("ws: oversized binary frame of length " << remainder_); return Request::Error; } wsState_ = ws00_binary_data; } ++begin; break; case ws00_text_data: if (static_cast<unsigned char>(*begin) == 0xFF) { state = Request::Complete; wsState_ = ws00_frame_start; dataEnd = begin; } else { ++remainder_; if (remainder_ >= MAX_WEBSOCKET_MESSAGE_LENGTH) { LOG_ERROR("ws: oversized text frame of length " << remainder_); return Request::Error; } } ++begin; break; case ws00_binary_data: { ::int64_t thisSize = std::min((::int64_t)(end - begin), remainder_); dataBegin = begin; begin = begin + thisSize; dataEnd = begin; remainder_ -= thisSize; if (remainder_ == 0) { state = Request::Complete; wsState_ = ws00_frame_start; } break; } case ws13_frame_start: { unsigned char frameType = *begin; LOG_DEBUG("ws: new frame, opcode byte=" << (int)frameType); /* RSV1-3 must be 0 */ if (frameType & 0x70) return Request::Error; switch (frameType & 0x0F) { case 0x0: // Continuation frame of a fragmented message if (frameType & 0x80) wsFrameType_ |= 0x80; // mark the end-of-frame break; case 0x1: // Text frame case 0x2: // Binary frame case 0x8: // Close case 0x9: // Ping case 0xA: // Pong wsFrameType_ = frameType; break; default: LOG_ERROR("ws: unknown opcode"); return Request::Error; } wsState_ = ws13_payload_length; wsCount_ = 0; ++begin; break; } case ws13_payload_length: /* Client frame must be masked */ if ((*begin & 0x80) == 0) { LOG_ERROR("ws: client frame not masked"); return Request::Error; } remainder_ = *begin & 0x7F; if (remainder_ < 126) { LOG_DEBUG("ws: new frame length " << remainder_); wsMask_ = 0; wsState_ = ws13_mask; wsCount_ = 4; } else if (remainder_ == 126) { wsState_ = ws13_extended_payload_length; wsCount_ = 2; remainder_ = 0; } else if (remainder_ == 127) { wsState_ = ws13_extended_payload_length; wsCount_ = 8; remainder_ = 0; } ++begin; break; case ws13_extended_payload_length: remainder_ <<= 8; remainder_ |= (unsigned char)(*begin); --wsCount_; if (wsCount_ == 0) { LOG_DEBUG("ws: new frame length " << remainder_); if (remainder_ >= MAX_WEBSOCKET_MESSAGE_LENGTH) { LOG_ERROR("ws: oversized frame of length " << remainder_); return Request::Error; } wsMask_ = 0; wsState_ = ws13_mask; wsCount_ = 4; } ++begin; break; case ws13_mask: wsMask_ <<= 8; wsMask_ |= (unsigned char)(*begin); --wsCount_; if (wsCount_ == 0) { LOG_DEBUG("ws: new frame read mask"); if (remainder_ != 0) { wsState_ = ws13_payload; } else { // Frame without data (like pong) if (wsFrameType_ & 0x80) state = Request::Complete; wsState_ = ws13_frame_start; } } ++begin; break; case ws13_payload: { ::int64_t thisSize = std::min((::int64_t)(end - begin), remainder_); dataBegin = begin; begin = begin + thisSize; dataEnd = begin; remainder_ -= thisSize; /* Unmask dataBegin to dataEnd, mask offset in wsCount_ */ for (Buffer::iterator i = dataBegin; i != dataEnd; ++i) { unsigned char d = *i; unsigned char m = (unsigned char)(wsMask_ >> ((3 - wsCount_) * 8)); d = d ^ m; *i = d; wsCount_ = (wsCount_ + 1) % 4; } LOG_DEBUG("ws: reading payload, remains = " << remainder_); if (remainder_ == 0) { if (wsFrameType_ & 0x80) state = Request::Complete; wsState_ = ws13_frame_start; } break; } default: assert(false); } } LOG_DEBUG("ws: " << (dataEnd - dataBegin) << "," << state); if (dataBegin < dataEnd || state == Request::Complete) { if (wsState_ < ws13_frame_start) { if (wsFrameType_ == 0x00) reply->consumeWebSocketMessage(Reply::text_frame, dataBegin, dataEnd, state); } else { Reply::ws_opcode opcode = (Reply::ws_opcode)(wsFrameType_ & 0x0F); reply->consumeWebSocketMessage(opcode, dataBegin, dataEnd, state); } } return state; }
Request::State RequestParser::parseWebSocketMessage(Request& req, ReplyPtr reply, char *& begin, char * end) { switch (wsState_) { case ws_start: { if (req.webSocketVersion != 0 && req.webSocketVersion != 7 && req.webSocketVersion != 8 && req.webSocketVersion != 13) { LOG_ERROR("ws: unsupported protocol version " << req.webSocketVersion); // FIXME add Sec-WebSocket-Version fields return Request::Error; } if (req.webSocketVersion == 0) { LOG_INFO("ws: connect with protocol version 0"); /* * In version 00, the hand shake is two-staged, we first need to * send the 101 to be able to access the part of the handshake * that is sent after the GET */ const Request::Header *host = req.getHeader("Host"); if (!host || host->value.empty()) { LOG_ERROR("ws: missing Host field"); return Request::Error; } wsState_ = ws00_hand_shake; wsCount_ = 0; reply->setStatus(Reply::switching_protocols); reply->addHeader("Connection", "Upgrade"); reply->addHeader("Upgrade", "WebSocket"); const Request::Header *origin = req.getHeader("Origin"); if (origin && !origin->value.empty()) reply->addHeader("Sec-WebSocket-Origin", origin->value.str()); std::string location = std::string(req.urlScheme) + "://" + host->value.str() + req.request_path + "?" + req.request_query; reply->addHeader("Sec-WebSocket-Location", location); reply->consumeData(&*begin, &*begin, Request::Partial); return Request::Complete; } else { LOG_INFO("ws: connect with protocol version " << req.webSocketVersion); std::string accept = doWebSocketHandshake13(req); if (accept.empty()) { LOG_ERROR("ws: error computing handshake result"); return Request::Error; } else { wsState_ = ws13_frame_start; reply->setStatus(Reply::switching_protocols); reply->addHeader("Connection", "Upgrade"); reply->addHeader("Upgrade", "WebSocket"); reply->addHeader("Sec-WebSocket-Accept", accept); #ifdef WTHTTP_WITH_ZLIB std::string compressHeader; if(!doWebSocketPerMessageDeflateNegotiation(req, compressHeader)) return Request::Error; if(!compressHeader.empty()) { // We can use per message deflate if(initInflate()) { LOG_DEBUG("Extension per_message_deflate requested"); reply->addHeader("Sec-WebSocket-Extensions", compressHeader); }else { // Decompress not available disable req.pmdState_.enabled = false; } } #endif reply->consumeData(&*begin, &*begin, Request::Complete); return Request::Complete; } } break; } case ws00_hand_shake: { unsigned thisSize = std::min((::int64_t)(end - begin), (::int64_t)(8 - wsCount_)); memcpy(ws00_buf_ + wsCount_, &*begin, thisSize); wsCount_ += thisSize; begin += thisSize; if (wsCount_ == 8) { bool okay = doWebSocketHandshake00(req); if (okay) { wsState_ = ws00_frame_start; reply->consumeData(ws00_buf_, ws00_buf_ + 16, Request::Complete); return Request::Complete; } else { LOG_ERROR("ws: invalid client hand-shake"); return Request::Error; } } else return Request::Partial; } default: break; } char *dataBegin = begin; char *dataEnd = begin; // Initially assume no data Request::State state = Request::Partial; while (begin < end && state == Request::Partial) { switch (wsState_) { case ws00_frame_start: wsFrameType_ = *begin; if (wsFrameType_ & 0x80) { wsState_ = ws00_binary_length; remainder_ = 0; } else { wsState_ = ws00_text_data; dataBegin = begin + 1; remainder_ = 0; } ++begin; break; case ws00_binary_length: if (remainder_ >= (((int64_t)0x01) << 56)) { LOG_ERROR("ws: oversized binary frame: overflows 64-bit signed integer"); return Request::Error; } remainder_ = remainder_ << 7 | (*begin & 0x7F); if ((*begin & 0x80) == 0) { if (remainder_ == 0 || remainder_ >= MAX_WEBSOCKET_MESSAGE_LENGTH) { LOG_ERROR("ws: oversized binary frame of length " << remainder_); return Request::Error; } wsState_ = ws00_binary_data; } ++begin; break; case ws00_text_data: if (static_cast<unsigned char>(*begin) == 0xFF) { state = Request::Complete; wsState_ = ws00_frame_start; dataEnd = begin; } else { ++remainder_; if (remainder_ >= MAX_WEBSOCKET_MESSAGE_LENGTH) { LOG_ERROR("ws: oversized text frame of length " << remainder_); return Request::Error; } } ++begin; break; case ws00_binary_data: { ::int64_t thisSize = std::min((::int64_t)(end - begin), remainder_); dataBegin = begin; begin = begin + thisSize; dataEnd = begin; remainder_ -= thisSize; if (remainder_ == 0) { state = Request::Complete; wsState_ = ws00_frame_start; } break; } case ws13_frame_start: { unsigned char frameType = *begin; LOG_DEBUG("ws: new frame, opcode byte=" << (int)frameType); #ifdef WTHTTP_WITH_ZLIB /* RSV1-3 must be 0 */ if (frameType & 0x70 && (!req.pmdState_.enabled && frameType & 0x30)) return Request::Error; frameCompressed_ = frameType & 0x40; #else if(frameType & 0x70) return Request::Error; #endif switch (frameType & 0x0F) { case 0x0: // Continuation frame of a fragmented message if (frameType & 0x80) wsFrameType_ |= 0x80; // mark the end-of-frame break; case 0x1: // Text frame case 0x2: // Binary frame case 0x8: // Close case 0x9: // Ping case 0xA: // Pong wsFrameType_ = frameType; break; default: LOG_ERROR("ws: unknown opcode"); return Request::Error; } wsState_ = ws13_payload_length; wsCount_ = 0; ++begin; break; } case ws13_payload_length: /* Client frame must be masked */ if ((*begin & 0x80) == 0) { LOG_ERROR("ws: client frame not masked"); return Request::Error; } remainder_ = *begin & 0x7F; if (remainder_ < 126) { LOG_DEBUG("ws: new frame length " << remainder_); wsMask_ = 0; wsState_ = ws13_mask; wsCount_ = 4; } else if (remainder_ == 126) { wsState_ = ws13_extended_payload_length; wsCount_ = 2; remainder_ = 0; } else if (remainder_ == 127) { wsState_ = ws13_extended_payload_length; wsCount_ = 8; remainder_ = 0; } ++begin; break; case ws13_extended_payload_length: if (wsCount_ == 8 && (unsigned char)(*begin) > 127) { LOG_ERROR("ws: malformed 8-byte frame length, MSB is not 0"); return Request::Error; } remainder_ <<= 8; remainder_ |= (unsigned char)(*begin); --wsCount_; if (wsCount_ == 0) { LOG_DEBUG("ws: new frame length " << remainder_); if (remainder_ >= MAX_WEBSOCKET_MESSAGE_LENGTH) { LOG_ERROR("ws: oversized frame of length " << remainder_); return Request::Error; } wsMask_ = 0; wsState_ = ws13_mask; wsCount_ = 4; } ++begin; break; case ws13_mask: wsMask_ <<= 8; wsMask_ |= (unsigned char)(*begin); --wsCount_; if (wsCount_ == 0) { LOG_DEBUG("ws: new frame read mask"); if (remainder_ != 0) { wsState_ = ws13_payload; } else { // Frame without data (like pong) if (wsFrameType_ & 0x80) state = Request::Complete; wsState_ = ws13_frame_start; } } ++begin; break; case ws13_payload: { ::int64_t thisSize = std::min((::int64_t)(end - begin), remainder_); dataBegin = begin; begin = begin + thisSize; dataEnd = begin; remainder_ -= thisSize; /* Unmask dataBegin to dataEnd, mask offset in wsCount_ */ for (char *i = dataBegin; i != dataEnd; ++i) { unsigned char d = *i; unsigned char m = (unsigned char)(wsMask_ >> ((3 - wsCount_) * 8)); d = d ^ m; *i = d; wsCount_ = (wsCount_ + 1) % 4; } LOG_DEBUG("ws: reading payload, remains = " << remainder_); if (remainder_ == 0) { if (wsFrameType_ & 0x80) state = Request::Complete; wsState_ = ws13_frame_start; } break; } default: assert(false); } } LOG_DEBUG("ws: " << (dataEnd - dataBegin) << "," << state); if (dataBegin < dataEnd || state == Request::Complete) { char* beg = &*dataBegin; char* end = &*dataEnd; #ifdef WTHTTP_WITH_ZLIB if (frameCompressed_) { Reply::ws_opcode opcode = (Reply::ws_opcode)(wsFrameType_ & 0x0F); if (wsState_ < ws13_frame_start) { if (wsFrameType_ == 0x00) opcode = Reply::text_frame; } unsigned char appendBlock[] = { 0x00, 0x00, 0xff, 0xff }; bool hasMore = false; char buffer[16 * 1024]; do { read_ = 0; bool ret1 = inflate(reinterpret_cast<unsigned char*>(&*beg), end - beg, reinterpret_cast<unsigned char*>(buffer), hasMore); if(!ret1) return Request::Error; reply->consumeWebSocketMessage(opcode, &buffer[0], &buffer[read_], hasMore ? Request::Partial : state); } while (hasMore); if (state == Request::Complete) if(!inflate(appendBlock, 4, reinterpret_cast<unsigned char*>(buffer), hasMore)) return Request::Error; return state; } #endif // handle uncompressed frame if (wsState_ < ws13_frame_start) { if (wsFrameType_ == 0x00) reply->consumeWebSocketMessage(Reply::text_frame, beg, end, state); } else { Reply::ws_opcode opcode = (Reply::ws_opcode)(wsFrameType_ & 0x0F); reply->consumeWebSocketMessage(opcode, beg, end, state); } } return state; }