Michele Orrù 7 years ago
parent
commit
fd2b28b0f1
7 changed files with 53 additions and 80 deletions
  1. 5 3
      configure.ac
  2. 7 5
      src/Makefile.am
  3. 1 4
      src/hss.c
  4. 1 6
      src/hss.h
  5. 8 18
      src/rms.c
  6. 30 43
      ver1.c
  7. 1 1
      ver2.c

+ 5 - 3
configure.ac

@@ -22,6 +22,7 @@ AC_GNU_SOURCE
 # Checks for header files.
 AC_CHECK_HEADERS([limits.h stdint.h stdlib.h string.h unistd.h])
 AC_CHECK_SIZEOF(mp_limb_t, 8, [#include <gmp.h>])
+AC_CHECK_SIZEOF(uint32_t, 4, [#include <stdint.h>])
 
 # Checks for typedefs, structures, and compiler characteristics.
 AC_TYPE_SIZE_T
@@ -31,8 +32,8 @@ AC_FUNC_ERROR_AT_LINE
 AC_FUNC_MALLOC
 #AC_CHECK_FUNCS([dup2 setlocale strdup])
 
-# Add compiler/linker flags
-CFLAGS+=" -O3 --std=c99 -Wall --pedantic -march=native -DNDEBUG"
+# Clear out compiler/linker flags
+CFLAGS=" -pedantic -Wall "
 
 # Shut up automake
 #AM_SILENT_RULES([yes])
@@ -42,7 +43,8 @@ AC_SUBST([AM_MAKEFLAGS], [--no-print-directory])
 
 AC_ARG_ENABLE(debug,
    AS_HELP_STRING([--enable-debug], [enable debugging, default: no]),
-   CFLAGS+=" -UNDEBUG -O0 -ggdb")
+   CFLAGS+=" -UNDEBUG -O0 -ggdb",
+   CFLAGS+=" -DNDEBUG -O3 -march=native")
 
 AC_OUTPUT([Makefile
            src/Makefile

+ 7 - 5
src/Makefile.am

@@ -1,9 +1,11 @@
 bin_PROGRAMS = rms
+#check_programs = test_ssl1
 
-
-HSS = hss.c hss.h
-ENTROPY = entropy.c entropy.h
+DDLOG = ddlog.c ddlog.h
 ELGAMAL = elgamal.c elgamal.h
+ENTROPY = entropy.c entropy.h
+HSS = hss.c hss.h
+#TESTS = $(check_programs)
 
-rms_SOURCES = rms.c $(HSS) $(ENTROPY) $(ELGAMAL)
-rms_LDADD = -lgmp
+#test_ssl1_SOURCES = test_ssl1.c
+rms_SOURCES = rms.c $(HSS) $(ENTROPY) $(ELGAMAL) $(DDLOG)

+ 1 - 4
src/hss.c

@@ -37,7 +37,7 @@ void hss_del()
 }
 
 
-void fbprecompute(fbptable_t T, const mpz_t base)
+void fbprecompute(mpz_t T[4][256], const mpz_t base)
 {
   for (size_t j = 0; j < 4; j++) {
     for (size_t i = 0; i <= 0xFF; i++) {
@@ -72,9 +72,6 @@ void ssl1_share(ssl1_t r1, ssl1_t r2, const mpz_t v, const elgamal_key_t key)
   mpz_init_set_ui(zero, 0);
 
   elgamal_encrypt_shares(r1->w, r2->w, key, v);
-  //  s->T = (fbptable_t) calloc(sizeof(fbptable_entry_t), 256 * 4);
-  fbprecompute(r1->T, r1->w->c2);
-  fbprecompute(r2->T, r2->w->c2);
 
   for (size_t t = 0; t < 160; t++) {
     if (mpz_tstbit(key->sk, 159-t)) {

+ 1 - 6
src/hss.h

@@ -21,15 +21,9 @@ void hss_del();
  *  plus the encryption of the product of each bit.
  */
 
-//typedef __mpz_struct* fbptable_entry_t;
-//typedef fbptable_entry_t (* fbptable_t)[256];
-typedef const mpz_t (* const_fbptable_t)[256];
-typedef mpz_t (* fbptable_t)[256];
-
 typedef struct ssl1 {
   elgamal_cipher_t w;
   elgamal_cipher_t cw[160];
-  mpz_t T[4][256];
 } ssl1_t[1];
 
 
@@ -51,3 +45,4 @@ void ssl2_init(ssl2_t s);
 void ssl2_clear(ssl2_t s);
 void ssl2_share(ssl2_t s1, ssl2_t s2, const mpz_t v, const mpz_t sk);
 void ssl2_open(mpz_t rop, const ssl2_t s1, const ssl2_t s2);
+void fbprecompute(mpz_t T[4][256], const mpz_t base);

+ 8 - 18
src/rms.c

@@ -16,7 +16,7 @@
 
 
 static inline
-void fbpowm(mpz_t rop, const_fbptable_t T, const uint32_t exp)
+void fbpowm(mpz_t rop, const mpz_t T[4][256], const uint32_t exp)
 {
   const uint8_t *e = (uint8_t *) &exp;
 
@@ -34,7 +34,6 @@ uint32_t __mul_single(mpz_t op1,
                       mpz_t op2,
                       const mpz_t c1,
                       const mpz_t c2,
-                      const_fbptable_t T,
                       const uint32_t x,
                       const mpz_t cx)
 {
@@ -42,12 +41,6 @@ uint32_t __mul_single(mpz_t op1,
   mpz_powm(op1, c1, cx, p);
   mpz_invert(op1, op1, p);
 
-  //mpz_t test; mpz_init(test);
-  //mpz_powm_ui(test, c2, x, p);
-  fbpowm(op2, T, x);
-  //if (mpz_cmp(test, op2)) gmp_printf("base: %Zx\nexp: %x\npcomp: %Zx\nreal: %Zd\n", c2, x, op2, test);
-  //mpz_clear(test);
-
   mpz_powm_ui(op2, c2, x, p);
   mpz_mul(op2, op2, op1);
   mpz_mod(op2, op2, p);
@@ -58,25 +51,22 @@ uint32_t __mul_single(mpz_t op1,
 void hss_mul(ssl2_t rop, const ssl1_t sl1, const ssl2_t sl2)
 {
   mpz_t op1, op2;
-  uint32_t converted;
   mpz_inits(op1, op2, NULL);
 
-  converted = __mul_single(op1, op2,
+  rop->x = __mul_single(op1, op2,
                            sl1->w->c1,
                            sl1->w->c2,
-                           sl1->T,
                            sl2->x,
                            sl2->cx);
-  rop->x = converted;
 
   mpz_set_ui(rop->cx, 0);
   for (size_t t = 0; t < 160; t++) {
-    converted = __mul_single(op1, op2,
-                             sl1->cw[t]->c1,
-                             sl1->cw[t]->c2,
-                             sl1->T,
-                             sl2->x,
-                             sl2->cx);
+    const uint32_t converted =
+      __mul_single(op1, op2,
+                   sl1->cw[t]->c1,
+                   sl1->cw[t]->c2,
+                   sl2->x,
+                   sl2->cx);
     mpz_add_ui(rop->cx, rop->cx, converted);
     mpz_mul_2exp(rop->cx, rop->cx, 1);
   }

+ 30 - 43
ver1.c

@@ -9,6 +9,7 @@
 #include <linux/random.h>
 #include <sys/syscall.h>
 #include <sys/time.h>
+#include <immintrin.h>
 
 #include <gmp.h>
 
@@ -57,49 +58,44 @@ getrandom(void *buffer, size_t length, unsigned int flags)
 
 INIT_TIMEIT();
 static const uint32_t strip_size = 16;
+#define halfstrip_size (strip_size/2)
 static uint8_t lookup[256];
 static uint8_t offset[256];
 
 uint32_t convert(uint64_t *nn)
 {
-  uint32_t steps = 0;
-  /**
-   * Here we start making a bunch of assumptions.
-   * First, we assume the "w" here is 64 bits, which should be the size
-   * (in bits) of a mp_limb_t.
-   * Secondly, the amount of zeros to check, "d" here is 8.
-   */
-  static const uint32_t window = (32+30) /8;
-  mpz_t a;
-  mpz_init(a);
-#define distinguished(x) (((x)[23] & (~(ULLONG_MAX >> strip_size)))) == 0
-  while (!distinguished(nn)) {
-    /**
-     * Here we try to find a strip of zeros for "w2" bits.
-     * When we find one (up to w2 = 64), then we jump of w = w/2.
-     * I tried to optimize this code:
-     * - by integrating the if statement above with the for loop invariant;
-     * - by making the loop algebraic (i.e. no if-s), given that in the
-     *   generated assembly I read a lot of jumps.
-     * Unfortunately, both approaches actually lead to a slow down in the code.
-     */
+  assert(strip_size == 16);
+  uint32_t steps;
+  static const uint32_t window = 7;
+
+#define distinguished(x) ((x)[23] & ~(ULLONG_MAX >> strip_size)) == 0
+  const __m128i rotmask = _mm_set_epi8(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1);
+  for (steps = 0; !distinguished(nn); steps += window*8) {
     START_TIMEIT();
-    const uint8_t *x = (uint8_t *) &nn[22];
-    uint8_t *y = memrchr(x, '\0', window*2-1);
 
-    if (y  && y[-1] <= lookup[y[+1]]) {
-      return steps + (x + 15 - y)*8 - offset[y[+1]];
+    __m128i x = _mm_lddqu_si128((__m128i *) (nn + 22));
+    __m128i mask = _mm_set_epi8(0xFF, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
+
+    for (int32_t i = 14; i > 0; --i) {
+      mask = _mm_shuffle_epi8(mask, rotmask);
+        const bool zero = _mm_testz_si128(mask, x);
+        if (zero) {
+          const uint8_t previous = _mm_extract_epi8(x, i-1);
+          const uint8_t next = _mm_extract_epi8(x, i+1);
+          if (previous <= lookup[next]) {
+            END_TIMEIT();
+            return steps + (15-i)*8 - offset[next];
+          }
+        }
     }
-
     END_TIMEIT();
+
     /**
      * We found no distinguished point.
      */
     const uint64_t a = mpn_lshift(nn, nn, 24, window) * gg;
     mpn_add_1(nn, nn, 24, a);
-    steps += window;
   }
-  mpz_clear(a);
   return steps;
 }
 
@@ -109,7 +105,7 @@ uint32_t naif_convert(mpz_t n)
   uint32_t i;
   mpz_t t;
   mpz_init_set_ui(t, 1);
-  mpz_mul_2exp(t, t, 1536-16);
+  mpz_mul_2exp(t, t, 1536-strip_size);
 
 
   for (i = 0; mpz_cmp(n, t) > -1; i++) {
@@ -139,11 +135,11 @@ int main()
     offset[i] = j;
   }
 
-  uint32_t converted;
-  for (uint64_t i=0; i < 1e4; i++) {
-    mpz_t n, n0;
-    mpz_inits(n, n0, NULL);
+  mpz_t n, n0;
+  mpz_inits(n, n0, NULL);
 
+  uint32_t converted;
+  for (uint64_t i=0; i < 1e3; i++) {
     mpz_urandomm(n0, _rstate, p);
     mpz_set(n, n0);
     converted = convert(n->_mp_d);
@@ -153,19 +149,10 @@ int main()
     if (converted != expected) printf("%d %d\n", converted, expected);
     assert(converted == expected);
 #endif
-    mpz_clears(n, n0, NULL);
   }
   printf(TIMEIT_FORMAT "\n", GET_TIMEIT());
 
-
-  /* memset(n->_mp_d, 0, 24*8); */
-  /* memset(n0->_mp_d, 0, 24*8); */
-  /* n0->_mp_d[0] = 13423523; */
-  /* n0->_mp_d[1] = 1; */
-  /* uint64_t v[64] = {0}; */
-  /* unpack(v, n->_mp_d); */
-  /* pack(n0, v); */
-
+  mpz_clears(n, n0, NULL);
   mpz_clears(p, g, NULL);
   return 0;
 

+ 1 - 1
ver2.c

@@ -150,7 +150,7 @@ int main()
 
   INIT_TIMEIT();
   uint32_t converted;
-  for (int i=0; i < 1e5; i++) {
+  for (int i=0; i < 24e3; i++) {
     mpz_urandomm(n0, _rstate, p);
     mpz_set(n, n0);
     START_TIMEIT();