Bladeren bron

Re-implementing timeouts using select().

Fuck you, baidu.
Michele Orrù 11 jaren geleden
bovenliggende
commit
17893b2fc2
1 gewijzigde bestanden met toevoegingen van 54 en 28 verwijderingen
  1. 54 28
      src/qa_sock.c

+ 54 - 28
src/qa_sock.c

@@ -4,9 +4,13 @@
 #include <stdlib.h>
 #include <strings.h>
 #include <unistd.h>
+#include <fcntl.h>
+#include <error.h>
 #include <errno.h>
+#include <time.h>
 
 #include <sys/types.h>
+#include <sys/select.h>
 #include <sys/socket.h>
 #include <netdb.h>
 
@@ -16,7 +20,9 @@
 #include "qa/qa.h"
 #include "qa/qa_sock.h"
 
-#define TIMEOUT 5
+#define TIMEOUT_SEC 1
+#define TIMEOUT_USEC 0
+
 #define SOCKET_PROTOCOL 0
 #define INVALID_SOCKET  (-1)
 
@@ -77,11 +83,9 @@ int host_port(char *uri, char **host, char **service)
 int init_client(const char *host, const char *port)
 {
   int s, i;
+  fd_set socket_fds;
   struct addrinfo *result, *rp;
-  struct timeval timeout = {
-    .tv_sec = TIMEOUT,
-    .tv_usec = 0
-  };
+  struct timeval timeout;
 
   if ((i=getaddrinfo(host, port, NULL, &result))) {
     BIO_printf(bio_err, "Error: %s\n", gai_strerror(i));
@@ -98,17 +102,28 @@ int init_client(const char *host, const char *port)
        i = 0;
        i = setsockopt(s, SOL_SOCKET, SO_KEEPALIVE, (char*) &i, sizeof(i));
        if (i < 0) return -1;
-       i = setsockopt(s, SOL_SOCKET, SO_RCVTIMEO,
-                      (char *) &timeout, sizeof(struct timeval));
-       if (i < 0) return -1;
     }
 
+    //Set the socket to non-blocking
+    int flags = fcntl(s, F_GETFL, 0);
+    fcntl(s, F_SETFL, flags | O_NONBLOCK);
 
-    if (connect(s, rp->ai_addr, rp->ai_addrlen) != -1) break;
-  }
+    connect(s, rp->ai_addr, rp->ai_addrlen);
+    if (errno != EINPROGRESS) {
+      close(s);
+      continue;
+    }
 
-  if (!rp) return -1;
+    FD_ZERO(&socket_fds);
+    FD_SET(s, &socket_fds);
+    timeout.tv_sec = TIMEOUT_SEC;
+    timeout.tv_usec = TIMEOUT_USEC;
+    i = select(s+1, NULL, &socket_fds, NULL, &timeout);
+    if (i > 0) break;
 
+    close(s);
+  }
+  if (!rp) return -1;
   return s;
 }
 
@@ -169,11 +184,11 @@ static int verify_callback(int ok, X509_STORE_CTX* ctx)
 struct qa_connection* qa_connection_new(char* address)
 {
   struct qa_connection* c = NULL;
-  struct timeval timeout = {
-    .tv_sec = TIMEOUT,
-    .tv_usec = 0
-  };
   char *host, *port;
+  int attempts;
+  int err;
+  fd_set socket_fds;
+  struct timeval timeout;
 
   /* parse input address */
   if (!host_port(address, &host, &port)) goto error;
@@ -182,25 +197,36 @@ struct qa_connection* qa_connection_new(char* address)
   c = calloc(1, sizeof(struct qa_connection));
   if (!c) goto error;
   /* set up context, and protocol versions */
-  c->ctx = SSL_CTX_new(SSLv23_client_method());
+  c->ctx = SSL_CTX_new(TLSv1_client_method());
   if (!c->ctx) goto error;
   /* create the ssl session, disabling certificate verification */
   SSL_CTX_set_verify(c->ctx, SSL_VERIFY_NONE, verify_callback);
   c->ssl = SSL_new(c->ctx);
+  SSL_set_connect_state(c->ssl);
+
   if (!c->ssl) goto error;
   /* open the socket over ssl */
   c->socket = init_client(host, port);
-  c->sbio = BIO_new_dgram(c->socket, BIO_NOCLOSE);
-  // BIO_ctrl(c->sbio, BIO_CTRL_DGRAM_SET_SEND_TIMEOUT, 0, &timeout);
-  BIO_ctrl(c->sbio, BIO_CTRL_DGRAM_SET_RECV_TIMEOUT, 0, &timeout);
-
-  if (c->socket == -1)
-    goto error;
-  SSL_set_bio(c->ssl, c->sbio, c->sbio);
-  if (SSL_connect(c->ssl) != 1)
-    goto error;
-  SSL_set_connect_state(c->ssl);
-  return c;
+  if (c->socket == -1)  goto error;
+  if (!SSL_set_fd(c->ssl, c->socket)) goto error;
+
+  FD_ZERO(&socket_fds);
+  FD_SET(c->socket, &socket_fds);
+  for(attempts = 10; attempts; attempts--) {
+     err = SSL_do_handshake(c->ssl);
+    // err = SSL_connect(c->ssl);
+    if (err == 1) return c;
+
+    err = SSL_get_error(c->ssl, err);
+    timeout.tv_sec = TIMEOUT_SEC;
+    timeout.tv_usec = TIMEOUT_USEC;
+    if (err == SSL_ERROR_WANT_WRITE)
+      select(c->socket+1, NULL, &socket_fds, NULL, &timeout);
+    else if (err == SSL_ERROR_WANT_READ)
+      select(c->socket+1, &socket_fds, NULL, NULL, &timeout);
+    else
+      goto error;
+  }
 
  error:
   /* XXX. add checks for errno, and the ssl error stack (ssl_get_error) */
@@ -220,7 +246,7 @@ struct qa_connection* qa_connection_new(char* address)
  */
 X509* get_remote_cert(char *address)
 {
-  X509 *crt;
+  X509 *crt = NULL;
   qa_connection_t *c;
 
   c = qa_connection_new(address);