// process input data int receiveData(SSL& ssl, Data& data, bool peek) { if (ssl.GetError() == YasslError(SSL_ERROR_WANT_READ)) ssl.SetError(no_error); ssl.verfiyHandShakeComplete(); if (ssl.GetError()) return -1; if (!ssl.HasData()) processReply(ssl); if (peek) ssl.PeekData(data); else ssl.fillData(data); ssl.useLog().ShowData(data.get_length()); if (ssl.GetError()) return -1; if (data.get_length() == 0 && ssl.getSocket().WouldBlock()) { ssl.SetError(YasslError(SSL_ERROR_WANT_READ)); return SSL_WOULD_BLOCK; } return data.get_length(); }
// do process input requests, return 0 is done, 1 is call again to complete int DoProcessReply(SSL& ssl) { // wait for input if blocking if (!ssl.useSocket().wait()) { ssl.SetError(receive_error); return 0; } uint ready = ssl.getSocket().get_ready(); if (!ready) return 1; // add buffered data if its there input_buffer* buffered = ssl.useBuffers().TakeRawInput(); uint buffSz = buffered ? buffered->get_size() : 0; input_buffer buffer(buffSz + ready); if (buffSz) { buffer.assign(buffered->get_buffer(), buffSz); ysDelete(buffered); buffered = 0; } // add new data uint read = ssl.useSocket().receive(buffer.get_buffer() + buffSz, ready); if (read == static_cast<uint>(-1)) { ssl.SetError(receive_error); return 0; } buffer.add_size(read); uint offset = 0; const MessageFactory& mf = ssl.getFactory().getMessage(); // old style sslv2 client hello? if (ssl.getSecurity().get_parms().entity_ == server_end && ssl.getStates().getServer() == clientNull) if (buffer.peek() != handshake) { ProcessOldClientHello(buffer, ssl); if (ssl.GetError()) return 0; } while(!buffer.eof()) { // each record RecordLayerHeader hdr; bool needHdr = false; if (static_cast<uint>(RECORD_HEADER) > buffer.get_remaining()) needHdr = true; else { buffer >> hdr; ssl.verifyState(hdr); } // make sure we have enough input in buffer to process this record if (needHdr || hdr.length_ > buffer.get_remaining()) { // put header in front for next time processing uint extra = needHdr ? 0 : RECORD_HEADER; uint sz = buffer.get_remaining() + extra; ssl.useBuffers().SetRawInput(NEW_YS input_buffer(sz, buffer.get_buffer() + buffer.get_current() - extra, sz)); return 1; } while (buffer.get_current() < hdr.length_ + RECORD_HEADER + offset) { // each message in record, can be more than 1 if not encrypted if (ssl.getSecurity().get_parms().pending_ == false) // cipher on decrypt_message(ssl, buffer, hdr.length_); mySTL::auto_ptr<Message> msg(mf.CreateObject(hdr.type_)); if (!msg.get()) { ssl.SetError(factory_error); return 0; } buffer >> *msg; msg->Process(buffer, ssl); if (ssl.GetError()) return 0; } offset += hdr.length_ + RECORD_HEADER; } return 0; }
// send data int sendData(SSL& ssl, const void* buffer, int sz) { int sent = 0; if (ssl.GetError() == YasslError(SSL_ERROR_WANT_READ)) ssl.SetError(no_error); if (ssl.GetError() == YasslError(SSL_ERROR_WANT_WRITE)) { ssl.SetError(no_error); ssl.SendWriteBuffered(); if (!ssl.GetError()) { // advance sent to prvevious sent + plain size just sent sent = ssl.useBuffers().prevSent + ssl.useBuffers().plainSz; } } ssl.verfiyHandShakeComplete(); if (ssl.GetError()) return -1; for (;;) { int len = min(sz - sent, MAX_RECORD_SIZE); output_buffer out; input_buffer tmp; Data data; if (sent == sz) break; if (ssl.CompressionOn()) { if (Compress(static_cast<const opaque*>(buffer) + sent, len, tmp) == -1) { ssl.SetError(compress_error); return -1; } data.SetData(tmp.get_size(), tmp.get_buffer()); } else data.SetData(len, static_cast<const opaque*>(buffer) + sent); buildMessage(ssl, out, data); ssl.Send(out.get_buffer(), out.get_size()); if (ssl.GetError()) { if (ssl.GetError() == YasslError(SSL_ERROR_WANT_WRITE)) { ssl.useBuffers().plainSz = len; ssl.useBuffers().prevSent = sent; } return -1; } sent += len; } ssl.useLog().ShowData(sent, true); return sent; }
// some clients still send sslv2 client hello void ProcessOldClientHello(input_buffer& input, SSL& ssl) { if (input.get_remaining() < 2) { ssl.SetError(bad_input); return; } byte b0 = input[AUTO]; byte b1 = input[AUTO]; uint16 sz = ((b0 & 0x7f) << 8) | b1; if (sz > input.get_remaining()) { ssl.SetError(bad_input); return; } // hashHandShake manually const opaque* buffer = input.get_buffer() + input.get_current(); ssl.useHashes().use_MD5().update(buffer, sz); ssl.useHashes().use_SHA().update(buffer, sz); b1 = input[AUTO]; // does this value mean client_hello? ClientHello ch; ch.client_version_.major_ = input[AUTO]; ch.client_version_.minor_ = input[AUTO]; byte len[2]; input.read(len, sizeof(len)); ato16(len, ch.suite_len_); input.read(len, sizeof(len)); uint16 sessionLen; ato16(len, sessionLen); ch.id_len_ = sessionLen; input.read(len, sizeof(len)); uint16 randomLen; ato16(len, randomLen); if (ch.suite_len_ > MAX_SUITE_SZ || sessionLen > ID_LEN || randomLen > RAN_LEN) { ssl.SetError(bad_input); return; } int j = 0; for (uint16 i = 0; i < ch.suite_len_; i += 3) { byte first = input[AUTO]; if (first) // sslv2 type input.read(len, SUITE_LEN); // skip else { input.read(&ch.cipher_suites_[j], SUITE_LEN); j += SUITE_LEN; } } ch.suite_len_ = j; if (ch.id_len_) input.read(ch.session_id_, ch.id_len_); if (randomLen < RAN_LEN) memset(ch.random_, 0, RAN_LEN - randomLen); input.read(&ch.random_[RAN_LEN - randomLen], randomLen); ch.Process(input, ssl); }