import java.util.Random;

/**

  This program illustrates how to work in the
  finite fields GF(2<sup>n</sup>), and can be used to
  find irreducible polynomials in GF(2<sup>n</sup>)
  for n less than 64. 
  <p>
  The techniques used here are described in
  "Probabilistic Algorithms in Finite Fields"
  by Michael O. Rabin, <i>SIAM Journal of Computing</i>,
  May 1980.


  @author David Chase chase@naturalbridge.com

  */

public class TwoPolynomials {

    static long [][] primeFactorsOfTwoToTheNMinusOneCache = new long[64][];

    static long[] primeFactorsOfTwoToTheNMinusOne(long x) {
        int ix = (int) x;
        if (primeFactorsOfTwoToTheNMinusOneCache[ix] == null)
            primeFactorsOfTwoToTheNMinusOneCache[ix] = PF2.factors2ttxm1(x);
        return primeFactorsOfTwoToTheNMinusOneCache[ix];
    }

    /**
       Returns the number of set (1) bits in x.
     */
    public static int populationCount(long x) {
        return populationCount((int) x) + populationCount((int) (x >>> 32));
    }
    public static int populationCount(int x) {
        int x1 = x & 0x55555555;
        int x2 = (x & 0xaaaaaaaa) >>> 1;
        int x3 = x1 + x2;
        // Each pair of bits in x3 is 0-2

        int x4 = x3 & 0x33333333;
        int x5 = (x3 & 0xcccccccc) >>> 2;
        int x6 = x4 + x5;
        // Each quad of bits in x6 is 0-4

        int x7 = x6 & 0x0f0f0f0f;
        int x8 = (x6 & 0xf0f0f0f0) >>> 4;
        int x9 = x7 + x8;
        // Each octet of bits in x9 is 0-8;

        int x10 = x9 & 0x00ff00ff;
        int x11 = (x9 & 0xff00ff00) >>> 8;
        int x12 = x10 + x11;
        // Each half of x12 is 0-16;

        int x13 = x12 & 0x0000ffff;
        int x14 = x12 >>> 16;

        return x13 + x14;
    }

    /**
      <code>countHighZeros(x)</code> returns the number of
      contiguous zeros found at the high-order (unsigned)
      end of x.
      <p>
      Thus,<pre><blockquote>
         countHighZeros(-1) = 0;
         countHighZeros(0) = 32;
         countHighZeros(1) = 31;
         countHighZeros(65535) = 16;
      </blockquote></pre>
     */

    public static int countHighZeros(int x) {
        int count = 0;
        if ((x & 0xffff0000) == 0) {
            count += 16;
            x = x << 16;
        }
        if ((x & 0xff000000) == 0) {
            count += 8;
            x = x << 8;
        }
        if ((x & 0xf0000000) == 0) {
            count += 4;
            x = x << 4;
        }
        if ((x & 0xc0000000) == 0) {
            count += 2;
            x = x << 2;
        }
        if ((x & 0x80000000) == 0) {
            count += 1;
            x = x << 1;
        }
        return count + (((x & 0x80000000) == 0) ? 1 : 0);
    }

    public static int countHighZeros(long x) {
        int xl = (int) x;
        int xh = (int) (x >>> 32);
        if (xh == 0)
            return 32 + countHighZeros(xl);
        return countHighZeros(xh);
    }

    /** 
      <code>highBit(a)</code> returns 0 if a is zero,
      otherwise it returns a rounded down to the largest
      power of two less than or equal to a.
      */

    static long highBit(long a) {
        if (a == 0)
            return 0;
        return 1L << (63 - countHighZeros(a));
    }

    static boolean isEven(long a) {
        return ((int) a & 1) == 0;
    }

    static boolean isOdd(long a) {
        return ((int) a & 1) == 1;
    }

    /**
      <code>polyMulMod(long a, long b, long m)</code>
      Multiply together two polynomials a and b,
      modulo a third polynomial m.  A and b must
      have smaller degree than m does.
    
      <p>
      Here, the polynomials have coefficients in {0,1},
      where addition is XOR and multiplication is AND.
      The set bits in a long integer correspond to the
      non-zero coefficients of the polynomial, where
      the coefficient N corresponds to the bit for 2<sup>n</sup>
      in the the integer.  Thus,
      x<sup>3</sup> + x + 1
    
      has coefficients 3, 1, 0, and value
      8 + 2 + 1 = 11.
    
      */

    public static long polyMulMod(long a, long b, long m) {
        if (a > b) {
            long t = a;
            a = b;
            b = t;
        }
        assertSmallerDegree(b, m);
        long hm = highBit(m);
        if (a == 0)
            return 0;
        if (a == 1)
            return b;
        long r = 0;
        while (a != 0) {
            if (isOdd(a))
                r = r ^ b; /* Add b to r */
            b = b + b; /* Raise the degree of b */
            if ((b & hm) != 0)
                b = b ^ m; /* Subtract out m if equal degree */
            a = a >>> 1;
        }
        return r;
    }

    /** 
      <code>pToTheNMod(long p, long n, long m)</code>
      returns the result of raising polynomial p to the
      n power, modulo polynomial m.  P must have lower
      degree than m's degree.
      <p>
      Again, these are polynomials with coefficients
      in {0,1}, + = XOR, * = AND.
      
     */

    public static long pToTheNMod(long p, long n, long m) {
        if (n == 0)
            return 1;
        if (n == 1)
            return p;
        assertSmallerDegree(p, m);
        return polyMulMod(
            isEven(n) ? 1 : p,
            pToTheNMod(polyMulMod(p, p, m), n / 2, m),
            m);
    }

    private static void assertSmallerDegree(long p, long m) {
        if (p >= m || (p & m) * 2 > m)
            throw new IllegalArgumentException(
                "p (0x"
                    + Long.toHexString(p)
                    + ") >= m(0x"
                    + Long.toHexString(m)
                    + ")");
    }

    /**
      <code>pToTheTwoToTheNMod(long p, int n, long m)</code>
      returns the result of raising polynomial p to the 2<sup>n</sup>
      power, modulo polynomial m
      */

    public static long pToTheTwoToTheNMod(long p, int n, long m) {
        long r = p;
        while (n-- > 0)
            r = polyMulMod(r, r, m);
        return r;
    }

    /**
      gcd(long p, long q) returns the gcd of the two
      polynomials p and q.
      */

    public static long gcd(long p, long q) {
        long hp = highBit(p);
        long hq = highBit(q);

        // Without loss of generality, degree(p) <= degree(q)
        if (hp > hq) {
            long t = p;
            p = q;
            q = t;
            t = hp;
            hp = hq;
            hq = t;
        }

        // gcd(0,q) == q;
        if (p == 0)
            return q;

        // n = 1 + degree(q) - degree(p)
        int n = 0;
        while ((hp << n) < hq)
            n++;

        // q = q modulo p
        while (n >= 0) {
            if (((hp << n) & q) != 0) {
                q = q ^ (p << n);
            }
            n--;
        }

        // and recur.
        return gcd(q, p);
    }

    /**
      This is Lemma 1 in Rabin's article:
      <p>
      Let L<sub>1</sub>, ..., L<sub>k</sub> be all the
      prime divisors of n and denote n/l<sub>i</sub> = m<sub>i</sub>.<br>
      A polynomial g(x) in Z<sub>p</sub>[x] is irreducible
      in Z<sub>p</sub>[x] if and only if
      <ol>
      <li> g(x) divides (x <sup>p<sup>n</sup></sup> - x),
      <li> For all i, 1 &lt;= i &lt;=k,
           gcd(g(x), x<sup>p<sup>m<sub>i</sub></sup></sup> - x) = 1.
      </ol>
      <p>
      For our purposes, recall that "p" is 2.  Also recall that
      our representation of the polynomial "x" is 2, since x == x<sup>1</sup>
      and 2 == 2<sup>1</sup>.
      
      NOTE: I think I got this wrong in the transcription somehow.
      The prime factors need to be from 2 to the N, minus one.
      */

    public static boolean isIrreducible(long gOfX) {
        int n = 63 - countHighZeros(gOfX);
        long tttnm1 = (1L << n) - 1;

        // Lemma 1, item 1, translated:
        // 
        // if g(x) divides (x-to-the-2-to-the-n minus x),
        // that is equivalent to xtt2ttn modulo g(x) equals x.
        // Or, translated into the representations of the polynomials
        // we are using, it equals the number 2.

        long t1 = pToTheTwoToTheNMod(2, n, gOfX) ^ 2;
        if (t1 != 0) {
            return false;
        }

        // Item 2 translated:
        // 
        // In this system, for polynomials m with degree larger
        // than 1, (p mod m) - x == (p - x) mod m.  Therefore,
        // the calculation of the gcd operation and the calculation
        // of xtt2tt-m-sub-i minus x are interleaved to avoid
        // gigantic intermediate results.  Rather than computing
        // xtt2ttmsi-x first, then feeding it to the gcd calculation,
        // this code calculates xtt2ttmsi modulo g(x) first, then
        // subtracts x, then proceeds with the gcd calculation.

//        for (int m = 1; m < n ; m++) {
//            long t = pToTheTwoToTheNMod(2, m, gOfX) ^ 2;
//            t = gcd(t, gOfX);
//            if (t != 1)
//                return false;
//        }

        long[] pf = primeFactorsOfTwoToTheNMinusOne(n);
        for (int i = 0; i < pf.length; i++) {
            long m = tttnm1 / pf[i];
            long t = pToTheNMod(2, m, gOfX);
            if (t == 1)
                return false;
        }

        return true;
    }

    /**
      Treating <code>p</code> as a polynomial of x
      with {0,1} coefficients, this returns (in
      polynomial arithmetic) p*x modulo irred_poly.
      <p>
      This assumes that irred_poly has degree 32,
      and that the representation of irred_poly
      has an implicit bit 32 (since if it were
      explicit, the representation of irred_poly
      would require 33 bits).
      */
    public static int step(int p, int irred_poly) {
        if (p < 0) {
            return (p + p) ^ irred_poly;
        } else {
            return p + p;
        }
    }

    /**
      This constructs a lookup table that can be
      used to accelerate calculation of
      p*x<sup>8</sup> modulo irred_poly
      <p>
      This assumes that irred_poly has degree 32,
      and that the representation of irred_poly
      has an implicit bit 32 (since if it were
      explicit, the representation of irred_poly
      would require 33 bits).
      <p>
      This table can be used in byte-string
      hashing operations; to calculate the
      remainder of a given string of bytes
      modulo some poly, it is sufficient to
      use a table returned by this routine,
      and use the iteration:
      <pre>
      int remainder = 0;
      for (int i = 0; i < bytes.length; i++) {
          remainder =
            (table[remainder >>> 24] ^ (remainder << 8)) +
              ((int) bytes[i] & 0xFF); 
      }
      </pre>
      This is not a spectacular hash function for
      short strings, but it apparently has interesting
      properties for manipulating long streams of bytes.
      If a hash function is desired, merely performing
      an additional four iterations:
    
      <pre>
      for (int i = 0; i < 4; i++) {
          remainder =
            (table[remainder >>> 24] ^ (remainder << 8));
      }
      </pre>
    
      yields a very good one, though it is not the fastest
      possible.
    
      */

    public static int[] makeLookupTable(int irred_poly) {
        int[] result = new int[256];
        for (int i = 0; i < 256; i++) {
            int trial = i << 24;
            for (int j = 0; j < 8; j++) {
                trial = step(trial, irred_poly);
            }
            result[i] = trial;
        }
        return result;
    }

    public static void printLookupTable(int[] table) {
        System.out.println(" = {");
        for (int i = 0; i < table.length; i++) {
            if (i == 255) {
                System.out.print("0x" + Integer.toHexString(table[i]) + "};");
            } else {
                System.out.print("0x" + Integer.toHexString(table[i]) + ", ");
            }
            if (((i + 1) & 3) == 0)
                System.out.println();
        }
    }

    // Sanity check
    static boolean test(long x) {
        System.err.print(Long.toHexString(x));
        boolean b = isIrreducible(x);
        System.err.println((b ? " is irreducible" : " is reducible"));
        return b;
    }

    static int pcount = 0;

    public static long exercise(long l) {
        int pc = populationCount((int) l);

        if (pc == 16 && isIrreducible(l) // && 0 == divisor()
        ) {

            long i1 = Invert.invert(l & 0xffffffffL, 0x7fffffffL);
            long i2 = Invert.invert(l & 0xffffffffL, 0x80000000L);
            long i3 = Invert.invert(l & 0xffffffffL, 0x100000000L);

            int d0 = TwoPolynomials.divisor(l & 0xffffffffL);
            int d1 = TwoPolynomials.divisor(i1);
            int d2 = TwoPolynomials.divisor(i2);
            int d3 = TwoPolynomials.divisor(i3);

            int dcount = 0;
            if (d0 == 0)
                dcount++;
            if (d1 == 0)
                dcount++;
            if (d2 == 0)
                dcount++;
            if (d3 == 0)
                dcount++;

            if (dcount == 2) {
                System.err.println(
                    "#define h"
                        + (pcount++)
                        + " 0x"
                        + Long.toHexString(0xffffffffL & l));
                // makeLookupTable((int)l);

            }
            return l;
        }
        return 0;
    }

    final static int h0 = 0x2099ebb3;
    final static int h1 = 0x28d7c663;
    final static int h2 = 0xbd52982d;
    final static int h3 = 0x153d88db;
    final static int h4 = 0xae88dd83;
    final static int h5 = 0xad657069;
    final static int h6 = 0x1c62cd6b;
    final static int h7 = 0x5c2225f7;
    final static int h8 = 0xe9525c8d;
    final static int h9 = 0x2aa9b2ab;
    final static int h10 = 0x90fa44ed;
    final static int h11 = 0x068b8fd9;
    final static int h12 = 0xa334536b;
    final static int h13 = 0x7de40e29;
    final static int h14 = 0x9394965b;
    final static int h15 = 0x71cc581f;
    final static int h16 = 0xa7429b27;
    final static int h17 = 0x9992627d;
    final static int h18 = 0x33e485cb;
    final static int h19 = 0xc448ccfd;
    final static int h20 = 0xb0d91f0b;
    final static int h21 = 0x73396433;
    final static int h22 = 0x46b0b747;
    final static int h23 = 0x20b4bee3;
    final static int h24 = 0xc0d562ed;
    final static int h25 = 0xcd2f28d1;
    final static int h26 = 0x46384f75;
    final static int h27 = 0xad88ae87;
    final static int h28 = 0xb23b8d43; // negative
    final static int h29 = 0xbca541d3; // negative
    final static int h30 = 0x05592ff1;
    final static int h31 = 0xed1a0ec9;
    final static int h32 = 0x6d1418ef;
    final static int h33 = 0xb4ce06b3;
    final static int h34 = 0x0fe4b171;
    final static int h35 = 0x89b6ccc9;
    final static int h36 = 0x8e4154df;
    final static int h37 = 0x04e0ef6b;
    final static int h38 = 0xd90ea917;
    final static int h39 = 0x5c0ead87;
    final static int h40 = 0x478a157d;
    final static int h41 = 0x1b4ad713;
    final static int h42 = 0xa256d569;
    final static int h43 = 0x5b94469b;
    final static int h44 = 0x52e90be5;
    final static int h45 = 0x7607db81;
    final static int h46 = 0xcb1a8753;
    final static int h47 = 0x1f62b945;

    // For convenient array access.
    final static int[] H =
        {
            h0,
            h1,
            h2,
            h3,
            h4,
            h5,
            h6,
            h7,
            h8,
            h9,
            h10,
            h11,
            h12,
            h13,
            h14,
            h15,
            h16,
            h17,
            h18,
            h19,
            h20,
            h21,
            h22,
            h23,
            h24,
            h25,
            h26,
            h27,
            h28,
            h29,
            h30,
            h31,
            h32,
            h33,
            h34,
            h35,
            h36,
            h37,
            h38,
            h39,
            h40,
            h41,
            h42,
            h43,
            h44,
            h45,
            h46,
            h47 };

    public static int divisor(int x) {
        long l = (long) x & 0xffffffffL;
        return divisor(l);
    }

    public static int divisor(long l) {
        if (0 == (l & 1))
            return 2;

        long s = (long) Math.sqrt((double) l);

        for (int i = 3; i <= (int) s; i += 2) {
            if (l % i == 0)
                return i;
        }
        return 0;
    }

    static public void checkHPrime() {
        for (int i = 0; i < H.length; i++) {
            int h = H[i];
            int d = divisor(h);
            String s = "0x" + Long.toHexString((long) h & 0xffffffffL);
            if (d == 0)
                System.out.println(s + " is prime");
            else
                System.out.println(s + " is divided by " + d);
        }
    }

    public static void main(String[] args) {

        // checkHPrime();

        test(0x118000003L); // Known irred poly.
        test(0x101800003L); // Known reducible poly.
        test(0x19L);
        test(0x13L);
        test(0x11dL);
        test(0x1c3L);
        test(0x1100bL);
        test(0x100400007L);

        if (true)
            return;

        Random r = new Random(0x123456789L);
        int n_found = 0;

        // Find a few irreducible polynomial with 16 set bits in
        // their implicit-leading-term representation,
        // and print them, and finally print the table for the last
        // one.

        for (int j = 0; j < 25; j++) {

            long start = System.currentTimeMillis();

            if (pcount > 50)
                break;
            for (int i = 0; i < 500000; i++) {
                long l = 0x100000001L | ((long) r.nextInt() & 0xffffffffL);
                if (pcount > 50)
                    break;
                exercise(l);
            }

            long stop = System.currentTimeMillis();

            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
            }

            System.err.println(
                "Timing run "
                    + j
                    + " took "
                    + (stop - start)
                    + " milliseconds");
        }
    }
}
