static int lua_proto_append_challenge_packet (lua_State *L) {
	GString *packet;
	network_mysqld_auth_challenge *auth_challenge;

	luaL_checktype(L, 1, LUA_TTABLE);

	auth_challenge = network_mysqld_auth_challenge_new();

	LUA_IMPORT_INT(auth_challenge, protocol_version);
	LUA_IMPORT_INT(auth_challenge, server_version);
	LUA_IMPORT_INT(auth_challenge, thread_id);
	LUA_IMPORT_INT(auth_challenge, capabilities);
	LUA_IMPORT_INT(auth_challenge, charset);
	LUA_IMPORT_INT(auth_challenge, server_status);

	LUA_IMPORT_STR_FROM(auth_challenge, auth_plugin_data, "challenge");

	packet = g_string_new(NULL);	
	network_mysqld_proto_append_auth_challenge(packet, auth_challenge);
	
	network_mysqld_auth_challenge_free(auth_challenge);

	lua_pushlstring(L, S(packet));
	
	g_string_free(packet, TRUE);

	return 1;
}
void test_mysqld_handshake(void) {
	const char raw_packet[] = "J\0\0\0"
		"\n"
		"5.0.45-Debian_1ubuntu3.3-log\0"
		"w\0\0\0"
		"\"L;!3|8@"
		"\0"
		",\242" /* 0x2c 0xa2 */
		"\10"
		"\2\0"
		"\0\0\0\0\0\0\0\0\0\0\0\0\0"
		"vV,s#PLjSA+Q"
		"\0";
	network_mysqld_auth_challenge *shake;
	network_packet packet;

	shake = network_mysqld_auth_challenge_new();
	
	packet.data = g_string_new(NULL);
	packet.offset = 0;
	g_string_append_len(packet.data, C(raw_packet));

	g_assert_cmpint(packet.data->len, ==, 78);

	g_assert_cmpint(0, ==, network_mysqld_proto_skip_network_header(&packet));
	g_assert_cmpint(0, ==, network_mysqld_proto_get_auth_challenge(&packet, shake));

	g_assert(shake->server_version == 50045);
	g_assert(shake->thread_id == 119);
	g_assert(shake->server_status == 
			SERVER_STATUS_AUTOCOMMIT);
	g_assert(shake->charset == 8);
	g_assert(shake->capabilities ==
			(CLIENT_CONNECT_WITH_DB |
			CLIENT_LONG_FLAG |

			CLIENT_COMPRESS |

			CLIENT_PROTOCOL_41 |

			CLIENT_TRANSACTIONS |
			CLIENT_SECURE_CONNECTION));

	g_assert(shake->challenge->len == 20);
	g_assert(0 == memcmp(shake->challenge->str, "\"L;!3|8@vV,s#PLjSA+Q", shake->challenge->len));

	/* ... and back */
	g_string_truncate(packet.data, 0);
	g_string_append_len(packet.data, C("J\0\0\0"));
	network_mysqld_proto_append_auth_challenge(packet.data, shake);

	g_assert_cmpint(packet.data->len, ==, sizeof(raw_packet) - 1);

	g_assert(0 == memcmp(packet.data->str, raw_packet, packet.data->len));

	network_mysqld_auth_challenge_free(shake);
	g_string_free(packet.data, TRUE);
}
예제 #3
0
static PyObject *
python_proto_append_challenge_packet (PyObject *self, PyObject *args) {
	GString *packet;
	network_mysqld_auth_challenge *auth_challenge;

	PyObject *dict;
	if(!PyArg_ParseTuple(args, "O!", &PyDict_Type, &dict))
		return NULL;

	auth_challenge = network_mysqld_auth_challenge_new();

#define PYTHON_IMPORT_INT(x, y) \
	PyObject *v_ ## y = PyDict_GetItemString(dict, #y);\
	if(v_ ## y){\
		if(PyInt_Check(v_ ## y))\
			x->y = PyInt_AsLong(v_ ## y);\
		else{\
			PyErr_SetString(PyExc_ValueError, #x "." #y "must be an int");\
			Py_DECREF(v_ ## y);\
			return NULL;\
		}\
	}

#define PYTHON_IMPORT_STR(x, y) \
	PyObject *v_ ## y = PyDict_GetItemString(dict, #y);\
	if(v_ ## y){\
		if(PyString_Check(v_ ## y)){\
			int len;\
			char *str;\
			PyString_AsStringAndSize(v_ ## y, &str, &len);\
			g_string_assign_len(x->y, str, len);\
		} else{\
			PyErr_SetString(PyExc_ValueError, #x "." #y "must be an string");\
			Py_DECREF(v_ ## y);\
			return NULL;\
		}\
	}

	PYTHON_IMPORT_INT(auth_challenge, protocol_version);
	PYTHON_IMPORT_INT(auth_challenge, server_version);
	PYTHON_IMPORT_INT(auth_challenge, thread_id);
	PYTHON_IMPORT_INT(auth_challenge, capabilities);
	PYTHON_IMPORT_INT(auth_challenge, charset);
	PYTHON_IMPORT_INT(auth_challenge, server_status);

	PYTHON_IMPORT_STR(auth_challenge, challenge);

	packet = g_string_new(NULL);
	network_mysqld_proto_append_auth_challenge(packet, auth_challenge);

	network_mysqld_auth_challenge_free(auth_challenge);
	PyObject *res = PyString_FromStringAndSize(S(packet));
	if(!res)
		return NULL;
	g_string_free(packet, TRUE);
	return res;
}
static int lua_proto_get_challenge_packet (lua_State *L) {
	size_t packet_len;
	const char *packet_str = luaL_checklstring(L, 1, &packet_len);
	network_mysqld_auth_challenge *auth_challenge;
	network_packet packet;
	GString s;
	int err = 0;

	s.str = (char *)packet_str;
	s.len = packet_len;

	packet.data = &s;
	packet.offset = 0;

	auth_challenge = network_mysqld_auth_challenge_new();

	err = err || network_mysqld_proto_get_auth_challenge(&packet, auth_challenge);
	if (err) {
		network_mysqld_auth_challenge_free(auth_challenge);

		luaL_error(L, "%s: network_mysqld_proto_get_auth_challenge() failed", G_STRLOC);
		return 0;
	}

	lua_newtable(L);
	LUA_EXPORT_INT(auth_challenge, protocol_version);
	LUA_EXPORT_INT(auth_challenge, server_version);
	LUA_EXPORT_INT(auth_challenge, thread_id);
	LUA_EXPORT_INT(auth_challenge, capabilities);
	LUA_EXPORT_INT(auth_challenge, charset);
	LUA_EXPORT_INT(auth_challenge, server_status);

	LUA_EXPORT_STR_TO(auth_challenge, auth_plugin_data, "challenge");
	LUA_EXPORT_STR(auth_challenge, auth_plugin_name);

	network_mysqld_auth_challenge_free(auth_challenge);

	return 1;
}
예제 #5
0
network_socket *self_connect(network_mysqld_con *con, network_backend_t *backend, GHashTable *pwd_table) {

    /*make sure that the max conn for the backend is no more than the config number
     *when max_conn_for_a_backend is no more than 0, there is no limitation for max connection for a backend;
     * */
    if (con->srv->max_conn_for_a_backend > 0 && backend->connected_clients >= con->srv->max_conn_for_a_backend) {
        g_critical("%s.%d: self_connect:%08x's connected_clients is %d, which are too many!",__FILE__, __LINE__, backend,  backend->connected_clients);
        return NULL;
    }
    
    //1. connect DB
	network_socket *sock = network_socket_new();
	network_address_copy(sock->dst, backend->addr);
	if (-1 == (sock->fd = socket(sock->dst->addr.common.sa_family, sock->socket_type, 0))) {
		g_critical("%s.%d: socket(%s) failed: %s (%d)", __FILE__, __LINE__, sock->dst->name->str, g_strerror(errno), errno);
		network_socket_free(sock);
		return NULL;
	}
	if (-1 == (connect(sock->fd, &sock->dst->addr.common, sock->dst->len))) {
		g_message("%s.%d: connecting to backend (%s) failed, marking it as down for ...", __FILE__, __LINE__, sock->dst->name->str);
		network_socket_free(sock);
		if (backend->state != BACKEND_STATE_OFFLINE) backend->state = BACKEND_STATE_DOWN;
		return NULL;
	}

	//2. read handshake,重点是获取20个字节的随机串
	off_t to_read = NET_HEADER_SIZE;
	guint offset = 0;
	guchar header[NET_HEADER_SIZE];
	while (to_read > 0) {
		gssize len = recv(sock->fd, header + offset, to_read, 0);
		if (len == -1 || len == 0) {
			network_socket_free(sock);
			return NULL;
		}
		offset += len;
		to_read -= len;
	}

	to_read = header[0] + (header[1] << 8) + (header[2] << 16);
	offset = 0;
	GString *data = g_string_sized_new(to_read);
	while (to_read > 0) {
		gssize len = recv(sock->fd, data->str + offset, to_read, 0);
		if (len == -1 || len == 0) {
			network_socket_free(sock);
			g_string_free(data, TRUE);
			return NULL;
		}
		offset += len;
		to_read -= len;
	}
	data->len = offset;

	network_packet packet;
	packet.data = data;
	packet.offset = 0;
	network_mysqld_auth_challenge *challenge = network_mysqld_auth_challenge_new();
	network_mysqld_proto_get_auth_challenge(&packet, challenge);

	//3. 生成response
	GString *response = g_string_sized_new(20);
	GString *hashed_password = g_hash_table_lookup(pwd_table, con->client->response->username->str);
	if (hashed_password) {
		network_mysqld_proto_password_scramble(response, S(challenge->challenge), S(hashed_password));
	} else {
		network_socket_free(sock);
		g_string_free(data, TRUE);
		network_mysqld_auth_challenge_free(challenge);
		g_string_free(response, TRUE);
		return NULL;
	}

	//4. send auth
	off_t to_write = 58 + con->client->response->username->len;
	offset = 0;
	g_string_truncate(data, 0);
	char tmp[] = {to_write - 4, 0, 0, 1, 0x85, 0xa6, 3, 0, 0, 0, 0, 1, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
	g_string_append_len(data, tmp, 36);
	g_string_append_len(data, con->client->response->username->str, con->client->response->username->len);
	g_string_append_len(data, "\0\x14", 2);
	g_string_append_len(data, response->str, 20);
	g_string_free(response, TRUE);
	while (to_write > 0) {
		gssize len = send(sock->fd, data->str + offset, to_write, 0);
		if (len == -1) {
			network_socket_free(sock);
			g_string_free(data, TRUE);
			network_mysqld_auth_challenge_free(challenge);
			return NULL;
		}
		offset += len;
		to_write -= len;
	}

	//5. read auth result
	to_read = NET_HEADER_SIZE;
	offset = 0;
	while (to_read > 0) {
		gssize len = recv(sock->fd, header + offset, to_read, 0);
		if (len == -1 || len == 0) {
			network_socket_free(sock);
			g_string_free(data, TRUE);
			network_mysqld_auth_challenge_free(challenge);
			return NULL;
		}
		offset += len;
		to_read -= len;
	}

	to_read = header[0] + (header[1] << 8) + (header[2] << 16);
	offset = 0;
	g_string_truncate(data, 0);
	g_string_set_size(data, to_read);
	while (to_read > 0) {
		gssize len = recv(sock->fd, data->str + offset, to_read, 0);
		if (len == -1 || len == 0) {
			network_socket_free(sock);
			g_string_free(data, TRUE);
			network_mysqld_auth_challenge_free(challenge);
			return NULL;
		}
		offset += len;
		to_read -= len;
	}
	data->len = offset;

	if (data->str[0] != MYSQLD_PACKET_OK) {
		network_socket_free(sock);
		g_string_free(data, TRUE);
		network_mysqld_auth_challenge_free(challenge);
		return NULL;
	}
	g_string_free(data, TRUE);

	//6. set non-block
	network_socket_set_non_blocking(sock);
	network_socket_connect_setopts(sock);	//此句是否需要?是否应该放在第1步末尾?

	sock->challenge = challenge;
	sock->response = network_mysqld_auth_response_copy(con->client->response);
    g_atomic_int_inc(&backend->connected_clients);
	return sock;
}