void update_hooks_with_buffer (DskWebsocket *websocket) { uint8_t header[9]; uint64_t len; if (websocket->is_shutdown || websocket->is_deferred_shutdown) return; restart_processing: if (websocket->to_discard > 0) { if (websocket->to_discard > websocket->incoming.size) { websocket->to_discard -= websocket->incoming.size; dsk_buffer_reset (&websocket->incoming); goto unset_packet_readable; } else { dsk_buffer_discard (&websocket->incoming, websocket->to_discard); websocket->to_discard = 0; goto restart_processing; } } if (websocket->incoming.size < 9) goto unset_packet_readable; dsk_buffer_peek (&websocket->incoming, 9, header); len = dsk_uint64be_parse (header + 1); if (len > websocket->max_length) { switch (websocket->too_long_mode) { case DSK_WEBSOCKET_MODE_DROP: websocket->to_discard -= len + 9; goto restart_processing; case DSK_WEBSOCKET_MODE_RETURN_ERROR: goto set_packet_readable; return; case DSK_WEBSOCKET_MODE_SHUTDOWN: goto error_do_shutdown; } } else if (websocket->incoming.size >= 9 + len) goto set_packet_readable; else goto unset_packet_readable; set_packet_readable: dsk_hook_set_idle_notify (&websocket->readable, DSK_TRUE); if (websocket->read_trap != NULL) { dsk_hook_trap_destroy (websocket->read_trap); websocket->read_trap = NULL; } return; unset_packet_readable: dsk_hook_set_idle_notify (&websocket->readable, DSK_FALSE); ensure_has_read_trap (websocket); return; error_do_shutdown: dsk_hook_set_idle_notify (&websocket->readable, DSK_FALSE); if (websocket->read_trap != NULL) { dsk_hook_trap_destroy (websocket->read_trap); websocket->read_trap = NULL; } dsk_websocket_shutdown (websocket); }
int main(int argc, char** argv) { DskBuffer gskbuffer; char buf[1024]; char *str; dsk_cmdline_init ("test dsk-buffer code", "test DskBuffer", NULL, 0); dsk_cmdline_process_args (&argc, &argv); dsk_buffer_init (&gskbuffer); dsk_assert (gskbuffer.size == 0); dsk_buffer_append (&gskbuffer, 5, "hello"); dsk_assert (gskbuffer.size == 5); dsk_assert (dsk_buffer_read (&gskbuffer, sizeof (buf), buf) == 5); dsk_assert (memcmp (buf, "hello", 5) == 0); dsk_assert (gskbuffer.size == 0); dsk_buffer_clear (&gskbuffer); dsk_buffer_init (&gskbuffer); count (&gskbuffer, 1, 100000); decount (&gskbuffer, 1, 100000); dsk_assert (gskbuffer.size == 0); dsk_buffer_clear (&gskbuffer); dsk_buffer_init (&gskbuffer); dsk_buffer_append_string (&gskbuffer, "hello\na\nb"); str = dsk_buffer_read_line (&gskbuffer); dsk_assert (str); dsk_assert (strcmp (str, "hello") == 0); dsk_free (str); str = dsk_buffer_read_line (&gskbuffer); dsk_assert (str); dsk_assert (strcmp (str, "a") == 0); dsk_free (str); dsk_assert (gskbuffer.size == 1); dsk_assert (dsk_buffer_read_line (&gskbuffer) == NULL); dsk_buffer_append_byte (&gskbuffer, '\n'); str = dsk_buffer_read_line (&gskbuffer); dsk_assert (str); dsk_assert (strcmp (str, "b") == 0); dsk_free (str); dsk_assert (gskbuffer.size == 0); dsk_buffer_clear (&gskbuffer); dsk_buffer_init (&gskbuffer); dsk_buffer_append (&gskbuffer, 5, "hello"); dsk_buffer_append_foreign (&gskbuffer, 4, "test", NULL, NULL); dsk_buffer_append (&gskbuffer, 5, "hello"); dsk_assert (gskbuffer.size == 14); dsk_assert (dsk_buffer_read (&gskbuffer, sizeof (buf), buf) == 14); dsk_assert (memcmp (buf, "hellotesthello", 14) == 0); dsk_assert (gskbuffer.size == 0); /* Test that the foreign data really is not being stored in the DskBuffer */ { char test_str[5]; strcpy (test_str, "test"); dsk_buffer_init (&gskbuffer); dsk_buffer_append (&gskbuffer, 5, "hello"); dsk_buffer_append_foreign (&gskbuffer, 4, test_str, NULL, NULL); dsk_buffer_append (&gskbuffer, 5, "hello"); dsk_assert (gskbuffer.size == 14); dsk_assert (dsk_buffer_peek (&gskbuffer, sizeof (buf), buf) == 14); dsk_assert (memcmp (buf, "hellotesthello", 14) == 0); test_str[1] = '3'; dsk_assert (gskbuffer.size == 14); dsk_assert (dsk_buffer_read (&gskbuffer, sizeof (buf), buf) == 14); dsk_assert (memcmp (buf, "hellot3sthello", 14) == 0); dsk_buffer_clear (&gskbuffer); } /* Test str_index_of */ { DskBuffer buffer = DSK_BUFFER_INIT; dsk_buffer_append_foreign (&buffer, 3, "abc", NULL, NULL); dsk_buffer_append_foreign (&buffer, 3, "def", NULL, NULL); dsk_buffer_append_foreign (&buffer, 3, "gad", NULL, NULL); #if 0 dsk_assert (dsk_buffer_str_index_of (&buffer, "cdefg") == 2); dsk_assert (dsk_buffer_str_index_of (&buffer, "ad") == 7); dsk_assert (dsk_buffer_str_index_of (&buffer, "ab") == 0); dsk_assert (dsk_buffer_str_index_of (&buffer, "a") == 0); dsk_assert (dsk_buffer_str_index_of (&buffer, "g") == 6); #endif dsk_buffer_clear (&buffer); } static const char *before_strs[] = { "", "foo", NULL, NULL }; before_strs[2] = generate_str (100, 1000); before_strs[3] = generate_str (10000, 100000); static const char *placeholder_strs[] = { "", "bar", NULL, NULL, NULL }; placeholder_strs[2] = generate_str (100, 1000); placeholder_strs[3] = generate_str (10000, 100000); placeholder_strs[4] = generate_str (100000, 1000000); static const char *after_strs[] = { "", "foo", NULL, NULL }; after_strs[2] = generate_str (100, 1000); after_strs[3] = generate_str (10000, 100000); unsigned bi, pi, ai; for (bi = 0; bi < DSK_N_ELEMENTS (before_strs); bi++) for (pi = 0; pi < DSK_N_ELEMENTS (placeholder_strs); pi++) for (ai = 0; ai < DSK_N_ELEMENTS (after_strs); ai++) { DskBuffer buffer = DSK_BUFFER_INIT; const char *pi_str = placeholder_strs[pi]; DskBufferPlaceholder placeholder; dsk_buffer_append_string (&buffer, before_strs[bi]); dsk_buffer_append_placeholder (&buffer, strlen (pi_str), &placeholder); dsk_buffer_append_string (&buffer, after_strs[ai]); dsk_buffer_placeholder_set (&placeholder, pi_str); dsk_assert (try_initial_remove (&buffer, before_strs[bi])); dsk_assert (try_initial_remove (&buffer, pi_str)); dsk_assert (try_initial_remove (&buffer, after_strs[ai])); dsk_assert (buffer.size == 0); } return 0; }
DskIOResult dsk_websocket_receive (DskWebsocket *websocket, unsigned *length_out, uint8_t **data_out, DskError **error) { uint8_t header[9]; uint64_t length; restart: maybe_discard_data (websocket); if (websocket->incoming.size < 9) return DSK_IO_RESULT_AGAIN; dsk_buffer_peek (&websocket->incoming, 9, header); length = dsk_uint64be_parse (header + 1); if (length > websocket->max_length) { switch (websocket->too_long_mode) { case DSK_WEBSOCKET_MODE_DROP: websocket->to_discard = length + 9; goto restart; case DSK_WEBSOCKET_MODE_SHUTDOWN: do_deferred_shutdown (websocket); return DSK_IO_RESULT_ERROR; case DSK_WEBSOCKET_MODE_RETURN_ERROR: websocket->to_discard = length + 9; maybe_discard_data (websocket); dsk_set_error (error, "packet too long (%"PRIu64" bytes)", length); return DSK_IO_RESULT_ERROR; } } switch (header[0]) { case 0x00: /* uh oh - shutdown packet */ dsk_buffer_discard (&websocket->incoming, 9 + length); do_deferred_shutdown (websocket); return DSK_IO_RESULT_EOF; case 0xff: if (websocket->incoming.size - 9 < length) return DSK_IO_RESULT_AGAIN; *length_out = length; *data_out = dsk_malloc (length); dsk_buffer_discard (&websocket->incoming, 9); *data_out = dsk_malloc (length); dsk_buffer_read (&websocket->incoming, length, *data_out); update_hooks_with_buffer (websocket); return DSK_IO_RESULT_SUCCESS; default: /* error */ switch (websocket->bad_packet_type_mode) { case DSK_WEBSOCKET_MODE_SHUTDOWN: do_deferred_shutdown (websocket); break; case DSK_WEBSOCKET_MODE_RETURN_ERROR: dsk_set_error (error, "packet had bad type: 0x%02x", header[0]); websocket->to_discard = length + 9; maybe_discard_data (websocket); return DSK_IO_RESULT_ERROR; case DSK_WEBSOCKET_MODE_DROP: websocket->to_discard = length + 9; goto restart; } } dsk_assert_not_reached (); return DSK_IO_RESULT_ERROR; }