module NumberTheory
( eratosthenesSieve
, findFactors
, primeFactors
, primeFactorization
, primeFactorizationWithPrimes
, totient
, gcdEuclid
, inverseEuclid
, extendedEuclid
, powMod
, millerSAndT
, rabinMiller
, millersTest
, chineseRemainderTheorem
, chineseRemainderTheorem2
, integerSquareRoot
, babyStepGiantStepList
, babyStepGiantStep
, pollardFactorization
, tortoiseAndHare
, tAndHBrent
, pollardLog
, dixonAlgorithm
, shanksTonelli
) where

import qualified Data.Bits as Bits (shift)
import qualified Data.Map as Map (Map, fromList, lookup) -- Baby Step 
import qualified Data.Maybe as Maybe (isNothing, fromJust) -- Baby Step
import Data.List (group) -- primeFactorization
import Control.Arrow -- primeFactorization

-- I borrowed this one from https://wiki.haskell.org/Prime_numbers

-- minus is used to subtract one array from another
minus (x:xs) (y:ys) = case (compare x y) of
           LT -> x : minus  xs  (y:ys)
           EQ ->     minus  xs     ys 
           GT ->     minus (x:xs)  ys
minus xs _ = xs

{- The classic Sieve of Eratosthenes
-- Input n -- the largest number in the sieve
-- Return -- the sieved primes up to n as an array
-}
eratosthenesSieve :: Int -> [Int]
eratosthenesSieve n = 2 : sieve [3, 5 .. n]
    where sieve ps@(p:xs) 
            | p * p > n = ps
            | otherwise = p : sieve (xs `minus` [p * p, p * p + 2 * p ..])

findFactors :: Int -> [Int] -> [Int]
findFactors n (p:_) | n < p^2 = [n | n > 1]
findFactors n ps@(p:ps') = let (d, r) = n `divMod` p
    in if r == 0 then p : findFactors d ps else findFactors n ps'

{- Find the prime factors of an integer, including repetitions
-- Input n -- the integer to find the factors of
-- Return -- n's factors as an array
   Sample: ghci> primeFactors 75
           [3,5,5]
-}
primeFactors :: Int -> [Int]
primeFactors n = findFactors n primes
    where primes = eratosthenesSieve n 

{- The prime factorization of an integer
-- Input n -- the integer to factor
-- n's factorization as an array of tuples
   Sample: ghci> primeFactorization 75
           [(3,1),(5,2)]
-}
primeFactorization :: Int -> [(Int, Int)]
primeFactorization n = (map (head &&& length) . group) $ primeFactors n

{- The prime factorization of an integer with a supplied list of primes
-- Input n -- the integer to factor
-- Input p_list -- an array of primes
-- n's factorization as an array of tuples
   Sample: ghci> let p_list = eratosthenesSieve 100
           ghci> primeFactorizationWithPrimes 185 p_list
           [(5,1),(37,1)]
-}
primeFactorizationWithPrimes :: Int -> [Int] -> [(Int, Int)]
primeFactorizationWithPrimes n p_list= (map (head &&& length) . group) $ findFactors n p_list

{- Euler's totient function
-- Input n -- the number to find the totient of
-- Return -- totient(n)
-}
totient :: Int -> Int
totient n = foldl (\acc p -> (acc - (acc `div` p))) n unique_pf
    where pf = primeFactorization n
          unique_pf = [ a | (a, b) <- pf ]

{- Greatest common denominator via Euclid algorithm
-- Input a -- one of the numbers to find the greatest common denominator of
-- Input b -- one of the numbers to find the greatest common denominator of
-- Return  -- gcd(a, b)
   Sample: ghci> gcd 24 44
           4
-}
gcdEuclid :: Integer -> Integer -> Integer
gcdEuclid a b
  | y == 0  = x
  | otherwise = gcdEuclid y (x `mod` y)
    where x = abs a
          y = abs b

{- Multiplicative inverse via extended Euclid algorithm
-- Input x -- The integer to find the multiplicative inverse of
-- Input m -- The modulus of the inverse operation
-- Return  -- The multiplicative inverse of x (mod m)
   Sample: ghci> inverseEuclid 7 9
           4
-}
inverseEuclid :: Integer -> Integer -> Integer
inverseEuclid x m = inverseEuclidLoop m x m 0 1 100

inverseEuclid2 :: Integer -> Integer -> Either String Integer
inverseEuclid2 x m = if gcdEuclid x m == 1
                        then Right $ inverseEuclidLoop m x m 0 1 100
                        else Left $ (show x) ++ " and " ++ (show m) ++ " are not coprime!"

inverseEuclidLoop :: Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Integer
inverseEuclidLoop m_hold _ _ l_y0 _ 0 = l_y0 `mod` m_hold
inverseEuclidLoop m_hold x m l_y0 l_y1 r_dummy = inverseEuclidLoop m_hold l_r1_2 l_r0_2 l_y0_2 l_y1_2 r
  where l_r0 = m
        l_r1 = x
        (q, r) = l_r0 `divMod` l_r1
        y = l_y0 - q * l_y1
        l_y0_2 = l_y1
        l_y1_2 = y
        l_r0_2 = l_r1
        l_r1_2 = r

{- Find s and t in a * s + b * t = gcd(a, b) via extended Euclid's algorithm
-- Input (a, b) -- Tuple of integers
-- Return       -- (s, t)
   Sample: ghci> extendedEuclid (24, 44)
           (2,-1) 
-}
extendedEuclid :: (Integer, Integer) -> (Integer, Integer)
extendedEuclid (a, b) = extendedEuclidLoop (a, b) 0 1 1 0

extendedEuclidLoop :: (Integer, Integer) -> Integer -> Integer -> Integer -> Integer -> (Integer, Integer)
extendedEuclidLoop (_, 0) _ x_back _ y_back = (x_back, y_back)
extendedEuclidLoop (a, b) x x_back y y_back = extendedEuclidLoop (a_new, b_new) x_new x_back_new y_new y_back_new
  where (q, r) = a `divMod` b
        a_new = b
        b_new = r
        x_new = x_back - q * x
        x_back_new = x
        y_new = y_back - q * y
        y_back_new = y

{- Fast and efficient algorithm for computing b^e (mod m)
-- Input b -- the base, or number to be exponentiated
-- Input e -- the exponent to raise the base, b, to
-- Input m -- the modulus of the exponentiation problem, b^e (mod m)
-- Return  -- b^e (mod m)
   Sample: ghci> powMod 7 5 11
           10
-}
powMod :: Integer -> Integer -> Integer -> Integer
powMod b e m = powModLoop b e m 1

powModLoop :: Integer -> Integer -> Integer -> Integer -> Integer
powModLoop _ 0 _ result = result
powModLoop b e m result = powModLoop b_new e_new m result_new
  where result_new = if (odd e) then (result * b) `mod` m else result
        e_new = Bits.shift e (-1)
        b_new = (b * b) `mod` m

{- Pull the factors of 2 out of a number, n - 1 (used by rabinMiller, following)
-- Input n -- the number to split into n - 1 = 2^s * t
-- Return  -- (s, t)
   Sample: ghci> millerSAndT 89
           (3,11)
-}
millerSAndT :: Integer -> (Integer, Integer)
millerSAndT n = millerSAndTLoop 0 (n - 1) 0 (n - 1)

millerSAndTLoop :: Integer -> Integer -> Integer -> Integer -> (Integer, Integer)
millerSAndTLoop s _ 1 last_t = (s - 1, last_t)
millerSAndTLoop s t t_mod_2 last_t= millerSAndTLoop s_new t_new t_mod_2_new last_t_hold
  where last_t_hold = t 
        t_new = t `div` 2
        s_new = s + 1
        t_mod_2_new = t `mod` 2
   
{- Rabin-Miller probabilistic test for primality
-- Input n           -- the number to test for primality
-- Input uncertainty -- the level of uncertainty of the prime test
-- Return            -- True if believed to be prime, False otherwise
   Sample: ghci> rabinMiller 997 1.0e-7
           True
-}
rabinMiller :: Integer -> Float -> Bool
rabinMiller n uncertainty
    | n == 2 = True
    | n == 3215031751 = False -- strong pseudoprime
    | even n = False -- No need to test even numbers for primality.
    | otherwise = rabinMillerLoop n count (s, t) base ret
    where count = ceiling (log (uncertainty) / log(0.25))
          (s, t) = millerSAndT n
          base = 2
          ret = True

rabinMillerLoop :: Integer -> Integer -> (Integer, Integer) -> Integer -> Bool -> Bool
rabinMillerLoop _ _ (_, _) _ False = False
rabinMillerLoop _ 0 (_, _) _ ret = ret
rabinMillerLoop n count (s, t) base ret = 
  let temp = base `mod` n /= 0
      ret_new = if temp then millersTest n base (s, t) else ret
      count_new = if temp then count - 1 else count
      base_new = base + 1
  in rabinMillerLoop n count_new (s, t) base_new ret_new

{- Miller's test for primality with respect to a given base
-- Input n      -- the number to test for primality
-- Input base   -- the base for the test
-- Input (s, t) -- where n - 1 = 2^s * t (from millerSAndT above)
-- Return       -- True if n is prime to base, False otherwise
   Sample: ghci> millersTest 89 7 (3,11)
           True
-}
millersTest :: Integer -> Integer -> (Integer, Integer) -> Bool
millersTest n base (s, t) = 
  let n_minus_1 = n - 1
      base_to_power = powMod base t n
      {- n is odd => n - 1 is even => s >= 1 => base_to_power will get squared at least once
         and be forever = 1 after that, whether it started at 1 or -1 => base_to_power ^ (n - 1) = 1.
         This is Fermat's little theorem, so let's call n a prime and quit.
      -}
      rv = base_to_power == 1 || base_to_power == n_minus_1
  in millersTestLoop s base_to_power n n_minus_1 rv
  
millersTestLoop :: Integer -> Integer -> Integer -> Integer -> Bool -> Bool
millersTestLoop _ _ _ _ True = True
millersTestLoop 1 _ _ _ rv = rv
millersTestLoop s base_to_power n n_minus_1 rv =
  let base_to_power_new = powMod base_to_power 2 n
      {- if base_to_power_new = n - 1 then n is behaving like a prime, so let's call it one and quit
         the loop.
      -}
      rv_new = if (base_to_power_new == n_minus_1) then True else rv
      {- if base_to_power_new = 1 then base_to_power is self_inverse, but if n is prime, it's only
         self-inverses are 1 and n - 1. We've already eliminated both of those possible values for
         base_to_power, so let's say that n is not prime and break out of the loop by setting s_new = 1.
      -}
      s_new = if (base_to_power_new == 1) then 1 else s - 1
  in millersTestLoop s_new base_to_power_new n n_minus_1 rv_new
 
{- Pairwise prime routine
-- Input a -- array of integers to test for pairwise primality
-- Return  -- True if numbers are pairwise prime, False otherwise
   Sample: ghci> pairwisePrime [2,7,15,19]
           True
-}
pairwisePrime :: [Integer] -> Bool
pairwisePrime a = length [1 | x <- a, y <- a, x < y, gcdEuclid x y /= 1] == 0

{- Chinese Remainder Theorem for a single variable group of congruences
-- Input a_of_t -- array of tuples of congruences, with each tuple of the form (x_i, m_i) where x = x_i (mod m_i)
-- Return       -- Left "Moduli are not pairwise prime!" for bad data or Right results from CRT
   Sample: ghci> chineseRemainderTheorem [(1,5), (2,6), (3,7)]
           Right 206
-}      
chineseRemainderTheorem :: [(Integer, Integer)] -> Either String Integer
chineseRemainderTheorem a_of_t = rv
  where ps = [p | (x, p) <- a_of_t]
        m = product ps
        ms = [m `div` p | (x, p) <- a_of_t]
        rv = if pairwisePrime ps
             then Right $ (sum [x * an_ms * (inverseEuclid an_ms p)| ((x, p), an_ms) <- zip a_of_t ms]) `mod` m
             else Left $ "Moduli are not pairwise prime!"

{- Chinese Remainder Theorem for two variables with the same modulus
-- Inputs -- a, b, e, c, d, f, m from
      a * x + b * y = e (mod m) and c * x + d * y = f (mod m)
-- Return --  Left "No solution!" or Right the tuple (x, y) if there is a solution
   Sample: ghci> chineseRemainderTheorem2 3 4 5 2 5 7 13
           Right (7,9)
-}
chineseRemainderTheorem2 :: Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Either String (Integer, Integer)
chineseRemainderTheorem2 a b e c d f m = rv
    where delta = a * d - b * c
          valid = gcdEuclid delta m == 1
          delta_i = if valid then inverseEuclid delta m else 0
          x = if valid then (delta_i * (d * e - b * f)) `mod` m else -999
          y = if valid then (delta_i * (a * f - c * e)) `mod` m else -999
          rv = if valid then Right (x, y) else Left "No solution!"

{- Second order Newton method (Halley's method) for integer square roots of large numbers
-- Input n -- the number to take the root of
-- Return value is a 3-tuple (x, msg, iters) with:
     x     equal to the best approximation to the square root
     msg   one of "Exact", "Low", or "No convergence"
     iters the number of iterations
     A return of "Exact" means that n is a perfect square and that the root is the true root
     A return of "Low" means the root is within 1 of the square root and is less than it, the so-called floor
     A return of "No convergence" means that the Newton iteration didn't converge to within 1 of the square root in the number of
        iterations returned
   Samples: ghci> integerSquareRoot 994009
            (997,"Exact",2)
            ghci> integerSquareRoot 994010
            (997,"Low",2)
-}      
integerSquareRoot :: Integer -> (Integer, String, Int)
integerSquareRoot n =
    let n_3 = 3 * n
        m = length (show n)
        shift_by = m * 3 `div` 2
        x = n `Bits.shift` (-shift_by)
        iters = 0
        msg = ""
    in integerSquareRootLoop n n_3 x iters msg

integerSquareRootLoop :: Integer -> Integer -> Integer -> Int -> String -> (Integer, String, Int)
integerSquareRootLoop _ _ x iters "Exact"  = (x, "Exact", iters)
integerSquareRootLoop _ _ x iters "Low"  = (x, "Low", iters)
integerSquareRootLoop _ _ x 1000 _  = (x, "No convergence", 1000)
integerSquareRootLoop n n_3 x iters msg = integerSquareRootLoop n n_3 x_next iters_next msg_new
    where x_hold = (x * (x * x + n_3)) `div` (3 * x * x + n)
          x_p = x_hold + 1
          x_m = x_hold - 1
          x_next
             | x_p * x_p - n == 0 = x_p
             | x_m * x_m - n == 0 = x_m
             | (x_m * x_m - n) * (x_hold * x_hold - n) < 0 = x_m
             | otherwise = x_hold
          iters_next = iters + 1
          msg_new
             | x_hold * x_hold - n == 0 = "Exact"
             | x_p * x_p - n == 0  = "Exact"
             | x_m * x_m - n == 0 = "Exact"
             | (x_p * x_p - n) * (x_hold * x_hold - n) < 0 = "Low"
             | (x_m * x_m - n) * (x_hold * x_hold - n) < 0 = "Low"
             | otherwise = "None"

{- Second order Newton method (Halley's method) for integer square roots of large numbers
-- Input n -- the number to take the root of
-- Return value is a 3-tuple (x, msg, iters) with:
     x     equal to the best approximation to the square root
     msg   one of "Exact", "Low", or "No convergence"
     iters the number of iterations
     A return of "Exact" means that n is a perfect square and that the root is the true root
     A return of "Low" means the root is within 1 of the square root and is less than it, the so-called floor
     A return of "No convergence" means that the Newton iteration didn't converge to within 1 of the square root in the number of
        iterations returned
   Samples: ghci> integerSquareRoot2 994009
           (997,"Exact",3)
           ghci> integerSquareRoot2 994010
           (997,"Low",3)
-}
integerSquareRoot2 :: Integer -> (Integer, String, Int)
integerSquareRoot2 n =
    let n_3 = 3 * n
        m = length (show n)
        shift_by = m * 3 `div` 2
        x = n `Bits.shift` (-shift_by)
        iters = 0
        msg = ""
    in integerSquareRoot2Loop n n_3 x iters msg

integerSquareRoot2Loop :: Integer -> Integer -> Integer -> Int -> String -> (Integer, String, Int)
integerSquareRoot2Loop _ _ x iters "Exact"  = (x, "Exact", iters)
integerSquareRoot2Loop _ _ x iters "Low"  = (x, "Low", iters)
integerSquareRoot2Loop _ _ x 1000 _  = (x, "No convergence", 1000)
integerSquareRoot2Loop n n_3 x iters msg = integerSquareRoot2Loop n n_3 x_next iters_next msg_new
    where x_next = (x * (x * x + n_3)) `div` (3 * x * x + n)
          iters_next = iters + 1
          close = abs (x_next - x) < 1
          msg_new
             | close && (x_next * x_next == n) = "Exact"
             | close = "Low"
             | otherwise = "None"

{- The Baby Step Giant Step method for the discrete logarithm using an array for storage. Runs in O(n) time
-- Input x -- the number to take the discrete logarithm of
-- Input a -- the base of the logarithm
-- Input modulus -- the modulus of the logarithm
-- Return y where a^y = x (mod modulus)
   Sample and check: ghci> babyStepGiantStepList 7 17 11
                     Right 3
                     ghci> 17^3 `mod` 11
                7
-}        
babyStepGiantStepList :: Integer -> Integer -> Integer -> Either String Integer
babyStepGiantStepList x a modulus
    | modulus > 1000000 = Left "Modulus is limited to 1000000"
    | gcdEuclid a modulus /= 1 = Left "Base and modulus must be relatively prime"
    | otherwise = babyStepGiantStepListLoop done no_solution the_table m x a_to_minus_m modulus i 0
    where (root, msg, _) = integerSquareRoot modulus
          done = False
          no_solution = False
          the_table = [(x, powMod a x modulus) | x <- [0 .. m]]
          m = if msg == "Exact" then root else root + 1
          a_to_minus_m = inverseEuclid (a^m) modulus
          i = 0

babyStepGiantStepListLoop :: Bool -> Bool ->[(Integer, Integer)] -> Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Either String Integer
babyStepGiantStepListLoop True _ _ _ _ _ _ _ rv = Right rv
babyStepGiantStepListLoop _ True _ m _ _ _ i _
    | m == i = Left "No solution"
    | otherwise = Left "Programming Error!"
babyStepGiantStepListLoop done no_solution the_table m gamma a_to_minus_m modulus i rv = babyStepGiantStepListLoop new_done new_no_solution the_table m new_gamma a_to_minus_m modulus new_i new_rv
    where possible_match = [ j | (j, k) <- the_table, k == gamma]
          new_done = length possible_match == 1
          j = if new_done then possible_match!!0 else 0
          new_rv = if new_done then i * m + j else 0
          new_gamma = (gamma * a_to_minus_m) `mod` modulus
          new_i = i + 1
          new_no_solution = not done && new_i == m 

{- The Baby Step Giant Step method for the discrete logarithm using a map (hash) for storage. Runs in O(sqrt(n)) time
-- Input x -- the number to take the discrete logarithm of
-- Input a -- the base of the logarithm
-- Input modulus -- the modulus of the logarithm
-- Return -- y where a^y = x (mod modulus)
   Sample and check: ghci> babyStepGiantStep 7 17 11
                     Right 3
                     ghci> 17^3 `mod` 11
                     7 
-}    
babyStepGiantStep :: Integer -> Integer -> Integer -> Either String Integer
babyStepGiantStep x a modulus
    | modulus > 1000000 = Left "Modulus is limited to 1000000"
    | gcdEuclid a modulus /= 1 = Left "Base and modulus must be relatively prime"
    | otherwise = babyStepGiantStepLoop done no_solution the_map m x a_to_minus_m modulus i 0
    where (root, msg, _) = integerSquareRoot modulus
          done = False
          no_solution = False
          the_map = Map.fromList $ [(powMod a x modulus, x) | x <- [0 .. m]]
          m = if msg == "Exact" then root else root + 1
          a_to_minus_m = inverseEuclid (a^m) modulus
          i = 0

babyStepGiantStepLoop :: Bool -> Bool -> Map.Map Integer Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Either String Integer
babyStepGiantStepLoop True _ _ _ _ _ _ _ rv = Right rv
babyStepGiantStepLoop _ True _ m _ _ _ i _
    | m == i = Left "No solution"
    | otherwise = Left "Programming Error!"
babyStepGiantStepLoop done no_solution the_map m gamma a_to_minus_m modulus i rv = babyStepGiantStepLoop new_done new_no_solution the_map m new_gamma a_to_minus_m modulus new_i new_rv
    where hold_j = Map.lookup gamma the_map
          new_done = not $ Maybe.isNothing hold_j
          j = if new_done then Maybe.fromJust hold_j else 0
          new_rv = if new_done then i * m + j else 0
          new_gamma = (gamma * a_to_minus_m) `mod` modulus
          new_i = i + 1
          new_no_solution = not new_done && new_i == m

{- Pollard's Rho factorization method
-- Input n -- the number to attempt to factor
-- Input starting_value -- the starting values for x and y
-- Return -- either Left Failure or Right a factor of n
   Sample: ghci> pollardFactorization 999 3
           Right 111
-}
pollardFactorization :: Integer -> Integer -> Either String Integer
pollardFactorization n starting_value = pollardFactorizationLoop n x y d 
    where x = starting_value
          y = starting_value
          d = 1

pollardFactorizationLoop :: Integer -> Integer -> Integer -> Integer -> Either String Integer
pollardFactorizationLoop n x y d
    | d == 1 = pollardFactorizationLoop n new_x new_y new_d
    | d == n = Left "Failure"
    | otherwise = Right d
    where new_x = (x * x + 1) `mod` n
          temp_y = (y * y + 1) `mod` n
          new_y = (temp_y * temp_y + 1) `mod` n
          new_d = gcdEuclid (abs (new_x - new_y)) n

{- Floyd's cycle-finding algorithm
-- Input f -- a function of two variables that cycles at some point
-- Input x0 -- a starting value for f
-- Input n -- a constant second parameter of f
-- Return -- the tuple (lambda, mu) where mu is the starting index of f's cycle and lambda is the period
   Sample: ghci> tortoiseAndHare f 2 1011
           (11,8)
-}
tortoiseAndHare :: (Integer -> Integer -> Integer) -> Integer -> Integer -> (Integer, Integer)
tortoiseAndHare f x0 n = tortoiseAndHareLoop1 f x0 n tortoise hare
    where tortoise = f x0 n
          hare = f tortoise n

{- If we knew where the "circle" breaks away from the straight line, i.e., where the function becomes periodic, it
would be relatively easy to find the periodicity, lambda. We'd just count the number of steps from that starting point,
mu, until we found a matching function value. The third loop does just that.

The trick is to start at the beginning and let the hare take two steps to the tortoise's one until matching function
values are found. At that point the distance between the "racers" is a multiple of lambda, call it nu. That's loop one.
      
At this point the hare is 2 * nu from the beginning. Now, reset the tortoise to the beginning and let the hare and the
tortoise proceed at the same rate, always remaining 2 * nu steps apart. When the tortoise reaches point mu, both racers
will be at the same value, because the tortoise is at the beginning of the circle and the hare has gone around the circle an integral number, 2 * nu of times. That's loop two.

With mu in hand, we go to loop 3 to find lambda.
-}

tortoiseAndHareLoop1 :: (Integer -> Integer -> Integer) -> Integer -> Integer -> Integer -> Integer -> (Integer, Integer)
tortoiseAndHareLoop1 f x0 n tortoise hare
    | tortoise == hare = tortoiseAndHareLoop2 f x0 n 0 hare
    | otherwise = tortoiseAndHareLoop1 f x0 n new_tortoise new_hare
    where new_tortoise = f tortoise n
          temp_new_hare = f hare n
          new_hare = f temp_new_hare n

tortoiseAndHareLoop2 :: (Integer -> Integer -> Integer) -> Integer -> Integer -> Integer -> Integer -> (Integer, Integer)
tortoiseAndHareLoop2 f tortoise n mu hare
    | tortoise == hare = tortoiseAndHareLoop3 f n 1 (f tortoise n) tortoise mu
    | otherwise = tortoiseAndHareLoop2 f new_tortoise n new_mu new_hare
    where new_tortoise = f tortoise n
          new_hare = f hare n
          new_mu = mu + 1

tortoiseAndHareLoop3 :: (Integer -> Integer -> Integer) -> Integer -> Integer -> Integer -> Integer -> Integer -> (Integer, Integer)
tortoiseAndHareLoop3 f n lam hare tortoise mu
    | tortoise == hare = (lam, mu)
    | otherwise = tortoiseAndHareLoop3 f n new_lam new_hare tortoise mu
    where new_hare = f hare n  
          new_lam = lam + 1

{- A typical function for Floyd's algorithm
-- Input x -- the function's first parameter
-- Input n -- the function's second parameter
-- Return -- x^2 + 1 (mod n)
-}
f :: Integer -> Integer -> Integer
f x n = (x * x + 1) `mod` n

{- Brent's modification of Floyd's cycle-finding algorithm
-- Input f -- a function of two variables that cycles at some point
-- Input x0 -- a starting value for f
-- Input n -- a constant second parameter of f
-- Return -- the tuple (lambda, mu) where mu is the starting index of f's cycle and lambda is the period
   Sample: ghci> tAndHBrent f 2 1011
           (11,8)
-}
tAndHBrent :: (Integer -> Integer -> Integer) -> Integer -> Integer -> (Integer, Integer)
tAndHBrent f x0 n = tAndHBrentLoop1 f x0 n pow lambda tortoise hare
    where pow = 1
          lambda = 1
          tortoise = x0
          hare = f x0 n

-- Find lambda by searching in successively larger power-of-two sized windows     
tAndHBrentLoop1 :: (Integer -> Integer -> Integer) -> Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> (Integer, Integer)
tAndHBrentLoop1 f x0 n pow lambda tortoise hare
    | tortoise == hare = tAndHBrentLoop2 f x0 n 0 lambda x0 x0
    | otherwise = tAndHBrentLoop1 f x0 n new_pow new_lambda new_tortoise new_hare
    where new_pow = if pow == lambda then (2 * pow) else pow
          new_lambda = if pow == lambda then 1 else (lambda + 1)
          new_tortoise = if pow == lambda then hare else tortoise
          new_hare = f hare n

-- With the tortoise at the beginning, set the hare ahead by the periodicity, lambda
tAndHBrentLoop2 :: (Integer -> Integer -> Integer) -> Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> (Integer, Integer)         
tAndHBrentLoop2 f x0 n i lambda tortoise hare
    | i == lambda = tAndHBrentLoop3 f n 0 lambda tortoise hare
    | otherwise = tAndHBrentLoop2 f x0 n new_i lambda tortoise new_hare
        where new_hare = f hare n
              new_i = i + 1
 
-- Keeping them a distance lambda apart, move both the tortoise and the hare until they agree    
tAndHBrentLoop3 :: (Integer -> Integer -> Integer) -> Integer -> Integer -> Integer -> Integer -> Integer -> (Integer, Integer)
tAndHBrentLoop3 f n mu lambda tortoise hare
    | tortoise == hare = (lambda, mu)
    | otherwise = tAndHBrentLoop3 f n new_mu lambda new_tortoise new_hare
    where new_tortoise = f tortoise n
          new_hare = f hare n
          new_mu = mu + 1

{- Pollard's algorithm for finding the discrete logarithm
-- Input big_n -- the modulus of the cyclic group
-- Input n -- the group's size
-- Input alpha -- the base of the logarithm
-- Input beta -- the number to take the logarithm of
-- Return -- either the Left String "No solution" or the Right Integer logarithm
   Samples and checks: ghci> pollardLog 383 191 2 228
                       Right 110
                       ghci> 2^110 `mod` 383
                       228
                       ghci> pollardLog 1019 1018 2 1024
                       Left "Alternate solution to beta = 5, gamma = 10"
                       ghci> 2^10`mod` 1019
                       5
-}  
pollardLog :: Integer -> Integer -> Integer -> Integer -> Either String Integer
pollardLog big_n n alpha beta = pollardLogLoop big_n n alpha beta x a b big_x big_a big_b i
     where x = 1
           a = 0
           b = 0
           big_x = x
           big_a = a
           big_b = b
           i = 1

-- This loop calls newXab twice for X and once for x, repeating until x = X.
pollardLogLoop :: Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Either String Integer
pollardLogLoop big_n n alpha beta x a b big_x big_a big_b i
    | new_x == new_big_x = pollardLogFinish (new_big_b - new_b) (new_a - new_big_a) n big_n alpha beta
    | new_i == n - 2 = Left "No solution"
    | otherwise = pollardLogLoop big_n n alpha beta new_x new_a new_b new_big_x new_big_a new_big_b new_i
    where (new_x, new_a, new_b) = newXab x a b big_n n alpha beta
          (tmp_new_big_x, tmp_new_big_a, tmp_new_big_b) = newXab big_x big_a big_b big_n n alpha beta
          (new_big_x, new_big_a, new_big_b) = newXab tmp_new_big_x tmp_new_big_a tmp_new_big_b big_n n alpha beta
          new_i = i + 1

{- Solve the linear congruence for the logarithm by reducing it by the common denominator of the left and right sides, then reducing the modulus.
-}
pollardLogFinish :: Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Either String Integer
pollardLogFinish lhs rhs n big_n alpha beta = rv
    where tmp_lhs = if lhs < 0 then lhs + n else lhs
          tmp_rhs = if rhs < 0 then rhs + n else rhs
          gcd_lhs_rhs = gcdEuclid tmp_lhs tmp_rhs
          gcd_lhs_n = gcdEuclid tmp_lhs n
          new_lhs = tmp_lhs `div` gcd_lhs_rhs
          new_rhs = tmp_rhs `div` gcd_lhs_rhs
          new_n = n `div` gcd_lhs_n
          inv = inverseEuclid new_lhs new_n
          log_try = (inv * new_rhs) `mod` new_n
          beta_test = powMod alpha log_try big_n
          rv = if beta_test /= beta then Left ("Alternate solution to beta = " ++ (show beta_test) ++ ", gamma = " ++ (show log_try)) else Right log_try

{- newXab is the function used by pollardLog to "nudge" the tortoise and hare positions. The movements are about 1/3 multiplies by a, 1/3 multiplies by b, and 1/3 squaring of x.
-}          
newXab :: Integer -> Integer -> Integer -> Integer -> Integer -> Integer -> Integer-> (Integer, Integer, Integer)
newXab x a b big_n n alpha beta
    | rem == 0 = (x * x `mod` big_n, 2 * a `mod` n, 2 * b `mod` n) 
    | rem == 1 = (x * alpha `mod` big_n, a + 1 `mod` n, b)
    | rem == 2 = (x * beta `mod` big_n, a, b + 1 `mod` n)
    where  rem = x `mod` 3

{- The Dixon algorithm for factoring integers
-- Input -- n, The number to factor
-- Input -- starting_num, The start of the range to check for possible congruence candidates
-- Input -- count, The number of candidate numbers in the list to pair-up
-- Return -- Either Left "No solution" or Right (f1, f2) where f1 and f2 are factors of n
   Sample: dixonAlgorithm 100160063 1200000 4
           Right (10009,10007)
           ghci> 10007 * 10009
           100160063
-}
dixonAlgorithm :: Int -> Int -> Int -> Either String (Int, Int)
dixonAlgorithm n starting_num count = soltn
    where p_list = [2,3,5,7]
          possibles = take count [x | x <- [starting_num..], isPossible (x^2 `mod` n)] 
          pairs = [(x, y) | x <- possibles, y <- possibles, x < y]
          goods = [(x, y) | (x, y) <- pairs, testPair (x, y) n p_list]
          goods_keepers = [(x, y) | (x, y) <- goods, fst(getFactors (x, y) n p_list) /= 1, fst(getFactors (x, y) n p_list) /= n]
          soltn = if length goods_keepers == 0 then Left "No solution" else Right $ getFactors (goods_keepers!!0) n p_list

{- Given an integer, see if its only prime factors are 2, 3, 5, and 7
-- Input n -- the integer to check for prime factors
-- Return -- True if all the prime factors are <= 7, False otherwise
   Sample: ghci> isPossible 100
           True
           ghci> isPossible 34
           False
-}     
isPossible :: Int -> Bool
isPossible n = (product $ collectFactors n [2, 3, 5, 7, 11]) == n

{- Collect the prime factors 2, 3, 5, and 7, repeated, as necessary, returning them in an array
-- Input n -- the integer to check for prime factors
-- Input -- the array [2, 3, 5, 7, 11]
-- Return -- the prime factors of n, as an array, with primes repeated as necessary
   Sample: ghci> collectFactors 100 [2,3,5,7,11]
           [2,2,5,5]
-}
collectFactors :: Int -> [Int] -> [Int]
collectFactors n ps@(p:ps')
    | p == 11 = take (length ps - 1) ps -- Trim the 11 off the right end. In this case it's just a sentinel.
    | otherwise = (if r == 0 then p : collectFactors d ps else collectFactors n ps')
        where (d, r) = n `divMod` p
        
{- Add any missing prime factors among 2, 3, 5, and 7 to power 0 for ease of later processing
-- Input -- lst, An array of tuples of the form [(p, i)] where p is one of 2, 3, 5, and 7 and i is the prime's power
-- Return -- An array of 4 tuples of the input form, but with missing 0 powers inserted
   Sample: ghci> paddedPossible [(2,2), (3,1)]
           [(2,2),(3,1),(5,0),(7,0)]
-}
paddedPossible :: [(Int, Int)] -> [(Int, Int)]
paddedPossible lst = p7_lst
    where prime_factors = [x | (x, y) <- lst]
          p2_lst = if 2 `elem` prime_factors then lst else (2, 0) : lst
          (p2:rest) = p2_lst
          p3_lst = if 3 `elem` prime_factors then p2_lst else p2:(3, 0):rest
          (p1':p2':rest2) = p3_lst
          p5_lst = if 5 `elem` prime_factors then p3_lst else p1':p2':(5, 0): rest2
          p7_lst = if 7 `elem` prime_factors then p5_lst else p5_lst ++ [(7, 0)]

{- Test a pair of integers to see if their squared modulus is a perfect square
-- Input -- A tuple of two integers to test
-- Input -- n, the number to factor
-- Return -- True if the input pair satisfies the Dixon criterion, False otherwise
-}    
testPair :: (Int, Int) -> Int -> [Int] -> Bool
testPair (n1, n2) n p_list = length odds == 0
    where n1_padded = paddedPossible $ primeFactorizationWithPrimes (n1^2 `mod` n) p_list
          n2_padded = paddedPossible $ primeFactorizationWithPrimes (n2^2 `mod` n) p_list
          mult = zipWith (\(x1, y1) (x2, y2) -> (y1 + y2)) n1_padded n2_padded
          odds = filter (odd) mult

{- Given a pair of integers satisfying the Dixon criterion, find the factors of the number n
-- Input -- A tuple of two integers satisfying the Dixon criterion
-- Input -- n, the number to factor
-- Return -- A tuple of integer factors of n
-}    
getFactors :: (Int, Int) -> Int -> [Int] -> (Int, Int)
getFactors (n1, n2) n p_list= (f1, f2)
    where d1 = n1*n2 `mod` n
          n1_padded = paddedPossible $ primeFactorizationWithPrimes (n1^2 `mod` n) p_list
          n2_padded = paddedPossible $ primeFactorizationWithPrimes (n2^2 `mod` n) p_list
          comb = zipWith (\(x1, y1) (x2, y2) -> (x1, (y1 + y2) `div` 2)) n1_padded n2_padded
          d2 = foldl (\acc (x, y) -> acc * x^y) 1 comb
          f1 = gcd (d1 + d2) n
          -- f2 = gcd (d1 - d2) n
          f2 = n `div` f1

{- Solve the quadratic residue equation x^2 = n (mod p) for x via the Tonelli-Shanks algorithm
-- Input -- p, the prime for the residue calculations
-- Input -- n, the quadratic residue
-- Return -- either Left "n is not a quadratic residue modulo p" or (s1, s2) where s1 and s2 are solutions
   Sample: ghci> shanksTonelli 1000033 6
           Right (348481,651552)
           ghci> 348481^2 `mod` 1000033
           6
           ghci> 651552^2 `mod` 1000033
           6
           ghci> shanksTonelli 1000033 5
           Left "5 is not a quadratic residue modulo 1000033"
-}
shanksTonelli :: Integer -> Integer -> Either String (Integer, Integer)
shanksTonelli p n = if i == 1 then Right (sTAlgorithm p n) else Left $ show n ++ " is not a quadratic residue modulo " ++ show p  
    where i = n ^ ((p - 1) `div` 2) `mod` p

sTAlgorithm :: Integer -> Integer -> (Integer, Integer)
sTAlgorithm p n = sTAlgorithmLoop p r t m c
    where (s, q) = millerSAndT p
          pm1 = p - 1
          z = head [x | x <- [2..pm1], x^(pm1 `div` 2) `mod` p == pm1]
          c = z^q `mod` p
          r = n^((q + 1) `div` 2) `mod` p
          t = n^q `mod` p
          m = s
          
sTAlgorithmLoop :: Integer -> Integer -> Integer -> Integer -> Integer -> (Integer, Integer)
sTAlgorithmLoop p r 1 _ _ = (r, p - r)
sTAlgorithmLoop p r t m c = sTAlgorithmLoop p new_r new_t new_m new_c
    where i = head [x | x <- [1..(m - 1)], t^(2^x) `mod` p == 1]
          b = c^2^(m - i - 1) `mod` p
          new_r = r * b `mod` p
          new_t = t * b * b `mod` p
          new_m = i
          new_c = b * b `mod` p