Beispiel #1
0
/*
 * Verify a plaintext password against a SCRAM verifier.  This is used when
 * performing plaintext password authentication for a user that has a SCRAM
 * verifier stored in pg_authid.
 */
bool
scram_verify_plain_password(const char *username, const char *password,
							const char *verifier)
{
	char	   *encoded_salt;
	char	   *salt;
	int			saltlen;
	int			iterations;
	uint8		salted_password[SCRAM_KEY_LEN];
	uint8		stored_key[SCRAM_KEY_LEN];
	uint8		server_key[SCRAM_KEY_LEN];
	uint8		computed_key[SCRAM_KEY_LEN];
	char	   *prep_password = NULL;
	pg_saslprep_rc rc;

	if (!parse_scram_verifier(verifier, &iterations, &encoded_salt,
							  stored_key, server_key))
	{
		/*
		 * The password looked like a SCRAM verifier, but could not be parsed.
		 */
		ereport(LOG,
				(errmsg("invalid SCRAM verifier for user \"%s\"", username)));
		return false;
	}

	salt = palloc(pg_b64_dec_len(strlen(encoded_salt)));
	saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt);
	if (saltlen == -1)
	{
		ereport(LOG,
				(errmsg("invalid SCRAM verifier for user \"%s\"", username)));
		return false;
	}

	/* Normalize the password */
	rc = pg_saslprep(password, &prep_password);
	if (rc == SASLPREP_SUCCESS)
		password = prep_password;

	/* Compute Server Key based on the user-supplied plaintext password */
	scram_SaltedPassword(password, salt, saltlen, iterations, salted_password);
	scram_ServerKey(salted_password, computed_key);

	if (prep_password)
		pfree(prep_password);

	/*
	 * Compare the verifier's Server Key with the one computed from the
	 * user-supplied password.
	 */
	return memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0;
}
Beispiel #2
0
/*
 * Read and parse the final message received from client.
 */
static void
read_client_final_message(scram_state *state, char *input)
{
	char		attr;
	char	   *channel_binding;
	char	   *value;
	char	   *begin,
			   *proof;
	char	   *p;
	char	   *client_proof;

	begin = p = pstrdup(input);

	/*------
	 * The syntax for the server-first-message is: (RFC 5802)
	 *
	 * gs2-header	   = gs2-cbind-flag "," [ authzid ] ","
	 *					 ;; GS2 header for SCRAM
	 *					 ;; (the actual GS2 header includes an optional
	 *					 ;; flag to indicate that the GSS mechanism is not
	 *					 ;; "standard", but since SCRAM is "standard", we
	 *					 ;; don't include that flag).
	 *
	 * cbind-input	 = gs2-header [ cbind-data ]
	 *					 ;; cbind-data MUST be present for
	 *					 ;; gs2-cbind-flag of "p" and MUST be absent
	 *					 ;; for "y" or "n".
	 *
	 * channel-binding = "c=" base64
	 *					 ;; base64 encoding of cbind-input.
	 *
	 * proof		   = "p=" base64
	 *
	 * client-final-message-without-proof =
	 *					 channel-binding "," nonce [","
	 *					 extensions]
	 *
	 * client-final-message =
	 *					 client-final-message-without-proof "," proof
	 *------
	 */

	/*
	 * Read channel-binding.  We don't support channel binding, so it's
	 * expected to always be "biws", which is "n,,", base64-encoded, or
	 * "eSws", which is "y,,".  We also have to check whether the flag is
	 * the same one that the client originally sent.
	 */
	channel_binding = read_attr_value(&p, 'c');
	if (!(strcmp(channel_binding, "biws") == 0 && state->cbind_flag == 'n') &&
		!(strcmp(channel_binding, "eSws") == 0 && state->cbind_flag == 'y'))
		ereport(ERROR,
				(errcode(ERRCODE_PROTOCOL_VIOLATION),
				 (errmsg("unexpected SCRAM channel-binding attribute in client-final-message"))));
	state->client_final_nonce = read_attr_value(&p, 'r');

	/* ignore optional extensions */
	do
	{
		proof = p - 1;
		value = read_any_attr(&p, &attr);
	} while (attr != 'p');

	client_proof = palloc(pg_b64_dec_len(strlen(value)));
	if (pg_b64_decode(value, strlen(value), client_proof) != SCRAM_KEY_LEN)
		ereport(ERROR,
				(errcode(ERRCODE_PROTOCOL_VIOLATION),
				 errmsg("malformed SCRAM message"),
				 errdetail("Malformed proof in client-final-message.")));
	memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN);
	pfree(client_proof);

	if (*p != '\0')
		ereport(ERROR,
				(errcode(ERRCODE_PROTOCOL_VIOLATION),
				 errmsg("malformed SCRAM message"),
				 errdetail("Garbage found at the end of client-final-message.")));

	state->client_final_message_without_proof = palloc(proof - begin + 1);
	memcpy(state->client_final_message_without_proof, input, proof - begin);
	state->client_final_message_without_proof[proof - begin] = '\0';
}
Beispiel #3
0
/*
 * Parse and validate format of given SCRAM verifier.
 *
 * Returns true if the SCRAM verifier has been parsed, and false otherwise.
 */
static bool
parse_scram_verifier(const char *verifier, int *iterations, char **salt,
					 uint8 *stored_key, uint8 *server_key)
{
	char	   *v;
	char	   *p;
	char	   *scheme_str;
	char	   *salt_str;
	char	   *iterations_str;
	char	   *storedkey_str;
	char	   *serverkey_str;
	int			decoded_len;
	char	   *decoded_salt_buf;

	/*
	 * The verifier is of form:
	 *
	 * SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
	 */
	v = pstrdup(verifier);
	if ((scheme_str = strtok(v, "$")) == NULL)
		goto invalid_verifier;
	if ((iterations_str = strtok(NULL, ":")) == NULL)
		goto invalid_verifier;
	if ((salt_str = strtok(NULL, "$")) == NULL)
		goto invalid_verifier;
	if ((storedkey_str = strtok(NULL, ":")) == NULL)
		goto invalid_verifier;
	if ((serverkey_str = strtok(NULL, "")) == NULL)
		goto invalid_verifier;

	/* Parse the fields */
	if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
		goto invalid_verifier;

	errno = 0;
	*iterations = strtol(iterations_str, &p, 10);
	if (*p || errno != 0)
		goto invalid_verifier;

	/*
	 * Verify that the salt is in Base64-encoded format, by decoding it,
	 * although we return the encoded version to the caller.
	 */
	decoded_salt_buf = palloc(pg_b64_dec_len(strlen(salt_str)));
	decoded_len = pg_b64_decode(salt_str, strlen(salt_str), decoded_salt_buf);
	if (decoded_len < 0)
		goto invalid_verifier;
	*salt = pstrdup(salt_str);

	/*
	 * Decode StoredKey and ServerKey.
	 */
	if (pg_b64_dec_len(strlen(storedkey_str) != SCRAM_KEY_LEN))
		goto invalid_verifier;
	decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
								(char *) stored_key);
	if (decoded_len != SCRAM_KEY_LEN)
		goto invalid_verifier;

	if (pg_b64_dec_len(strlen(serverkey_str) != SCRAM_KEY_LEN))
		goto invalid_verifier;
	decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
								(char *) server_key);
	if (decoded_len != SCRAM_KEY_LEN)
		goto invalid_verifier;

	return true;

invalid_verifier:
	pfree(v);
	*salt = NULL;
	return false;
}
Beispiel #4
0
/*
 * Read and parse the final message received from client.
 */
static void
read_client_final_message(scram_state *state, char *input)
{
	char		attr;
	char	   *channel_binding;
	char	   *value;
	char	   *begin,
			   *proof;
	char	   *p;
	char	   *client_proof;

	begin = p = pstrdup(input);

	/*------
	 * The syntax for the server-first-message is: (RFC 5802)
	 *
	 * gs2-header	   = gs2-cbind-flag "," [ authzid ] ","
	 *					 ;; GS2 header for SCRAM
	 *					 ;; (the actual GS2 header includes an optional
	 *					 ;; flag to indicate that the GSS mechanism is not
	 *					 ;; "standard", but since SCRAM is "standard", we
	 *					 ;; don't include that flag).
	 *
	 * cbind-input	 = gs2-header [ cbind-data ]
	 *					 ;; cbind-data MUST be present for
	 *					 ;; gs2-cbind-flag of "p" and MUST be absent
	 *					 ;; for "y" or "n".
	 *
	 * channel-binding = "c=" base64
	 *					 ;; base64 encoding of cbind-input.
	 *
	 * proof		   = "p=" base64
	 *
	 * client-final-message-without-proof =
	 *					 channel-binding "," nonce [","
	 *					 extensions]
	 *
	 * client-final-message =
	 *					 client-final-message-without-proof "," proof
	 *------
	 */

	/*
	 * Read channel binding.  This repeats the channel-binding flags and is
	 * then followed by the actual binding data depending on the type.
	 */
	channel_binding = read_attr_value(&p, 'c');
	if (state->channel_binding_in_use)
	{
#ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
		const char *cbind_data = NULL;
		size_t		cbind_data_len = 0;
		size_t		cbind_header_len;
		char	   *cbind_input;
		size_t		cbind_input_len;
		char	   *b64_message;
		int			b64_message_len;

		Assert(state->cbind_flag == 'p');

		/* Fetch hash data of server's SSL certificate */
		cbind_data = be_tls_get_certificate_hash(state->port,
												 &cbind_data_len);

		/* should not happen */
		if (cbind_data == NULL || cbind_data_len == 0)
			elog(ERROR, "could not get server certificate hash");

		cbind_header_len = strlen("p=tls-server-end-point,,");	/* p=type,, */
		cbind_input_len = cbind_header_len + cbind_data_len;
		cbind_input = palloc(cbind_input_len);
		snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,");
		memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);

		b64_message = palloc(pg_b64_enc_len(cbind_input_len) + 1);
		b64_message_len = pg_b64_encode(cbind_input, cbind_input_len,
										b64_message);
		b64_message[b64_message_len] = '\0';

		/*
		 * Compare the value sent by the client with the value expected by the
		 * server.
		 */
		if (strcmp(channel_binding, b64_message) != 0)
			ereport(ERROR,
					(errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
					 (errmsg("SCRAM channel binding check failed"))));
#else
		/* shouldn't happen, because we checked this earlier already */
		elog(ERROR, "channel binding not supported by this build");
#endif
	}
	else
	{
		/*
		 * If we are not using channel binding, the binding data is expected
		 * to always be "biws", which is "n,," base64-encoded, or "eSws",
		 * which is "y,,".  We also have to check whether the flag is the same
		 * one that the client originally sent.
		 */
		if (!(strcmp(channel_binding, "biws") == 0 && state->cbind_flag == 'n') &&
			!(strcmp(channel_binding, "eSws") == 0 && state->cbind_flag == 'y'))
			ereport(ERROR,
					(errcode(ERRCODE_PROTOCOL_VIOLATION),
					 (errmsg("unexpected SCRAM channel-binding attribute in client-final-message"))));
	}

	state->client_final_nonce = read_attr_value(&p, 'r');

	/* ignore optional extensions */
	do
	{
		proof = p - 1;
		value = read_any_attr(&p, &attr);
	} while (attr != 'p');

	client_proof = palloc(pg_b64_dec_len(strlen(value)));
	if (pg_b64_decode(value, strlen(value), client_proof) != SCRAM_KEY_LEN)
		ereport(ERROR,
				(errcode(ERRCODE_PROTOCOL_VIOLATION),
				 errmsg("malformed SCRAM message"),
				 errdetail("Malformed proof in client-final-message.")));
	memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN);
	pfree(client_proof);

	if (*p != '\0')
		ereport(ERROR,
				(errcode(ERRCODE_PROTOCOL_VIOLATION),
				 errmsg("malformed SCRAM message"),
				 errdetail("Garbage found at the end of client-final-message.")));

	state->client_final_message_without_proof = palloc(proof - begin + 1);
	memcpy(state->client_final_message_without_proof, input, proof - begin);
	state->client_final_message_without_proof[proof - begin] = '\0';
}
Beispiel #5
0
int
main(int argc, char *argv[])
{
#define PRINT_USAGE(exit_code)	print_usage(argv[0], exit_code)

	char		conf_file[POOLMAXPATHLEN + 1];
	char		enc_key[MAX_POOL_KEY_LEN + 1];
	char		pg_pass[MAX_PGPASS_LEN + 1];
	char		username[MAX_USER_NAME_LEN + 1];
	char		key_file_path[POOLMAXPATHLEN + sizeof(POOLKEYFILE) + 1];
	int			opt;
	int			optindex;
	bool		updatepasswd = false;
	bool		prompt = false;
	bool		prompt_for_key = false;
	char	   *pool_key = NULL;

	static struct option long_options[] = {
		{"help", no_argument, NULL, 'h'},
		{"prompt", no_argument, NULL, 'p'},
		{"prompt-for-key", no_argument, NULL, 'P'},
		{"update-pass", no_argument, NULL, 'm'},
		{"username", required_argument, NULL, 'u'},
		{"enc-key", required_argument, NULL, 'K'},
		{"key-file", required_argument, NULL, 'k'},
		{"config-file", required_argument, NULL, 'f'},
		{NULL, 0, NULL, 0}
	};

	snprintf(conf_file, sizeof(conf_file), "%s/%s", DEFAULT_CONFIGDIR, POOL_CONF_FILE_NAME);

	/*
	 * initialize username buffer with zeros so that we can use strlen on it
	 * later to check if a username was given on the command line
	 */
	memset(username, 0, sizeof(username));
	memset(enc_key, 0, sizeof(enc_key));
	memset(key_file_path, 0, sizeof(key_file_path));

	while ((opt = getopt_long(argc, argv, "hPpmf:u:k:K:", long_options, &optindex)) != -1)
	{
		switch (opt)
		{
			case 'p':			/* prompt for postgres password */
				prompt = true;
				break;

			case 'P':			/* prompt for encryption key */
				prompt_for_key = true;
				break;

			case 'm':			/* update password file */
				updatepasswd = true;
				break;

			case 'f':			/* specify configuration file */
				if (!optarg)
				{
					PRINT_USAGE(EXIT_SUCCESS);
				}
				strlcpy(conf_file, optarg, sizeof(conf_file));
				break;

			case 'k':			/* specify key file for encrypting
								 * pool_password entries */
				if (!optarg)
				{
					PRINT_USAGE(EXIT_SUCCESS);
				}
				strlcpy(key_file_path, optarg, sizeof(key_file_path));
				break;

			case 'K':			/* specify configuration file */
				if (!optarg)
				{
					PRINT_USAGE(EXIT_SUCCESS);
				}
				strlcpy(enc_key, optarg, sizeof(enc_key));
				break;

			case 'u':
				if (!optarg)
				{
					PRINT_USAGE(EXIT_SUCCESS);
				}
				/* check the input limit early */
				if (strlen(optarg) > sizeof(username))
				{
					fprintf(stderr, "Error: input exceeds maximum username length!\n\n");
					exit(EXIT_FAILURE);
				}
				strlcpy(username, optarg, sizeof(username));
				break;

			default:
				PRINT_USAGE(EXIT_SUCCESS);
				break;
		}
	}

	/* Prompt for password. */
	if (prompt || optind >= argc)
	{
		char		buf[MAX_PGPASS_LEN];
		int			len;

		set_tio_attr(1);
		printf("db password: "******"Couldn't read input from stdin. (fgets(): %s)",
					strerror(eno));

			exit(EXIT_FAILURE);
		}
		printf("\n");
		set_tio_attr(0);

		/* Remove LF at the end of line, if there is any. */
		len = strlen(buf);
		if (len > 0 && buf[len - 1] == '\n')
		{
			buf[len - 1] = '\0';
			len--;
		}
		stpncpy(pg_pass, buf, sizeof(pg_pass));
	}

	/* Read password from argv. */
	else
	{
		int			len;

		len = strlen(argv[optind]);

		if (len > MAX_PGPASS_LEN)
		{
			fprintf(stderr, "Error: Input exceeds maximum password length given:%d max allowed:%d!\n\n", len, MAX_PGPASS_LEN);
			PRINT_USAGE(EXIT_FAILURE);
		}

		stpncpy(pg_pass, argv[optind], sizeof(pg_pass));
	}
	/* prompt for key, overrides all key related arguments */
	if (prompt_for_key)
	{
		char		buf[MAX_POOL_KEY_LEN];
		int			len;

		/* we need to read the encryption key from stdin */
		set_tio_attr(1);
		printf("encryption key: ");
		if (!fgets(buf, sizeof(buf), stdin))
		{
			int			eno = errno;

			fprintf(stderr, "Couldn't read input from stdin. (fgets(): %s)",
					strerror(eno));

			exit(EXIT_FAILURE);
		}
		printf("\n");
		set_tio_attr(0);
		/* Remove LF at the end of line, if there is any. */
		len = strlen(buf);
		if (len > 0 && buf[len - 1] == '\n')
		{
			buf[len - 1] = '\0';
			len--;
		}
		if (len == 0)
		{
			fprintf(stderr, "encryption key not provided\n");
			exit(EXIT_FAILURE);
		}
		stpncpy(enc_key, buf, sizeof(enc_key));
	}
	else
	{
		/* check if we already have not got the key from command line argument */
		if (strlen(enc_key) == 0)
		{
			/* read from file */
			if (strlen(key_file_path) == 0)
			{
				get_pool_key_filename(key_file_path);
			}

			fprintf(stdout, "trying to read key from file %s\n", key_file_path);

			pool_key = read_pool_key(key_file_path);
		}
		else
		{
			pool_key = enc_key;
		}
	}

	if (pool_key == NULL)
	{
		fprintf(stderr, "encryption key not provided\n");
		exit(EXIT_FAILURE);
	}

	if (updatepasswd)
	{
		update_pool_passwd(conf_file, username, pg_pass, pool_key);
	}
	else
	{
		unsigned char ciphertext[MAX_ENCODED_PASSWD_LEN];
		unsigned char b64_enc[MAX_ENCODED_PASSWD_LEN];
		int			len;
		int			cypher_len;

		cypher_len = aes_encrypt_with_password((unsigned char *) pg_pass,
											   strlen(pg_pass), pool_key, ciphertext);

		/* generate the hash for the given username */
		len = pg_b64_encode((const char *) ciphertext, cypher_len, (char *) b64_enc);
		b64_enc[len] = 0;
		fprintf(stdout, "\n%s\n", b64_enc);
		fprintf(stdout, "pool_passwd string: AES%s\n", b64_enc);

#ifdef DEBUG_ENCODING
		unsigned char b64_dec[MAX_ENCODED_PASSWD_LEN];
		unsigned char plaintext[MAX_PGPASS_LEN];

		len = pg_b64_decode(b64_enc, len, b64_dec);
		len = aes_decrypt_with_password(b64_dec, len,
										pool_key, plaintext);
		plaintext[len] = 0;
#endif
	}

	if (pool_key != enc_key)
		free(pool_key);

	return EXIT_SUCCESS;
}