# Here is a simplified program for Shor's Factoring Algorithm.
# Consider it useful for instructional purposes only.
# It is very inefficient, I know. The Fourier transform is brute
# force, using neither the FFT nor any optimizations. The "quantum
# entanglement" is of course carried out classically.<p>

#! /usr/bin/ruby

require 'complex'
include Math

# This is Shor's factoring algorithm as explained in the class on Quantum Mechanics and
# Quantum Computation. It does not use the QFT, but simply the Fourier transform matrix
# without any optimizations. Then the program looks for the greatest common divisor among
# the indices of all the non-zero elements in the output. This works fine, because the order
# of the transform is an integer multiple of the order of the function f(x) = a^x (mod N),
# the principle disadvantage being that the transform is slow.

# Routine period_finder fakes "quantum entanglement" in O(r) time,
# where r is the repeat period.
# Input a -- the base for f(x) = a^x (mod N)
# Input n -- N in f(x) = a^x (mod N)
# Return -- r, the order of a (mod N), where f(x + r) = f(x) (mod N); also a^r = 1 (mod N)
def period_finder_1(a, n)
	x, f = 0, 1
   	loop do
	x += 1
        f = (f * a) % n
        break if f == 1
    end
    x
end

# Build the input array as it would look after measuring f(x) = a^x (mod N)
# and entanglement.
# Input inp -- initially a blank array; filled with the entangled state on return
# Input rand_ind -- a random index with 0 <= rand_ind < r
# Input r -- the order of a (mod N)
# Input m -- the length of inp
# Return -- none
def build_sft_input(inp, rand_ind, r, m)
    m.times { |i| inp[i] = 1.0 if i % r == rand_ind }
end

# The slow Fourier transform
# Input inp -- transform input
# Input outp-- initially a blank array; contains the transform on output
# Return -- none
def sft(inp, outp)
    m, sum = inp.length, nil
    omg = exp(2 * Complex::I * PI / m)
    puts "M = #{m}"
    m.times do |i|
        sum = Complex(0.0, 0.0)
        m.times do |j|
            sum += omg**(i * j) * inp[j]
        end
        outp << sum
    end
end
# Fast and efficient algorithm for computing b^e (mod m)
# Input b -- the base, or number to be exponentiated
# Input e -- the exponent to raise the base, b, to
# Input m -- the modulus of the exponentiation problem, b^e (mod m)
# Return -- b^e (mod m)
def pow_mod(b, e, m)
    result = 1
    while e > 0
        result = (result * b) % m if e & 1 == 1
        e >>= 1
        b = (b * b) % m
    end
    result
end

# Euclid's algorithm
# Input x_in -- one of the numbers to find the greatest common denominator of
# Input y_in -- one of the numbers to find the greatest common denominator of
# Return -- gcd(x_in, y_in)
def gcd(x_in, y_in)
    x, y = x_in.abs, y_in.abs
    if (x > 0)  or (y > 0)
        if y == 0
            y = x
        elsif x > 0
            x, y = y, x if x < y
            r = nil
            x, y = y, r while (r = x % y) > 0
        end
        y
    end
end

# Miller's test for primality with respect to a given base
# Input n -- the number to test for primality
# Input base -- the base for the test
# Return -- true if n is prime to base, false otherwise
def millers_test(n, base)
    j, rv = 0, false
    d = n_minus_1 = n - 1
    while d % 2 != 1
        d /= 2
        j += 1
    end
    base_to_power = pow_mod(base, d, n)
    if base_to_power == 1 or base_to_power == n_minus_1
        rv = true
    else
        j.times do
            base_to_power = pow_mod(base_to_power, 2, n)
            if base_to_power == n_minus_1
                rv = true
                break
            end
        end
    end

    rv
end

# Rabin-Miller probabilistic test for primality
# Input n -- the number to test for primality
# Input uncertainty -- the level of uncertainty of the prime test
# Return -- true if believed to be prime, false otherwise
def rabin_miller(n, uncertainty)
    rv = true
    if n != 2
        count, base = (log(uncertainty) / log(0.25) + 1).to_int, 2
        count.times do
            if !millers_test(n, base)
                rv = false
                break
            end
            base += 1
        end
    end

    rv
end

# Collect the (non-zero) qubit indices after the Fourier transform
# Input outp -- the Fourier transform
# Input output_indices -- the qubits in the transform
# Return -- none
def build_output_indices_array(outp, output_indices)
    outp.each_index { |i| output_indices << i if outp[i].polar[0] > Epsilon }
end

Uncertainty, Epsilon, inp, outp, output_indices, m_over_r = 0.01, 0.0001, [], [], [], nil
have_p, have_q = false, false

begin
    p, q = nil, nil
    loop do
        if !have_p
            STDOUT.write "Enter first prime (<= 47): "
            p = gets().to_i()
            if p < 2 or p > 47
                puts "#{p} is out of range"
                redo
            elsif (p != 2 and p != 3 and p != 5 and p != 7) and !rabin_miller(p, Uncertainty)
                puts "#{p} is not believed to be prime"
                redo
            else
                have_p = true
            end
        end
        if !have_q
            loop do
                STDOUT.write "Enter second prime (<= 47): "
                q = gets().to_i()
                if q < 2 or q > 47
                    puts "#{q} is out of range"
                    redo
                elsif rabin_miller(q, Uncertainty)
                    have_q = true
                    break
                elsif q == 2 or q == 3 or q == 5 or q == 7
                    have_q
                    break
                end
                puts "#{q} is not believed to be prime"
            end
        end
        n = p * q
        a = 0
        loop do
            STDOUT.write "Enter a base number < #{n}: "
            a = gets().to_i()
            break if a > 1 and a < n
            puts "#{a} is out of range"
        end

        g = gcd(a, n)
        if g > 1
            puts "N has the non trivial factor #{g}"
            break
        else
            r = period_finder_1(a, n)
            puts "Order of #{a} (mod #{n}) = #{r} (using fake \"quantum entanglement\")"
            if r % 2 == 1
                puts 'Index is odd: need to try another base.'
                redo
            elsif pow_mod(a, r / 2, n) == n - 1
                puts 'Trivial square root: need to try another base'
                redo
            else
                loop do
                    STDOUT.write "Enter M / r: "
                    m_over_r = gets().to_i()
                    break if m_over_r > 0
                    puts "#{m_over_r} is invalid"
                end
                m = m_over_r * r
                inp = Array.new(m, 0.0)
                puts "There are #{r} possible values for the measurement of f(x)"
                rand_ind = rand(r)
                puts "The program randomly picked f(#{rand_ind}) = #{a}^#{rand_ind} (mod #{n}) = #{pow_mod(a, rand_ind, n)}"
                build_sft_input(inp, rand_ind, r, m)
                sft(inp, outp)
                #               outp.each { |el| printf "|(%8.4f, %9.4f)| = %8.4f\n", el.real, el.image, el.polar[0] }
                build_output_indices_array(outp, output_indices)
                p output_indices
                puts "The are #{output_indices.length} elements in the Fourier transform."
                gcd_count = nil
                loop do
                    STDOUT.write "Enter the number of those elements you would like to use for the gcd calculation: "
                    gcd_count = gets().to_i()
                    break if gcd_count >= 1 and gcd_count <= output_indices.length
                    puts "#{gcd_count} is invalid"
                end
                running_gcd = 0
                gcd_count.times do
                    an_index = rand(output_indices.length)
                    # The 0 entry need not be used in the gcd calculations.
                    next if an_index == 0
                    running_gcd = gcd(running_gcd, output_indices[an_index])
                    output_indices.delete_at(an_index)
                end
                r = m / running_gcd
                puts "The gcd result is #{running_gcd}, resulting in r = #{r}."
                # Ruby's power modulus is either not working or not implemented!
                # sr = a**(r / 2) % n
                sr = pow_mod(a, r / 2, n)
                printf "%d is a non-trivial square root of 1 (mod %d).\n", sr, n
                p, q = gcd(sr + 1, n), gcd(sr - 1, n)
                printf "The factors of %d are %d and %d.\n", n, p, q
                break
            end
        end
        break
    end
    #    raise RuntimeError, "Bad Data"
rescue RuntimeError => message
    puts message
end