include Math

# Euclid's algorithm
# Input x_in -- one of the numbers to find the greatest common denominator of
# Input y_in -- one of the numbers to find the greatest common denominator of
# Return -- gcd(x_in, y_in)
def gcd(x_in, y_in)
    x, y = x_in.abs, y_in.abs
    if (x > 0)  or (y > 0)
        if y == 0
            y = x
        elsif x > 0
            x, y = y, x if x < y
            r = nil
            x, y = y, r while (r = x % y) > 0
        end
        y
    end
end

# Multiplicative inverse via extended Euclid's algorithm
# Input x -- The integer to find the multiplicative inverse of
# Input m -- The modulus of the inverse operation
# Return -- The multiplicative inverse of x (mod m)
def m_i_euclid(x, m)
    raise "#{x} and #{m} are not relatively prime!" if gcd(x, m) > 1
    l_r0, l_r1, l_y0, l_y1 = m, x, 0, 1
    loop do
        q, r = l_r0.divmod(l_r1)
        y = l_y0 - q * l_y1
        break if r == 0
        l_y0, l_y1 = l_y1, y
        l_r0, l_r1 = l_r1, r
    end

    l_y1 % m
end

# From Rosen's Number Theory book
# Multiplicative inverse via extended Euclid's algorithm
# Input a -- The integer to find the multiplicative inverse of
# Input m -- The modulus of the inverse operation
# Return -- The multiplicative inverse of a (mod m)
def m_i_euclid_2(a, m)
    raise "#{a} and #{m} are not relatively prime!" if gcd(a, m) > 1
    m_in, x, x_back, q, r = m, 0, 1, nil, nil
    while m != 0
        q, r = a.divmod(m)
        a, m = m, r
        x, x_back = x_back - q * x, x
    end

    x_back % m_in
end

# Find s and t in a * s + b * t = gcd(a, b)
# Input a -- First integer
# Input b -- Second integer
# Return -- Array with first element = s and second element = t
def extended_euclid(a, b)
    x, x_back, y, y_back, q, r = 0, 1, 1, 0, nil, nil
    while b != 0
        q, r = a.divmod(b)
        a, b = b, r
        x, x_back = x_back - q * x, x
        y, y_back = y_back - q * y, y
    end
    return x_back, y_back
end

# Multiplicative inverse via Gauss's method
# Input x -- The integer to find the multiplicative inverse of
# Input m -- The modulus of the inverse operation
# Return -- The multiplicative inverse of x (mod m)
def m_i_gauss(x, m)
    raise "#{x} and #{m} are not relatively prime!" if gcd(x, m) > 1
    z = 1

    while x > 1
        z += m
        g = gcd(x, z)
        if g > 1
            x /= g
            z /= g
        end
    end

    z
end

# Fast and efficient algorithm for computing b^e (mod m)
# Input b -- the base, or number to be exponentiated
# Input e -- the exponent to raise the base, b, to
# Input m -- the modulus of the exponentiation problem, b^e (mod m)
# Return -- b^e (mod m)
def pow_mod(b, e, m)
    result = 1
    while e > 0
        result = (result * b) % m if e & 1 == 1
        e >>= 1
        b = (b * b) % m
    end
    result
end

# Pull the factors of 2 out of a number, n - 1
# Input n -- the number to split into n - 1 = 2^s * t
# Return -- s and t as a tuple
def miller_s_and_t(n)
    s, t = 0, n - 1
    while t % 2 == 0
        t /= 2
        s += 1
    end
    return s, t
end

# Miller's test for primality with respect to a given base
# Input n -- the number to test for primality
# Input base -- the base for the test
# Input s -- n - 1 = 2^s * t
# Input t -- n - 1 = 2^s * t
# Return -- true if n is prime to base, false otherwise
def millers_test(n, base, s, t)
   rv, n_minus_1 = false, n - 1

   base_to_power = pow_mod(base, t, n)
   if base_to_power == 1 or base_to_power == n_minus_1
     # Obeys Fermat's Little Theorem, so probably prime. 
     rv = true
   else
       1.upto(s- 1) do
           base_to_power = pow_mod(base_to_power, 2, n)
           if base_to_power == 1
             # base-to_power's square root is not + or - 1.
             break
           elsif base_to_power == n_minus_1
             # Obeys Fermat's Little Theorem, so probably prime.
             rv = true
             break
           end
       end
       # Composite.
   end

   rv
end

# Rabin-Miller probabilistic test for primality
# Input n -- the number to test for primality
# Input uncertainty -- the level of uncertainty of the prime test
# Return -- true if believed to be prime, false otherwise
def rabin_miller(n, uncertainty)
    rv = true

    if n != 2
        count, base = (log(uncertainty) / log(0.25)).ceil, 2
        if count < 1
            count = 1
        end
        s, t = miller_s_and_t(n)
        i = 1
        while i <= count
            if base % n != 0
                if !millers_test(n, base, s, t)
                    rv = false
                    break
                end
                i += 1
            end
            base += 1
        end
    end

    rv
end

# Rabin-Miller deterministic test for primality of integers < 25.0e9
# Input n -- the number to test for primality
# Return -- true if prime, false otherwise
def rabin_miller_certain(n)
    raise "#{n} must be < 25.0e9" if n >= 25.0e9

    strong_pseudoprime = 3215031751
    rv = true
    if n == strong_pseudoprime
      rv = false
    elsif n != 2
      s, t = miller_s_and_t(n)
      [2, 3, 5, 7].each do |base|
        if !millers_test(n, base, s, t)
            rv = false
            break
        end
      end
    end

    rv
end

# Second order Newton method (Halley's method) for square roots of large numbers
# Input n -- the number to take the root of
# Return value is an array with:
#    first array element the best approximation to the square root
#    second array element one of 'Exact', 'Low', or 'Not found'
#    third array element the number of iterations
# A return of 'Exact' means that n is a perfect square and that the root is the true root
# A return of 'Low' means the root is within 1 of the square root and is less than it, the so-called floor
# A return of 'Not found' means that the Newton iteration didn't converge to within 1 of the square root

APPROX_SQRT_MAX_ITERS = 1000
def approx_sqrt(n)
    x, x_next, msg, n_3, iters = 1, nil, 'Not found', 3 * n, nil
    # Initial guess is a number with about half as many digits, in this case
    # a 1 followed by a bunch of zeroes
    # base 10 logarithm approximation:
    # ((sprintf "%d", n).length / 2).times { x *= 10 }
    # base 2 logarithm approximation:
    x = n >> (n.to_s(2).length / 2)
    1.upto(APPROX_SQRT_MAX_ITERS) do |iters|
        x_next = (x * (x * x + n_3)) / (3 * x * x + n)
        x_p = x_next + 1
        x_m = x_next - 1
        if x_next * x_next - n == 0
            msg = 'Exact'
            break
        elsif x_p * x_p - n == 0
            msg = 'Exact'
            x_next = x_p
            break
        elsif x_m * x_m - n == 0
            msg = 'Exact'
            x_next = x_m
            break
        elsif (x_p * x_p - n) * (x_next * x_next - n) < 0
            msg = 'Low'
            break
        elsif (x_m * x_m - n) * (x_next * x_next - n) < 0
            msg = 'Low'
            x_next = x_m
            break
        end
        x = x_next
    end
    return x_next, msg, iters
end

def approx_sqrt_2(n)
    x, x_next, msg, n_3, iters = 1, nil, 'Not found', 3 * n, nil
    # base 2 logarithm approximation:
    x = n >> (n.to_s(2).length / 2)
    1.upto(APPROX_SQRT_MAX_ITERS) do |iters|
        x_next = (x * (x * x + n_3)) / (3 * x * x + n)
        if (x_next - x).abs < 1
            if x_next * x_next == n
                msg = 'Exact'
            else
                msg = 'Floor'
            end
            break
        end
        x = x_next
    end
    return x_next, msg, iters
end

# Check for pairwise primality of a set of integers
# Input ps -- array of integers to test
# Return -- true if integers are pairwise prime, false otherwise
def pairwise_prime(ps)
  l, result = ps.length, true
  0.upto(l - 2) do |i|
    p1 = ps[i]
    (i + 1).upto(l - 1) do |j|
      p2 = ps[j]
      if gcd(p1, p2) != 1
        result = false
        break
      end
    end
  end
  result
end

# Chinese Remainder Theorem for a single variable group of congruences
# Input a_of_a -- array of array of congruences, with each subarray of the form [xi, modi]
# Return -- the simultaneous solution of the congruences
def crt_1_variable(a_of_a)
  ps = []
  # Collect primes for pairwise prime test
  a_of_a.each { |a| ps << a[1] }
  raise "Moduli are not all pairwise relatively prime!" if not pairwise_prime(ps)
  ms = []
  m = ps.inject(1){ |result, p| result * p}
  ps.each { |p| ms << m / p }
  result = 0
  a_of_a.each_index do |i|
    md = ps[i]
    m_i = m_i_euclid(ms[i], md)
    result += a_of_a[i][0] * ms[i] * m_i
  end
  result % m
end

# Chinese Remainder Theorem for two variables with the same modulus
# Inputs a, b, c, d, e, f, m from
# a * x + b * y = e (mod m) and c * x + d * y = f (mod m)
# Return -- the array [x, y]
def crt_2_variables(a, b, c, d, e, f, m)
  delta = a * d - b * c
  raise "Discriminant is not relatively prime to modulus!" if gcd(delta, m) != 1
  delta_i = m_i_euclid(delta, m)
  x = (delta_i * (d * e - b * f)) % m
  y = (delta_i * (a * f - c * e)) % m
  [x, y]
end