Exemple #1
0
/**
 * Receive data from netlink socket
 * @arg sk		Netlink socket.
 * @arg nla		Destination pointer for peer's netlink address.
 * @arg buf		Destination pointer for message content.
 * @arg creds		Destination pointer for credentials.
 *
 * Receives a netlink message, allocates a buffer in \c *buf and
 * stores the message content. The peer's netlink address is stored
 * in \c *nla. The caller is responsible for freeing the buffer allocated
 * in \c *buf if a positive value is returned.  Interrupted system calls
 * are handled by repeating the read. The input buffer size is determined
 * by peeking before the actual read is done.
 *
 * A non-blocking sockets causes the function to return immediately with
 * a return value of 0 if no data is available.
 *
 * @return Number of octets read, 0 on EOF or a negative error code.
 */
int nl_recv(struct nl_sock *sk, struct sockaddr_nl *nla,
	    unsigned char **buf, struct ucred **creds)
{
	int n;
	int flags = 0;
	static int page_size = 0;
	struct iovec iov;
	struct msghdr msg = {
		.msg_name = (void *) nla,
		.msg_namelen = sizeof(struct sockaddr_nl),
		.msg_iov = &iov,
		.msg_iovlen = 1,
		.msg_control = NULL,
		.msg_controllen = 0,
		.msg_flags = 0,
	};
	struct cmsghdr *cmsg;

	memset(nla, 0, sizeof(*nla));

	if (sk->s_flags & NL_MSG_PEEK)
		flags |= MSG_PEEK;

	if (page_size == 0)
		page_size = getpagesize();

	iov.iov_len = page_size;
	iov.iov_base = *buf = malloc(iov.iov_len);

	if (sk->s_flags & NL_SOCK_PASSCRED) {
		msg.msg_controllen = CMSG_SPACE(sizeof(struct ucred));
		msg.msg_control = calloc(1, msg.msg_controllen);
	}
retry:

	n = recvmsg(sk->s_fd, &msg, flags);
	if (!n)
		goto abort;
	else if (n < 0) {
		if (errno == EINTR) {
			NL_DBG(3, "recvmsg() returned EINTR, retrying\n");
			goto retry;
		} else if (errno == EAGAIN) {
			NL_DBG(3, "recvmsg() returned EAGAIN, aborting\n");
			goto abort;
		} else {
			free(msg.msg_control);
			free(*buf);
			return -nl_syserr2nlerr(errno);
		}
	}

	if (iov.iov_len < n ||
	    msg.msg_flags & MSG_TRUNC) {
		/* Provided buffer is not long enough, enlarge it
		 * and try again. */
		iov.iov_len *= 2;
		iov.iov_base = *buf = realloc(*buf, iov.iov_len);
		goto retry;
	} else if (msg.msg_flags & MSG_CTRUNC) {
		msg.msg_controllen *= 2;
		msg.msg_control = realloc(msg.msg_control, msg.msg_controllen);
		goto retry;
	} else if (flags != 0) {
		/* Buffer is big enough, do the actual reading */
		flags = 0;
		goto retry;
	}

	if (msg.msg_namelen != sizeof(struct sockaddr_nl)) {
		free(msg.msg_control);
		free(*buf);
		return -NLE_NOADDR;
	}

	for (cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
		if (cmsg->cmsg_level == SOL_SOCKET &&
		    cmsg->cmsg_type == SCM_CREDENTIALS) {
			if (creds) {
				*creds = calloc(1, sizeof(struct ucred));
				memcpy(*creds, CMSG_DATA(cmsg), sizeof(struct ucred));
			}
			break;
		}
	}

	free(msg.msg_control);
	return n;

abort:
	free(msg.msg_control);
	free(*buf);
	return 0;
}

#define NL_CB_CALL(cb, type, msg) \
do { \
	err = nl_cb_call(cb, type, msg); \
	switch (err) { \
	case NL_OK: \
		err = 0; \
		break; \
	case NL_SKIP: \
		goto skip; \
	case NL_STOP: \
		goto stop; \
	default: \
		goto out; \
	} \
} while (0)

static int recvmsgs(struct nl_sock *sk, struct nl_cb *cb)
{
	int n, err = 0, multipart = 0, interrupted = 0;
	unsigned char *buf = NULL;
	struct nlmsghdr *hdr;
	struct sockaddr_nl nla = {0};
	struct nl_msg *msg = NULL;
	struct ucred *creds = NULL;

continue_reading:
	NL_DBG(3, "Attempting to read from %p\n", sk);
	if (cb->cb_recv_ow)
		n = cb->cb_recv_ow(sk, &nla, &buf, &creds);
	else
		n = nl_recv(sk, &nla, &buf, &creds);

	if (n <= 0)
		return n;

	NL_DBG(3, "recvmsgs(%p): Read %d bytes\n", sk, n);

	hdr = (struct nlmsghdr *) buf;
	while (nlmsg_ok(hdr, n)) {
		NL_DBG(3, "recvmsgs(%p): Processing valid message...\n", sk);

		nlmsg_free(msg);
		msg = nlmsg_convert(hdr);
		if (!msg) {
			err = -NLE_NOMEM;
			goto out;
		}

		nlmsg_set_proto(msg, sk->s_proto);
		nlmsg_set_src(msg, &nla);
		if (creds)
			nlmsg_set_creds(msg, creds);

		/* Raw callback is the first, it gives the most control
		 * to the user and he can do his very own parsing. */
		if (cb->cb_set[NL_CB_MSG_IN])
			NL_CB_CALL(cb, NL_CB_MSG_IN, msg);

		/* Sequence number checking. The check may be done by
		 * the user, otherwise a very simple check is applied
		 * enforcing strict ordering */
		if (cb->cb_set[NL_CB_SEQ_CHECK]) {
			NL_CB_CALL(cb, NL_CB_SEQ_CHECK, msg);

		/* Only do sequence checking if auto-ack mode is enabled */
		} else if (!(sk->s_flags & NL_NO_AUTO_ACK)) {
			if (hdr->nlmsg_seq != sk->s_seq_expect) {
				if (cb->cb_set[NL_CB_INVALID])
					NL_CB_CALL(cb, NL_CB_INVALID, msg);
				else {
					err = -NLE_SEQ_MISMATCH;
					goto out;
				}
			}
		}

		if (hdr->nlmsg_type == NLMSG_DONE ||
		    hdr->nlmsg_type == NLMSG_ERROR ||
		    hdr->nlmsg_type == NLMSG_NOOP ||
		    hdr->nlmsg_type == NLMSG_OVERRUN) {
			/* We can't check for !NLM_F_MULTI since some netlink
			 * users in the kernel are broken. */
			sk->s_seq_expect++;
			NL_DBG(3, "recvmsgs(%p): Increased expected " \
			       "sequence number to %d\n",
			       sk, sk->s_seq_expect);
		}

		if (hdr->nlmsg_flags & NLM_F_MULTI)
			multipart = 1;

		if (hdr->nlmsg_flags & NLM_F_DUMP_INTR) {
			if (cb->cb_set[NL_CB_DUMP_INTR])
				NL_CB_CALL(cb, NL_CB_DUMP_INTR, msg);
			else {
				/*
				 * We have to continue reading to clear
				 * all messages until a NLMSG_DONE is
				 * received and report the inconsistency.
				 */
				interrupted = 1;
			}
		}
	
		/* Other side wishes to see an ack for this message */
		if (hdr->nlmsg_flags & NLM_F_ACK) {
			if (cb->cb_set[NL_CB_SEND_ACK])
				NL_CB_CALL(cb, NL_CB_SEND_ACK, msg);
			else {
				/* FIXME: implement */
			}
		}

		/* messages terminates a multpart message, this is
		 * usually the end of a message and therefore we slip
		 * out of the loop by default. the user may overrule
		 * this action by skipping this packet. */
		if (hdr->nlmsg_type == NLMSG_DONE) {
			multipart = 0;
			if (cb->cb_set[NL_CB_FINISH])
				NL_CB_CALL(cb, NL_CB_FINISH, msg);
		}

		/* Message to be ignored, the default action is to
		 * skip this message if no callback is specified. The
		 * user may overrule this action by returning
		 * NL_PROCEED. */
		else if (hdr->nlmsg_type == NLMSG_NOOP) {
			if (cb->cb_set[NL_CB_SKIPPED])
				NL_CB_CALL(cb, NL_CB_SKIPPED, msg);
			else
				goto skip;
		}

		/* Data got lost, report back to user. The default action is to
		 * quit parsing. The user may overrule this action by retuning
		 * NL_SKIP or NL_PROCEED (dangerous) */
		else if (hdr->nlmsg_type == NLMSG_OVERRUN) {
			if (cb->cb_set[NL_CB_OVERRUN])
				NL_CB_CALL(cb, NL_CB_OVERRUN, msg);
			else {
				err = -NLE_MSG_OVERFLOW;
				goto out;
			}
		}

		/* Message carries a nlmsgerr */
		else if (hdr->nlmsg_type == NLMSG_ERROR) {
			struct nlmsgerr *e = nlmsg_data(hdr);

			if (hdr->nlmsg_len < nlmsg_size(sizeof(*e))) {
				/* Truncated error message, the default action
				 * is to stop parsing. The user may overrule
				 * this action by returning NL_SKIP or
				 * NL_PROCEED (dangerous) */
				if (cb->cb_set[NL_CB_INVALID])
					NL_CB_CALL(cb, NL_CB_INVALID, msg);
				else {
					err = -NLE_MSG_TRUNC;
					goto out;
				}
			} else if (e->error) {
				/* Error message reported back from kernel. */
				if (cb->cb_err) {
					err = cb->cb_err(&nla, e,
							   cb->cb_err_arg);
					if (err < 0)
						goto out;
					else if (err == NL_SKIP)
						goto skip;
					else if (err == NL_STOP) {
						err = -nl_syserr2nlerr(e->error);
						goto out;
					}
				} else {
					err = -nl_syserr2nlerr(e->error);
					goto out;
				}
			} else if (cb->cb_set[NL_CB_ACK])
				NL_CB_CALL(cb, NL_CB_ACK, msg);
		} else {
			/* Valid message (not checking for MULTIPART bit to
			 * get along with broken kernels. NL_SKIP has no
			 * effect on this.  */
			if (cb->cb_set[NL_CB_VALID])
				NL_CB_CALL(cb, NL_CB_VALID, msg);
		}
skip:
		err = 0;
		hdr = nlmsg_next(hdr, &n);
	}
	
	nlmsg_free(msg);
	free(buf);
	free(creds);
	buf = NULL;
	msg = NULL;
	creds = NULL;

	if (multipart) {
		/* Multipart message not yet complete, continue reading */
		goto continue_reading;
	}
stop:
	err = 0;
out:
	nlmsg_free(msg);
	free(buf);
	free(creds);

	if (interrupted)
		err = -NLE_DUMP_INTR;

	return err;
}
static int rtnl_raw_parse_cb(struct nl_msg *msg, void *arg)
{
	struct nl_lookup_arg *lookup_arg = (struct nl_lookup_arg *)arg;
	struct usnic_rtnl_sk *unlsk = lookup_arg->unlsk;
	struct nlmsghdr	*nlm_hdr = nlmsg_hdr(msg);
	struct rtmsg *rtm;
	struct nlattr *tb[RTA_MAX + 1];
	int found = 0;
	int err;

#if WANT_DEBUG_MSGS
	nl_msg_dump(msg, stderr);
#endif /* WANT_DEBUG_MSGS */

	lookup_arg->nh_addr	= 0;
	lookup_arg->found 	= 0;
	lookup_arg->replied	= 0;

	if (nlm_hdr->nlmsg_pid != nl_socket_get_local_port(unlsk->sock)
		|| nlm_hdr->nlmsg_seq != unlsk->seq) {
		usnic_err("Not an expected reply msg pid: %u local pid: %u "
				"msg seq: %u expected seq: %u\n",
				nlm_hdr->nlmsg_pid, nl_socket_get_local_port(unlsk->sock),
				nlm_hdr->nlmsg_seq, unlsk->seq);
		return NL_SKIP;
	}
	lookup_arg->replied = 1;

	if (nlm_hdr->nlmsg_type == NLMSG_ERROR) {
		struct nlmsgerr *e = (struct nlmsgerr *)nlmsg_data(nlm_hdr);
		if (nlm_hdr->nlmsg_len >= (__u32)nlmsg_size(sizeof(*e))) {
			usnic_err("Received a netlink error message %d\n",
					e->error);
		}
		else {
			usnic_err("Received a truncated netlink error message\n");
		}
		return NL_STOP;
	}

	if (nlm_hdr->nlmsg_type != RTM_NEWROUTE) {
		usnic_err("Received an invalid route request reply message\n");
		return NL_STOP;
	}

	rtm = nlmsg_data(nlm_hdr);
	if (rtm->rtm_family != AF_INET) {
		usnic_err("RTM message contains invalid AF family\n");
		return NL_STOP;
	}

        init_route_policy(route_policy);
	err = nlmsg_parse(nlm_hdr, sizeof(struct rtmsg), tb, RTA_MAX,
			  route_policy);
	if (err < 0) {
		usnic_err("nlmsg parse error %d\n", err);
		return NL_STOP;
	}

	if (tb[RTA_OIF]) {
		if (nla_get_u32(tb[RTA_OIF]) == (uint32_t)lookup_arg->oif)
			found = 1;
		else
			usnic_err("Retrieved route has a different outgoing interface %d (expected %d)\n",
					nla_get_u32(tb[RTA_OIF]),
					lookup_arg->oif);
	}

        if (found && tb[RTA_METRICS]) {
            lookup_arg->metric = (int)nla_get_u32(tb[RTA_METRICS]);
        }

	if (found && tb[RTA_GATEWAY])
		lookup_arg->nh_addr = nla_get_u32(tb[RTA_GATEWAY]);

	lookup_arg->found = found;
	return NL_STOP;
}