# This is a complete implementation of the Shor factoring algorithm.

#! /usr/bin/ruby

# Complex
require 'complex'
# log
include Math


# This version of the Shor factoring algorithm is complete. It uses the QFT with size
# a power of 2 and does the continued fraction expansion of the index of one of the non-trivial
# elements of the transform followed by a search through the expansion convergent denominators for
# the largest one less than N. It continues with as many indices as necessary until r, the order of
# f(x) = a^x (mod N), is found, or terminates if it exhausts the output indices without having
# found a value for r, i.e., a solution.

# Euclid's algorithm
# Input x_in -- one of the numbers to find the greatest common denominator of
# Input y_in -- one of the numbers to find the greatest common denominator of
# Return -- gcd(x_in, y_in)
def gcd(x_in, y_in)
    x, y = x_in.abs, y_in.abs
    if (x > 0)  or (y > 0)
        if y == 0
            y = x
        elsif x > 0
            x, y = y, x if x < y
            r = nil
            x, y = y, r while (r = x % y) > 0
        end
        y
    end
end

# Fast and efficient algorithm for computing b^e (mod m)
# Input b -- the base, or number to be exponentiated
# Input e -- the exponent to raise the base, b, to
# Input m -- the modulus of the exponentiation problem, b^e (mod m)
# Return -- b^e (mod m)
def pow_mod(b, e, m)
    result = 1
    while e > 0
        result = (result * b) % m if e & 1 == 1
        e >>= 1
        b = (b * b) % m
    end
    result
end

# Miller's test for primality with respect to a given base
# Input n -- the number to test for primality
# Input base -- the base for the test
# Return -- true if n is prime to base, false otherwise
def millers_test(n, base)
    j, rv = 0, false
    d = n_minus_1 = n - 1
    while d % 2 != 1
        d /= 2
        j += 1
    end
    base_to_power = pow_mod(base, d, n)
    if base_to_power == 1 or base_to_power == n_minus_1
        rv = true
    else
        j.times do
            base_to_power = pow_mod(base_to_power, 2, n)
            if base_to_power == n_minus_1
                rv = true
                break
            end
        end
    end

    rv
end

# Rabin-Miller probabilistic test for primality
# Input n -- the number to test for primality
# Input uncertainty -- the level of uncertainty of the prime test
# Return -- true if believed to be prime, false otherwise
def rabin_miller(n, uncertainty)
    rv = true
    if n != 2
        count, base = (log(uncertainty) / log(0.25) + 1).to_int, 2
        count.times do
            if !millers_test(n, base)
                rv = false
                break
            end
            base += 1
        end
    end

    rv
end

# Given a fraction's numerator and denominator, find its continued fraction coefficients
# Input num -- the numerator of the fraction
# Input den -- the denominator of the fraction
# Return -- the array of continued fraction coefficients
def fraction_to_cf(num, den)
    a = []
    a <<  num / den

    until num == 0
        num, den = den, num
        i = num / den
        a << i
        num -= i * den
    end
    a
end

def reverse_and_swap(n, count, s, swapped)
    r = 0
    mask = 1 << (count - 1)
    0.upto(count - 1) do |i|
        r += mask if (n[i] == 1)
        mask >>= 1
    end
    if r != n
        if !swapped.has_key?(n)
            swapped[r] = 't'
            s[2 * n + 1], s[2 * r + 1] = s[2 * r + 1], s[2 * n + 1]
            s[2 * n + 2], s[2 * r + 2] = s[2 * r + 2], s[2 * n + 2]
        end
    end
end

def realft(samples, i_sign)
    n = samples.length - 1
    c1 = 0.5
    theta = 2 * PI  / n
    if i_sign == 1
        c2 = -0.5
        jns_four1(samples, 1)
    else
        c2 = 0.5
        theta = -theta
    end
    hold = sin(0.5 * theta)
    wpr, wpi = -2.0 * hold * hold, sin(theta)
    wr, wi = 1.0 + wpr, wpi
    np3 = n + 3
    2.upto(n / 4) do |i|
        i1 = i + i - 1
        i2 = i1 + 1
        i3 = np3 - i2
        i4 = i3 + 1
        h1r, h1i =  c1 * (samples[i1] + samples[i3]), c1 * (samples[i2] - samples[i4])
        h2r, h2i = -c2 * (samples[i2] + samples[i4]), c2 * (samples[i1] - samples[i3])
        samples[i1] =  h1r + wr * h2r - wi * h2i
        samples[i2] =  h1i + wr * h2i + wi * h2r
        samples[i3] =  h1r - wr * h2r + wi * h2i
        samples[i4] = -h1i + wr * h2i + wi * h2r
        hold = wr
        wr += wr * wpr -   wi * wpi
        wi += wi * wpr + hold * wpi
    end
    h1r = samples[1]
    if i_sign == 1
        samples[1] = h1r + samples[2]
        samples[2] = h1r - samples[2]
    else
        samples[1] = c1 * (h1r + samples[2])
        samples[2] = c1 * (h1r - samples[2])
        jns_four1(samples, -1)
    end
end

def jns_four1(samples, i_sign)
    len = len2 = samples.length
    bits = 0
    while len2 > 2
        len2 >>= 1
        bits += 1
    end
    swapped = Hash.new
    0.upto(len / 2 - 1) do |n|
        reverse_and_swap(n, bits, samples, swapped)
    end
    swapped = nil
    n = len - 1
    m_max = 2
    while n > m_max
        i_step = 2 * m_max
        theta = 2 * PI  * i_sign / m_max
        hold = sin(0.5 * theta)
        wpr, wpi = -2.0 * hold * hold, sin(theta)
        wr, wi = 1.0, 0.0
        1.step(m_max, 2) do |m|
            m.step(n, i_step) do |i|
                j = i + m_max
                tempr = wr * samples[j]     - wi * samples[j + 1]
                tempi = wr * samples[j + 1] + wi * samples[j]
                samples[j]     = samples[i]     - tempr
                samples[j + 1] = samples[i + 1] - tempi
                samples[i]     += tempr
                samples[i + 1] += tempi
            end
            hold = wr
            wr += wr * wpr - wi   * wpi
            wi += wi * wpr + hold * wpi
        end
        m_max = i_step
    end
end

# Routine period_finder fakes "quantum entanglement" in O(r) time,
# where r is the repeat period.
# Input a -- the base for f(x) = a^x (mod N)
# Input n -- N in f(x) = a^x (mod N)
# Return -- r, the order of a (mod N), where f(x + r) = f(x) (mod N); also a^r = 1 (mod N)
def period_finder_1(a, n)
    x, f = 0, 1
    loop do
        x += 1
        f = (f * a) % n
        break if f == 1
    end
    x
end

# Find the smallest power of 2 >= n * n
# Input n -- the number to be factored
# Return -- n * n <= 2^q < 2 * n * n
def find_q(n)
    y = 2 * log(n) / log(2.0)
    2**(y.to_i() + 1)
end

# The FFT used, realft, only outputs the first half of the transform plus the Nyquist value
# Input inp -- the output of the FFT, which is only half the transform, and
#              with the Nyquist value in the second position
# Return -- the output array as Complexes
def build_outp(inp)
    rv = []
    nyq = inp[1]
    inp[1] = 0.0

    0.step(inp.length - 1, 2) { |i| rv << Complex(inp[i], inp[i + 1]) }
    rv << Complex(nyq, 0.0)
    k = rv.length - 2
    k.downto(1) { |i| rv << rv[i].conjugate }
    rv
end

# Collect the (non-zero) qubit indices after the Fourier transform
# Input outp -- the Fourier transform
# Input max_factor -- the factor for the threshold of non-zero interference
# Return output_indices -- the qubits in the transform
def build_output_indices_array(outp, max_factor)
    rv = []
    almost_mx = max_factor * outp.max.polar[0]
    # Skip the first element, which is always 0.
    1.upto(outp.length - 1) { |i| rv << i if outp[i].polar[0] > almost_mx }
    rv
end

# From the continued fraction coefficients recursively generate the denominators of the convergents
# Input a -- array of continued fraction coefficients
# Input n -- the number to factor
# Return -- the minimum value of the period, r, from the continued fraction convergent denominators
def find_r_min(a, n)
    k, k_back_1, k_back_2 = nil, 0, 1
    done, rv = false, nil

    a.each do |el|
        k = el * k_back_1 + k_back_2
        if k > n
            rv = k_back_1
            done = true
            break
        end
        k_back_2, k_back_1 = k_back_1, k
    end
    rv = k if !done
    rv
end

# Try integer multiples of r_min for r until a^r = 1 (mod n)
# Input r_min -- the smallest possible value of the period, r
# Input a -- the base
# Input n -- the modulus
# Return -- an array, the first element of which is true only if an r < n could be found
#           and the second is the order of a mod n, i.e., a^r = 1 (mod n), if found
def find_r(r_min, a, n)
    rv, found = r_min, true
    until pow_mod(a, rv, n) == 1
        rv += r_min
        if rv >= n
            found = false
            break
        end
    end
    return found, rv
end

Uncertainty, max_factor, inp, outp= 0.01, 0.9, nil, nil
have_p, have_q, all_done = false, false, false

begin
    p, q = nil, nil
    loop do
        if !have_p
            STDOUT.write "Enter first prime (< 50): "
            p = gets().to_i()
            if p < 2 or p >= 50
                puts "#{p} is out of range"
                redo
            elsif (p != 2 and p != 3 and p != 5 and p != 7) and !rabin_miller(p, Uncertainty)
                puts "#{p} is not believed to be prime"
                redo
            else
                have_p = true
            end
        end
        if !have_q
            loop do
                STDOUT.write "Enter second prime (< 50): "
                q = gets().to_i()
                if q < 2 or q >= 50
                    puts "#{q} is out of range"
                    redo
                elsif rabin_miller(q, Uncertainty)
                    have_q = true
                    break
                elsif q == 2 or q == 3 or q == 5 or q == 7
                    have_q = true
                    break
                end
                puts "#{q} is not believed to be prime"
            end
        end
        n = p * q
        a = 0
        loop do
            STDOUT.write "Enter a base number < #{n}: "
            a = gets().to_i()
            break if a > 1 and a < n
            puts "#{a} is out of range"
        end

        g = gcd(a, n)
        if g > 1
            puts "N has the non trivial factor #{g}"
            break
        else
            r = period_finder_1(a, n)
            puts "Order of #{a} (mod #{n}) = #{r} (using fake \"quantum entanglement\")"
            if r % 2 == 1
                puts 'Index is odd: need to try another base.'
                redo
            elsif pow_mod(a, r / 2, n) == n - 1
                puts 'Trivial square root: need to try another base'
                redo
            else
                q = find_q(n)
                puts "q = #{q}"
                puts "There are #{r} possible values for the measurement of f(x)"
                rand_ind = rand(r)
                puts "The program randomly picked f(#{rand_ind}) = #{a}^#{rand_ind} (mod #{n}) = #{pow_mod(a, rand_ind, n)}"
                # Build the input array as it would look after measuring f(x) = a^x (mod N) and entanglement.
                inp = Array.new(q) { |i| i % r == rand_ind ? 1.0: 0.0 }
                inp.unshift(-1.0)
                realft(inp, 1)
                inp.shift
                outp = build_outp(inp)
                output_indices = build_output_indices_array(outp, max_factor)
                puts "The are #{output_indices.length} non-zero elements in the Fourier transform."
                while !output_indices.empty?
                    rand_ind = rand(output_indices.length)
                    c = output_indices.delete_at(rand_ind)
                    puts "The program randomly picked #{c}."
                    cf_array =  fraction_to_cf(c, q)
                    r_min = find_r_min(cf_array, n)
                    puts "From the continued fraction expansion the minimum r is #{r_min}."
                    found, r = find_r(r_min, a, n)
                    if found
                        puts "After testing integer multiples of #{r_min}, we find r = #{r}."
#                       Ruby's power modulus is either not working or not implemented!
#                       sr = a**(r / 2) % n
                        sr = pow_mod(a, r / 2, n)
                        printf "%d is a non-trivial square root of 1 (mod %d).\n", sr, n
                        p, q = gcd(sr + 1, n), gcd(sr - 1, n)
                        printf "The factors of %d are %d and %d.\n", n, p, q
                        all_done = true
                        break
                    else
                        puts 'Could not find the period from the chosen random index; trying another one.'
                    end
                end
                "Exhausted all elements in the QFT. Run the program again, please." if !all_done
            end
            break
        end
        break
    end
    #    raise RuntimeError, "Bad Data"
rescue RuntimeError => message
    puts message
end