static size_t atomicio(struct socket *s, void *buf, size_t n, int do_read) { char *b = buf; size_t pos = 0; ssize_t res; ssh_pollfd_t pfd; socket_t fd = ssh_socket_get_fd(s); pfd.fd = fd; pfd.events = do_read ? POLLIN : POLLOUT; while (n > pos) { if (do_read) { res = read(fd, b + pos, n - pos); } else { res = write(fd, b + pos, n - pos); } switch (res) { case -1: if (errno == EINTR) { continue; } #ifdef EWOULDBLOCK if (errno == EAGAIN || errno == EWOULDBLOCK) { #else if (errno == EAGAIN) { #endif (void) ssh_poll(&pfd, 1, -1); continue; } return 0; case 0: errno = EPIPE; return pos; default: pos += (size_t) res; } } return pos; } ssh_agent agent_new(struct ssh_session_struct *session) { ssh_agent agent = NULL; agent = malloc(sizeof(struct ssh_agent_struct)); if (agent == NULL) { return NULL; } ZERO_STRUCTP(agent); agent->count = 0; agent->sock = ssh_socket_new(session); if (agent->sock == NULL) { SAFE_FREE(agent); return NULL; } return agent; }
static size_t atomicio(struct ssh_agent_struct *agent, void *buf, size_t n, int do_read) { char *b = buf; size_t pos = 0; ssize_t res; ssh_pollfd_t pfd; ssh_channel channel = agent->channel; socket_t fd; /* Using a socket ? */ if (channel == NULL) { fd = ssh_socket_get_fd(agent->sock); pfd.fd = fd; pfd.events = do_read ? POLLIN : POLLOUT; while (n > pos) { if (do_read) { res = read(fd, b + pos, n - pos); } else { res = write(fd, b + pos, n - pos); } switch (res) { case -1: if (errno == EINTR) { continue; } #ifdef EWOULDBLOCK if (errno == EAGAIN || errno == EWOULDBLOCK) { #else if (errno == EAGAIN) { #endif (void) ssh_poll(&pfd, 1, -1); continue; } return 0; case 0: /* read returns 0 on end-of-file */ errno = do_read ? 0 : EPIPE; return pos; default: pos += (size_t) res; } } return pos; } else { /* using an SSH channel */ while (n > pos){
/** * @brief SSH poll callback. This callback will be used when an event * caught on the socket. * * @param p Poll object this callback belongs to. * @param fd The raw socket. * @param revents The current poll events on the socket. * @param userdata Userdata to be passed to the callback function, * in this case the socket object. * * @return 0 on success, < 0 when the poll object has been removed * from its poll context. */ int ssh_socket_pollcallback(struct ssh_poll_handle_struct *p, socket_t fd, int revents, void *v_s) { ssh_socket s = (ssh_socket)v_s; char buffer[MAX_BUF_SIZE]; ssize_t nread; int rc; int err = 0; socklen_t errlen = sizeof(err); /* Do not do anything if this socket was already closed */ if (!ssh_socket_is_open(s)) { return -1; } SSH_LOG(SSH_LOG_TRACE, "Poll callback on socket %d (%s%s%s), out buffer %d",fd, (revents & POLLIN) ? "POLLIN ":"", (revents & POLLOUT) ? "POLLOUT ":"", (revents & POLLERR) ? "POLLERR":"", ssh_buffer_get_len(s->out_buffer)); if ((revents & POLLERR) || (revents & POLLHUP)) { /* Check if we are in a connecting state */ if (s->state == SSH_SOCKET_CONNECTING) { s->state = SSH_SOCKET_ERROR; rc = getsockopt(fd, SOL_SOCKET, SO_ERROR, (char *)&err, &errlen); if (rc < 0) { err = errno; } s->last_errno = err; ssh_socket_close(s); if (s->callbacks != NULL && s->callbacks->connected != NULL) { s->callbacks->connected(SSH_SOCKET_CONNECTED_ERROR, err, s->callbacks->userdata); } return -1; } /* Then we are in a more standard kind of error */ /* force a read to get an explanation */ revents |= POLLIN; } if ((revents & POLLIN) && s->state == SSH_SOCKET_CONNECTED) { s->read_wontblock = 1; nread = ssh_socket_unbuffered_read(s, buffer, sizeof(buffer)); if (nread < 0) { if (p != NULL) { ssh_poll_remove_events(p, POLLIN); } if (s->callbacks != NULL && s->callbacks->exception != NULL) { s->callbacks->exception(SSH_SOCKET_EXCEPTION_ERROR, s->last_errno, s->callbacks->userdata); } return -2; } if (nread == 0) { if (p != NULL) { ssh_poll_remove_events(p, POLLIN); } if (s->callbacks != NULL && s->callbacks->exception != NULL) { s->callbacks->exception(SSH_SOCKET_EXCEPTION_EOF, 0, s->callbacks->userdata); } return -2; } if (s->session->socket_counter != NULL) { s->session->socket_counter->in_bytes += nread; } /* Bufferize the data and then call the callback */ rc = ssh_buffer_add_data(s->in_buffer, buffer, nread); if (rc < 0) { return -1; } if (s->callbacks != NULL && s->callbacks->data != NULL) { do { nread = s->callbacks->data(ssh_buffer_get(s->in_buffer), ssh_buffer_get_len(s->in_buffer), s->callbacks->userdata); ssh_buffer_pass_bytes(s->in_buffer, nread); } while ((nread > 0) && (s->state == SSH_SOCKET_CONNECTED)); /* p may have been freed, so don't use it * anymore in this function */ p = NULL; } } #ifdef _WIN32 if (revents & POLLOUT || revents & POLLWRNORM) { #else if (revents & POLLOUT) { #endif uint32_t len; /* First, POLLOUT is a sign we may be connected */ if (s->state == SSH_SOCKET_CONNECTING) { SSH_LOG(SSH_LOG_PACKET, "Received POLLOUT in connecting state"); s->state = SSH_SOCKET_CONNECTED; if (p != NULL) { ssh_poll_set_events(p, POLLOUT | POLLIN); } rc = ssh_socket_set_blocking(ssh_socket_get_fd(s)); if (rc < 0) { return -1; } if (s->callbacks != NULL && s->callbacks->connected != NULL) { s->callbacks->connected(SSH_SOCKET_CONNECTED_OK, 0, s->callbacks->userdata); } return 0; } /* So, we can write data */ s->write_wontblock = 1; if (p != NULL) { ssh_poll_remove_events(p, POLLOUT); } /* If buffered data is pending, write it */ len = ssh_buffer_get_len(s->out_buffer); if (len > 0) { ssh_socket_nonblocking_flush(s); } else if (s->callbacks != NULL && s->callbacks->controlflow != NULL) { /* Otherwise advertise the upper level that write can be done */ SSH_LOG(SSH_LOG_TRACE,"sending control flow event"); s->callbacks->controlflow(SSH_SOCKET_FLOW_WRITEWONTBLOCK, s->callbacks->userdata); } /* TODO: Find a way to put back POLLOUT when buffering occurs */ } /* Return -1 if the poll handler disappeared */ if (s->poll_handle == NULL) { return -1; } return 0; } /** @internal * @brief returns the poll handle corresponding to the socket, * creates it if it does not exist. * @returns allocated and initialized ssh_poll_handle object */ ssh_poll_handle ssh_socket_get_poll_handle(ssh_socket s) { if (s->poll_handle) { return s->poll_handle; } s->poll_handle = ssh_poll_new(s->fd,0,ssh_socket_pollcallback,s); return s->poll_handle; } /** \internal * \brief Deletes a socket object */ void ssh_socket_free(ssh_socket s){ if (s == NULL) { return; } ssh_socket_close(s); ssh_buffer_free(s->in_buffer); ssh_buffer_free(s->out_buffer); SAFE_FREE(s); } #ifndef _WIN32 int ssh_socket_unix(ssh_socket s, const char *path) { struct sockaddr_un sunaddr; socket_t fd; sunaddr.sun_family = AF_UNIX; snprintf(sunaddr.sun_path, sizeof(sunaddr.sun_path), "%s", path); fd = socket(AF_UNIX, SOCK_STREAM, 0); if (fd == SSH_INVALID_SOCKET) { ssh_set_error(s->session, SSH_FATAL, "Error from socket(AF_UNIX, SOCK_STREAM, 0): %s", strerror(errno)); return -1; } if (fcntl(fd, F_SETFD, 1) == -1) { ssh_set_error(s->session, SSH_FATAL, "Error from fcntl(fd, F_SETFD, 1): %s", strerror(errno)); close(fd); return -1; } if (connect(fd, (struct sockaddr *) &sunaddr, sizeof(sunaddr)) < 0) { ssh_set_error(s->session, SSH_FATAL, "Error from connect(): %s", strerror(errno)); close(fd); return -1; } ssh_socket_set_fd(s,fd); return 0; } #endif /** \internal * \brief closes a socket */ void ssh_socket_close(ssh_socket s){ if (ssh_socket_is_open(s)) { #ifdef _WIN32 CLOSE_SOCKET(s->fd); s->last_errno = WSAGetLastError(); #else CLOSE_SOCKET(s->fd); s->last_errno = errno; #endif } if(s->poll_handle != NULL){ ssh_poll_free(s->poll_handle); s->poll_handle=NULL; } s->state = SSH_SOCKET_CLOSED; } /** * @internal * @brief sets the file descriptor of the socket. * @param[out] s ssh_socket to update * @param[in] fd file descriptor to set * @warning this function updates boths the input and output * file descriptors */ void ssh_socket_set_fd(ssh_socket s, socket_t fd) { s->fd = fd; if (s->poll_handle) { ssh_poll_set_fd(s->poll_handle,fd); } else { s->state = SSH_SOCKET_CONNECTING; /* POLLOUT is the event to wait for in a nonblocking connect */ ssh_poll_set_events(ssh_socket_get_poll_handle(s), POLLOUT); #ifdef _WIN32 ssh_poll_add_events(ssh_socket_get_poll_handle(s), POLLWRNORM); #endif } }