示例#1
0
int cp_rsa_enc(uint8_t *out, int *out_len, uint8_t *in, int in_len, rsa_t pub) {
	bn_t m, eb;
	int size, pad_len, result = STS_OK;

	bn_null(m);
	bn_null(eb);

	size = bn_size_bin(pub->n);

	if (pub == NULL || in_len <= 0 || in_len > (size - RSA_PAD_LEN)) {
		return STS_ERR;
	}

	TRY {
		bn_new(m);
		bn_new(eb);

		bn_zero(m);
		bn_zero(eb);

#if CP_RSAPD == BASIC
		if (pad_basic(eb, &pad_len, in_len, size, RSA_ENC) == STS_OK) {
#elif CP_RSAPD == PKCS1
		if (pad_pkcs1(eb, &pad_len, in_len, size, RSA_ENC) == STS_OK) {
#elif CP_RSAPD == PKCS2
		if (pad_pkcs2(eb, &pad_len, in_len, size, RSA_ENC) == STS_OK) {
#endif
			bn_read_bin(m, in, in_len);
			bn_add(eb, eb, m);

#if CP_RSAPD == PKCS2
			pad_pkcs2(eb, &pad_len, in_len, size, RSA_ENC_FIN);
#endif
			bn_mxp(eb, eb, pub->e, pub->n);

			if (size <= *out_len) {
				*out_len = size;
				memset(out, 0, *out_len);
				bn_write_bin(out, size, eb);
			} else {
				result = STS_ERR;
			}
		} else {
			result = STS_ERR;
		}
	}
	CATCH_ANY {
		result = STS_ERR;
	}
	FINALLY {
		bn_free(m);
		bn_free(eb);
	}

	return result;
}

#if CP_RSA == BASIC || !defined(STRIP)

int cp_rsa_dec_basic(uint8_t *out, int *out_len, uint8_t *in, int in_len, rsa_t prv) {
	bn_t m, eb;
	int size, pad_len, result = STS_OK;

	size = bn_size_bin(prv->n);

	if (prv == NULL || in_len != size || in_len < RSA_PAD_LEN) {
		return STS_ERR;
	}

	bn_null(m);
	bn_null(eb);

	TRY {
		bn_new(m);
		bn_new(eb);

		bn_read_bin(eb, in, in_len);
		bn_mxp(eb, eb, prv->d, prv->n);

		if (bn_cmp(eb, prv->n) != CMP_LT) {
			result = STS_ERR;
		}
#if CP_RSAPD == BASIC
		if (pad_basic(eb, &pad_len, in_len, size, RSA_DEC) == STS_OK) {
#elif CP_RSAPD == PKCS1
		if (pad_pkcs1(eb, &pad_len, in_len, size, RSA_DEC) == STS_OK) {
#elif CP_RSAPD == PKCS2
		if (pad_pkcs2(eb, &pad_len, in_len, size, RSA_DEC) == STS_OK) {
#endif
			size = size - pad_len;

			if (size <= *out_len) {
				memset(out, 0, size);
				bn_write_bin(out, size, eb);
				*out_len = size;
			} else {
				result = STS_ERR;
			}
		} else {
			result = STS_ERR;
		}
	}
	CATCH_ANY {
		result = STS_ERR;
	}
	FINALLY {
		bn_free(m);
		bn_free(eb);
	}

	return result;
}

#endif

#if CP_RSA == QUICK || !defined(STRIP)

int cp_rsa_dec_quick(uint8_t *out, int *out_len, uint8_t *in, int in_len, rsa_t prv) {
	bn_t m, eb;
	int size, pad_len, result = STS_OK;

	bn_null(m);
	bn_null(eb);

	size = bn_size_bin(prv->n);

	if (prv == NULL || in_len != size || in_len < RSA_PAD_LEN) {
		return STS_ERR;
	}

	TRY {
		bn_new(m);
		bn_new(eb);

		bn_read_bin(eb, in, in_len);

		bn_copy(m, eb);

		/* m1 = c^dP mod p. */
		bn_mxp(eb, eb, prv->dp, prv->p);

		/* m2 = c^dQ mod q. */
		bn_mxp(m, m, prv->dq, prv->q);

		/* m1 = m1 - m2 mod p. */
		bn_sub(eb, eb, m);
		while (bn_sign(eb) == BN_NEG) {
			bn_add(eb, eb, prv->p);
		}
		bn_mod(eb, eb, prv->p);
		/* m1 = qInv(m1 - m2) mod p. */
		bn_mul(eb, eb, prv->qi);
		bn_mod(eb, eb, prv->p);
		/* m = m2 + m1 * q. */
		bn_mul(eb, eb, prv->q);
		bn_add(eb, eb, m);

		if (bn_cmp(eb, prv->n) != CMP_LT) {
			result = STS_ERR;
		}
#if CP_RSAPD == BASIC
		if (pad_basic(eb, &pad_len, in_len, size, RSA_DEC) == STS_OK) {
#elif CP_RSAPD == PKCS1
		if (pad_pkcs1(eb, &pad_len, in_len, size, RSA_DEC) == STS_OK) {
#elif CP_RSAPD == PKCS2
		if (pad_pkcs2(eb, &pad_len, in_len, size, RSA_DEC) == STS_OK) {
#endif
			size = size - pad_len;

			if (size <= *out_len) {
				memset(out, 0, size);
				bn_write_bin(out, size, eb);
				*out_len = size;
			} else {
				result = STS_ERR;
			}
		} else {
			result = STS_ERR;
		}
	}
	CATCH_ANY {
		result = STS_ERR;
	}
	FINALLY {
		bn_free(m);
		bn_free(eb);
	}

	return result;
}

#endif

#if CP_RSA == BASIC || !defined(STRIP)

int cp_rsa_sig_basic(uint8_t *sig, int *sig_len, uint8_t *msg, int msg_len, int hash, rsa_t prv) {
	bn_t m, eb;
	int size, pad_len, result = STS_OK;
	uint8_t h[MD_LEN];

	if (prv == NULL || msg_len < 0) {
		return STS_ERR;
	}

	pad_len = (!hash ? MD_LEN : msg_len);

#if CP_RSAPD == PKCS2
	size = bn_bits(prv->n) - 1;
	size = (size / 8) + (size % 8 > 0);
	if (pad_len > (size - 2)) {
		return STS_ERR;
	}
#else
	size = bn_size_bin(prv->n);
	if (pad_len > (size - RSA_PAD_LEN)) {
		return STS_ERR;
	}
#endif

	bn_null(m);
	bn_null(eb);

	TRY {
		bn_new(m);
		bn_new(eb);

		bn_zero(m);
		bn_zero(eb);

		int operation = (!hash ? RSA_SIG : RSA_SIG_HASH);

#if CP_RSAPD == BASIC
		if (pad_basic(eb, &pad_len, pad_len, size, operation) == STS_OK) {
#elif CP_RSAPD == PKCS1
		if (pad_pkcs1(eb, &pad_len, pad_len, size, operation) == STS_OK) {
#elif CP_RSAPD == PKCS2
		if (pad_pkcs2(eb, &pad_len, pad_len, size, operation) == STS_OK) {
#endif
			if (!hash) {
				md_map(h, msg, msg_len);
				bn_read_bin(m, h, MD_LEN);
				bn_add(eb, eb, m);
			} else {
				bn_read_bin(m, msg, msg_len);
				bn_add(eb, eb, m);
			}

#if CP_RSAPD == PKCS2
			pad_pkcs2(eb, &pad_len, bn_bits(prv->n), size, RSA_SIG_FIN);
#endif

			bn_mxp(eb, eb, prv->d, prv->n);

			size = bn_size_bin(prv->n);

			if (size <= *sig_len) {
				memset(sig, 0, size);
				bn_write_bin(sig, size, eb);
				*sig_len = size;
			} else {
				result = STS_ERR;
			}
		} else {
			result = STS_ERR;
		}
	}
	CATCH_ANY {
		THROW(ERR_CAUGHT);
	}
	FINALLY {
		bn_free(m);
		bn_free(eb);
	}

	return result;
}

#endif

#if CP_RSA == QUICK || !defined(STRIP)

int cp_rsa_sig_quick(uint8_t *sig, int *sig_len, uint8_t *msg, int msg_len, int hash, rsa_t prv) {
	bn_t m, eb;
	int pad_len, size, result = STS_OK;
	uint8_t h[MD_LEN];

	if (prv == NULL || msg_len < 0) {
		return STS_ERR;
	}

	pad_len = (!hash ? MD_LEN : msg_len);

#if CP_RSAPD == PKCS2
	size = bn_bits(prv->n) - 1;
	size = (size / 8) + (size % 8 > 0);
	if (pad_len > (size - 2)) {
		return STS_ERR;
	}
#else
	size = bn_size_bin(prv->n);
	if (pad_len > (size - RSA_PAD_LEN)) {
		return STS_ERR;
	}
#endif

	bn_null(m);
	bn_null(eb);

	TRY {
		bn_new(m);
		bn_new(eb);

		bn_zero(m);
		bn_zero(eb);

		int operation = (!hash ? RSA_SIG : RSA_SIG_HASH);

#if CP_RSAPD == BASIC
		if (pad_basic(eb, &pad_len, pad_len, size, operation) == STS_OK) {
#elif CP_RSAPD == PKCS1
		if (pad_pkcs1(eb, &pad_len, pad_len, size, operation) == STS_OK) {
#elif CP_RSAPD == PKCS2
		if (pad_pkcs2(eb, &pad_len, pad_len, size, operation) == STS_OK) {
#endif
			if (!hash) {
				md_map(h, msg, msg_len);
				bn_read_bin(m, h, MD_LEN);
				bn_add(eb, eb, m);
			} else {
				bn_read_bin(m, msg, msg_len);
				bn_add(eb, eb, m);
			}

#if CP_RSAPD == PKCS2
			pad_pkcs2(eb, &pad_len, bn_bits(prv->n), size, RSA_SIG_FIN);
#endif

			bn_copy(m, eb);

			/* m1 = c^dP mod p. */
			bn_mxp(eb, eb, prv->dp, prv->p);

			/* m2 = c^dQ mod q. */
			bn_mxp(m, m, prv->dq, prv->q);

			/* m1 = m1 - m2 mod p. */
			bn_sub(eb, eb, m);
			while (bn_sign(eb) == BN_NEG) {
				bn_add(eb, eb, prv->p);
			}
			bn_mod(eb, eb, prv->p);
			/* m1 = qInv(m1 - m2) mod p. */
			bn_mul(eb, eb, prv->qi);
			bn_mod(eb, eb, prv->p);
			/* m = m2 + m1 * q. */
			bn_mul(eb, eb, prv->q);
			bn_add(eb, eb, m);
			bn_mod(eb, eb, prv->n);

			size = bn_size_bin(prv->n);

			if (size <= *sig_len) {
				memset(sig, 0, size);
				bn_write_bin(sig, size, eb);
				*sig_len = size;
			} else {
				result = STS_ERR;
			}
		} else {
			result = STS_ERR;
		}
	}
	CATCH_ANY {
		THROW(ERR_CAUGHT);
	}
	FINALLY {
		bn_free(m);
		bn_free(eb);
	}

	return result;
}

#endif

int cp_rsa_ver(uint8_t *sig, int sig_len, uint8_t *msg, int msg_len, int hash, rsa_t pub) {
	bn_t m, eb;
	int size, pad_len, result;
	uint8_t h1[MAX(msg_len, MD_LEN) + 8], h2[MAX(msg_len, MD_LEN)];

	/* We suppose that the signature is invalid. */
	result = 0;

	if (pub == NULL || msg_len < 0) {
		return 0;
	}

	pad_len = (!hash ? MD_LEN : msg_len);

#if CP_RSAPD == PKCS2
	size = bn_bits(pub->n) - 1;
	if (size % 8 == 0) {
		size = size / 8 - 1;
	} else {
		size = bn_size_bin(pub->n);
	}
	if (pad_len > (size - 2)) {
		return 0;
	}
#else
	size = bn_size_bin(pub->n);
	if (pad_len > (size - RSA_PAD_LEN)) {
		return 0;
	}
#endif

	bn_null(m);
	bn_null(eb);

	TRY {
		bn_new(m);
		bn_new(eb);

		bn_read_bin(eb, sig, sig_len);

		bn_mxp(eb, eb, pub->e, pub->n);

		int operation = (!hash ? RSA_VER : RSA_VER_HASH);

#if CP_RSAPD == BASIC
		if (pad_basic(eb, &pad_len, MD_LEN, size, operation) == STS_OK) {
#elif CP_RSAPD == PKCS1
		if (pad_pkcs1(eb, &pad_len, MD_LEN, size, operation) == STS_OK) {
#elif CP_RSAPD == PKCS2
		if (pad_pkcs2(eb, &pad_len, bn_bits(pub->n), size, operation) == STS_OK) {
#endif

#if CP_RSAPD == PKCS2
			memset(h1, 0, 8);

			if (!hash) {
				md_map(h1 + 8, msg, msg_len);
				md_map(h2, h1, MD_LEN + 8);

				memset(h1, 0, MD_LEN);
				bn_write_bin(h1, size - pad_len, eb);
				/* Everything went ok, so signature status is changed. */
				result = util_cmp_const(h1, h2, MD_LEN);
			} else {
				memcpy(h1 + 8, msg, msg_len);
				md_map(h2, h1, MD_LEN + 8);

				memset(h1, 0, msg_len);
				bn_write_bin(h1, size - pad_len, eb);

				/* Everything went ok, so signature status is changed. */
				result = util_cmp_const(h1, h2, msg_len);
			}
#else
			memset(h1, 0, MAX(msg_len, MD_LEN));
			bn_write_bin(h1, size - pad_len, eb);

			if (!hash) {
				md_map(h2, msg, msg_len);
				/* Everything went ok, so signature status is changed. */
				result = util_cmp_const(h1, h2, MD_LEN);
			} else {
				/* Everything went ok, so signature status is changed. */
				result = util_cmp_const(h1, msg, msg_len);
			}
#endif
			result = (result == CMP_EQ ? 1 : 0);
		} else {
			result = 0;
		}
	}
	CATCH_ANY {
		result = 0;
	}
	FINALLY {
		bn_free(m);
		bn_free(eb);
	}

	return result;
}
示例#2
0
文件: tests-relic.c 项目: Cr0s/RIOT
static void tests_relic_ecdh(void)
{
    /*  The following is an example for doing an elliptic-curve Diffie-Hellman
        key exchange.
    */

    /* Select an elliptic curve configuration */
    if (ec_param_set_any() == STS_OK) {
#if (TEST_RELIC_SHOW_OUTPUT == 1)
        ec_param_print();
#endif

        bn_t privateA;
        ec_t publicA;
        uint8_t sharedKeyA[MD_LEN];

        bn_t privateB;
        ec_t publicB;
        uint8_t sharedKeyB[MD_LEN];

        bn_null(privateA);
        ec_null(publicA);

        bn_new(privateA);
        ec_new(publicA);

        bn_null(privateB);
        ec_null(publicB);

        bn_new(privateB);
        ec_new(publicB);

        /* User A generates private/public key pair */
        TEST_ASSERT_EQUAL_INT(STS_OK, cp_ecdh_gen(privateA, publicA));

#if (TEST_RELIC_SHOW_OUTPUT == 1)
        printf("User A\n");
        printf("======\n");
        printf("private key: ");
        bn_print(privateA);
        printf("\npublic key:  ");
        ec_print(publicA);
        printf("\n");
#endif

        /* User B generates private/public key pair */
        TEST_ASSERT_EQUAL_INT(STS_OK, cp_ecdh_gen(privateB, publicB));

#if (TEST_RELIC_SHOW_OUTPUT == 1)
        printf("User B\n");
        printf("======\n");
        printf("private key: ");
        bn_print(privateB);
        printf("\npublic key:  ");
        ec_print(publicB);
        printf("\n");
#endif

        /* In a protocol you would exchange the public keys now */

        /* User A calculates shared secret */
        TEST_ASSERT_EQUAL_INT(STS_OK, cp_ecdh_key(sharedKeyA, MD_LEN, privateA, publicB));

#if (TEST_RELIC_SHOW_OUTPUT == 1)
        printf("\nshared key computed by user A: ");
        print_mem(sharedKeyA, MD_LEN);
#endif

        /* User B calculates shared secret */
        TEST_ASSERT_EQUAL_INT(STS_OK, cp_ecdh_key(sharedKeyB, MD_LEN, privateB, publicA));

#if (TEST_RELIC_SHOW_OUTPUT == 1)
        printf("\nshared key computed by user B: ");
        print_mem(sharedKeyB, MD_LEN);
#endif

        /* The secrets should be the same now */
        TEST_ASSERT_EQUAL_INT(CMP_EQ, util_cmp_const(sharedKeyA, sharedKeyB, MD_LEN));

        bn_free(privateA);
        ec_free(publicA);

        bn_free(privateB);
        ec_free(publicB);
#if (TEST_RELIC_SHOW_OUTPUT == 1)
        printf("\nRELIC EC-DH test successful\n");
#endif
    }

}