示例#1
0
int session_cipher_encrypt(session_cipher *cipher,
        const uint8_t *padded_message, size_t padded_message_len,
        ciphertext_message **encrypted_message)
{
    int result = 0;
    session_record *record = 0;
    session_state *state = 0;
    ratchet_chain_key *chain_key = 0;
    ratchet_chain_key *next_chain_key = 0;
    ratchet_message_keys message_keys;
    ec_public_key *sender_ephemeral = 0;
    uint32_t previous_counter = 0;
    uint32_t session_version = 0;
    signal_buffer *ciphertext = 0;
    uint32_t chain_key_index = 0;
    ec_public_key *local_identity_key = 0;
    ec_public_key *remote_identity_key = 0;
    signal_message *message = 0;
    pre_key_signal_message *pre_key_message = 0;
    uint8_t *ciphertext_data = 0;
    size_t ciphertext_len = 0;

    assert(cipher);
    signal_lock(cipher->global_context);

    if(cipher->inside_callback == 1) {
        result = SG_ERR_INVAL;
        goto complete;
    }

    result = signal_protocol_session_load_session(cipher->store, &record, cipher->remote_address);
    if(result < 0) {
        goto complete;
    }

    state = session_record_get_state(record);
    if(!state) {
        result = SG_ERR_UNKNOWN;
        goto complete;
    }

    chain_key = session_state_get_sender_chain_key(state);
    if(!chain_key) {
        result = SG_ERR_UNKNOWN;
        goto complete;
    }

    result = ratchet_chain_key_get_message_keys(chain_key, &message_keys);
    if(result < 0) {
        goto complete;
    }

    sender_ephemeral = session_state_get_sender_ratchet_key(state);
    if(!sender_ephemeral) {
        result = SG_ERR_UNKNOWN;
        goto complete;
    }

    previous_counter = session_state_get_previous_counter(state);
    session_version = session_state_get_session_version(state);

    result = session_cipher_get_ciphertext(cipher,
            &ciphertext,
            session_version, &message_keys,
            padded_message, padded_message_len);
    if(result < 0) {
        goto complete;
    }
    ciphertext_data = signal_buffer_data(ciphertext);
    ciphertext_len = signal_buffer_len(ciphertext);

    chain_key_index = ratchet_chain_key_get_index(chain_key);

    local_identity_key = session_state_get_local_identity_key(state);
    if(!local_identity_key) {
        result = SG_ERR_UNKNOWN;
        goto complete;
    }

    remote_identity_key = session_state_get_remote_identity_key(state);
    if(!remote_identity_key) {
        result = SG_ERR_UNKNOWN;
        goto complete;
    }

    result = signal_message_create(&message,
            session_version,
            message_keys.mac_key, sizeof(message_keys.mac_key),
            sender_ephemeral,
            chain_key_index, previous_counter,
            ciphertext_data, ciphertext_len,
            local_identity_key, remote_identity_key,
            cipher->global_context);
    if(result < 0) {
        goto complete;
    }

    if(session_state_has_unacknowledged_pre_key_message(state) == 1) {
        uint32_t local_registration_id = session_state_get_local_registration_id(state);
        int has_pre_key_id = 0;
        uint32_t pre_key_id = 0;
        uint32_t signed_pre_key_id;
        ec_public_key *base_key;
        
        if(session_state_unacknowledged_pre_key_message_has_pre_key_id(state)) {
            has_pre_key_id = 1;
            pre_key_id = session_state_unacknowledged_pre_key_message_get_pre_key_id(state);
        }
        signed_pre_key_id = session_state_unacknowledged_pre_key_message_get_signed_pre_key_id(state);
        base_key = session_state_unacknowledged_pre_key_message_get_base_key(state);

        if(!base_key) {
            result = SG_ERR_UNKNOWN;
            goto complete;
        }

        result = pre_key_signal_message_create(&pre_key_message,
                session_version, local_registration_id, (has_pre_key_id ? &pre_key_id : 0),
                signed_pre_key_id, base_key, local_identity_key,
                message,
                cipher->global_context);
        if(result < 0) {
            goto complete;
        }
        SIGNAL_UNREF(message);
        message = 0;
    }

    result = ratchet_chain_key_create_next(chain_key, &next_chain_key);
    if(result < 0) {
        goto complete;
    }

    result = session_state_set_sender_chain_key(state, next_chain_key);
    if(result < 0) {
        goto complete;
    }

    result = signal_protocol_session_store_session(cipher->store, cipher->remote_address, record);

complete:
    if(result >= 0) {
        if(pre_key_message) {
            *encrypted_message = (ciphertext_message *)pre_key_message;
        }
        else {
            *encrypted_message = (ciphertext_message *)message;
        }
    }
    else {
        SIGNAL_UNREF(pre_key_message);
        SIGNAL_UNREF(message);
    }
    signal_buffer_free(ciphertext);
    SIGNAL_UNREF(next_chain_key);
    SIGNAL_UNREF(record);
    signal_explicit_bzero(&message_keys, sizeof(ratchet_message_keys));
    signal_unlock(cipher->global_context);
    return result;
}
END_TEST

START_TEST(test_serialize_pre_key_signal_message)
{
    int result = 0;

    static const char ciphertext[] = "WhisperCipherText";
    ec_public_key *sender_ratchet_key = create_test_ec_public_key(global_context);
    ec_public_key *sender_identity_key = create_test_ec_public_key(global_context);
    ec_public_key *receiver_identity_key = create_test_ec_public_key(global_context);
    ec_public_key *base_key = create_test_ec_public_key(global_context);
    ec_public_key *identity_key = create_test_ec_public_key(global_context);
    uint8_t mac_key[RATCHET_MAC_KEY_LENGTH];
    memset(mac_key, 1, sizeof(mac_key));

    signal_message *message = 0;
    pre_key_signal_message *pre_key_message = 0;
    pre_key_signal_message *result_pre_key_message = 0;

    result = signal_message_create(&message, 3,
                                   mac_key, sizeof(mac_key),
                                   sender_ratchet_key,
                                   2, /* counter */
                                   1, /* previous counter */
                                   (uint8_t *)ciphertext, sizeof(ciphertext) - 1,
                                   sender_identity_key, receiver_identity_key,
                                   global_context);
    ck_assert_int_eq(result, 0);

    uint32_t pre_key_id = 56;
    result = pre_key_signal_message_create(&pre_key_message,
                                           3,  /* message version */
                                           42, /* registration ID */
                                           &pre_key_id, /* pre key ID */
                                           72, /* signed pre key ID */
                                           base_key, identity_key,
                                           message,
                                           global_context);
    ck_assert_int_eq(result, 0);

    signal_buffer *serialized = ciphertext_message_get_serialized((ciphertext_message *)pre_key_message);
    ck_assert_ptr_ne(serialized, 0);

    result = pre_key_signal_message_deserialize(&result_pre_key_message,
             signal_buffer_data(serialized),
             signal_buffer_len(serialized),
             global_context);
    ck_assert_int_eq(result, 0);

    int version1 = pre_key_signal_message_get_message_version(pre_key_message);
    int version2 = pre_key_signal_message_get_message_version(result_pre_key_message);
    ck_assert_int_eq(version1, version2);

    ec_public_key *identity_key1 = pre_key_signal_message_get_identity_key(pre_key_message);
    ec_public_key *identity_key2 = pre_key_signal_message_get_identity_key(result_pre_key_message);
    ck_assert_int_eq(ec_public_key_compare(identity_key1, identity_key2), 0);

    int registration_id1 = pre_key_signal_message_get_registration_id(pre_key_message);
    int registration_id2 = pre_key_signal_message_get_registration_id(result_pre_key_message);
    ck_assert_int_eq(registration_id1, registration_id2);

    int has_pre_key_id1 = pre_key_signal_message_has_pre_key_id(pre_key_message);
    int has_pre_key_id2 = pre_key_signal_message_has_pre_key_id(result_pre_key_message);
    ck_assert_int_eq(has_pre_key_id1, has_pre_key_id2);

    if(has_pre_key_id1) {
        int pre_key_id1 = pre_key_signal_message_get_pre_key_id(pre_key_message);
        int pre_key_id2 = pre_key_signal_message_get_pre_key_id(result_pre_key_message);
        ck_assert_int_eq(pre_key_id1, pre_key_id2);
    }

    int signed_pre_key_id1 = pre_key_signal_message_get_signed_pre_key_id(pre_key_message);
    int signed_pre_key_id2 = pre_key_signal_message_get_signed_pre_key_id(result_pre_key_message);
    ck_assert_int_eq(signed_pre_key_id1, signed_pre_key_id2);

    ec_public_key *base_key1 = pre_key_signal_message_get_base_key(pre_key_message);
    ec_public_key *base_key2 = pre_key_signal_message_get_base_key(result_pre_key_message);
    ck_assert_int_eq(ec_public_key_compare(base_key1, base_key2), 0);

    signal_message *message1 = pre_key_signal_message_get_signal_message(pre_key_message);
    signal_message *message2 = pre_key_signal_message_get_signal_message(result_pre_key_message);
    compare_signal_messages(message1, message2);

    /* Cleanup */
    SIGNAL_UNREF(message);
    SIGNAL_UNREF(result_pre_key_message);
    SIGNAL_UNREF(pre_key_message);
    SIGNAL_UNREF(sender_ratchet_key);
    SIGNAL_UNREF(sender_identity_key);
    SIGNAL_UNREF(receiver_identity_key);
    SIGNAL_UNREF(base_key);
    SIGNAL_UNREF(identity_key);
}