main.rs
1 use std::collections::HashSet; 2 use std::io; 3 use clap::Parser; 4 use num_bigint::BigUint; 5 use num_traits::One; 6 7 /// Computes the continued fraction expansion of sqrt(n) and returns the period 8 fn continued_fraction_sqrt(n: u64) -> (u64, Vec<u64>) { 9 let a0 = (n as f64).sqrt() as u64; 10 11 // If n is a perfect square, return early 12 if a0 * a0 == n { 13 return (a0, vec![]); 14 } 15 16 let mut m = 0; 17 let mut d = 1; 18 let mut a = a0; 19 20 let mut expansion = vec![]; 21 let mut seen_states = HashSet::new(); 22 23 loop { 24 m = d * a - m; 25 d = (n - m * m) / d; 26 a = (a0 + m) / d; 27 28 // Store the current state 29 let state = (m, d, a); 30 31 // Check if we've seen this state before (indicates periodicity) 32 if seen_states.contains(&state) { 33 // Return the integer part and the periodic part 34 return (a0, expansion); 35 } 36 37 expansion.push(a); 38 seen_states.insert(state); 39 } 40 } 41 42 /// Computes the convergents of the continued fraction expansion 43 fn compute_convergents(a0: u64, period: &[u64], count: usize) -> Vec<(BigUint, BigUint)> { 44 // Initialize with special cases for first two convergents 45 let mut h = vec![BigUint::from(a0), BigUint::from(a0) * BigUint::from(period[0]) + BigUint::one()]; 46 let mut k = vec![BigUint::one(), BigUint::from(period[0])]; 47 48 for i in 1..count { 49 let idx = i % period.len(); 50 let a = BigUint::from(period[idx]); 51 52 let next_h = &a * &h[i] + &h[i - 1]; 53 let next_k = &a * &k[i] + &k[i - 1]; 54 55 h.push(next_h); 56 k.push(next_k); 57 } 58 59 h.into_iter().zip(k.into_iter()).collect() 60 } 61 62 /// Checks if a pair (x, y) is a solution to the Pell equation x² - ny² = 1 63 fn is_pell_solution(x: &BigUint, y: &BigUint, n: u64) -> bool { 64 let x_squared = x * x; 65 let ny_squared = BigUint::from(n) * y * y; 66 67 x_squared == ny_squared + BigUint::one() 68 } 69 70 /// Finds solutions to the Pell equation x² - ny² = 1 71 fn find_pell_solutions(n: u64, count: usize) -> Vec<(BigUint, BigUint)> { 72 // If n is a perfect square, there are no solutions 73 let sqrt_n = (n as f64).sqrt() as u64; 74 if sqrt_n * sqrt_n == n { 75 return vec![]; 76 } 77 78 // Get the continued fraction expansion of sqrt(n) 79 let (a0, period) = continued_fraction_sqrt(n); 80 81 // If the period is empty, return early 82 if period.is_empty() { 83 return vec![]; 84 } 85 86 // Compute convergents until we have enough solutions 87 let period_length = period.len(); 88 let mut num_convergents = period_length * 2; // Usually sufficient to find at least one solution 89 90 let max_iterations = 1000; 91 let mut iterations = 0; 92 93 loop { 94 iterations += 1; 95 if iterations > max_iterations { 96 println!("Reached maximum iterations without finding a solution."); 97 return vec![]; 98 } 99 100 let convergents = compute_convergents(a0, &period, num_convergents); 101 102 // Check which convergents are solutions to the Pell equation 103 let solutions: Vec<(BigUint, BigUint)> = convergents 104 .into_iter() 105 .filter(|(x, y)| is_pell_solution(x, y, n)) 106 .collect(); 107 108 if solutions.len() >= count { 109 return solutions.into_iter().take(count).collect(); 110 } 111 112 // If we don't have enough solutions, compute more convergents 113 num_convergents *= 2; 114 } 115 } 116 117 /// Generate more solutions using the fundamental solution 118 fn generate_more_solutions(x1: &BigUint, y1: &BigUint, n: u64, count: usize) -> Vec<(BigUint, BigUint)> { 119 let mut solutions = vec![(x1.clone(), y1.clone())]; 120 let mut x_prev = x1.clone(); 121 let mut y_prev = y1.clone(); 122 123 for _ in 1..count { 124 let x_next = x1 * &x_prev + BigUint::from(n) * y1 * &y_prev; 125 let y_next = x1 * &y_prev + y1 * &x_prev; 126 127 solutions.push((x_next.clone(), y_next.clone())); 128 x_prev = x_next; 129 y_prev = y_next; 130 } 131 132 solutions 133 } 134 135 #[derive(Parser)] 136 struct Cli { 137 /// The value of n in the Pell equation x² - n·y² = 1 138 n: Option<u64>, 139 140 /// The number of solutions to find 141 #[clap(short = 'c', long)] // Add short form 'c' for --count 142 count: Option<usize>, 143 } 144 145 fn main() { 146 let args = Cli::parse(); 147 148 // If n is provided as a command-line argument, use it 149 if let Some(n) = args.n { 150 let count = args.count.unwrap_or(5); 151 solve_pell(n, count); 152 } else { 153 // Otherwise, run in interactive mode 154 interactive_mode(); 155 } 156 } 157 158 /// Solves the Pell equation for a given n and count 159 fn solve_pell(n: u64, count: usize) { 160 println!("Pell equation solver: x² - {}·y² = 1", n); 161 162 // Check if n is a perfect square 163 let sqrt_n = (n as f64).sqrt() as u64; 164 if sqrt_n * sqrt_n == n { 165 println!("n = {} is a perfect square. The Pell equation has no non-trivial solutions.", n); 166 return; 167 } 168 169 // Find the fundamental solution 170 println!("Finding solutions..."); 171 let solutions = find_pell_solutions(n, 1); 172 173 if solutions.is_empty() { 174 println!("Could not find solutions for n = {}", n); 175 return; 176 } 177 178 let (x1, y1) = &solutions[0]; 179 println!("\nFundamental solution (x, y) = ({}, {})", x1, y1); 180 181 // Generate more solutions 182 println!("Generating additional solutions..."); 183 let all_solutions = generate_more_solutions(x1, y1, n, count); 184 185 println!("\nSolutions to x² - {}·y² = 1:", n); 186 for (i, (x, y)) in all_solutions.iter().enumerate() { 187 println!("Solution {}: (x, y) = ({}, {})", i + 1, x, y); 188 } 189 } 190 191 /// Runs the program in interactive mode 192 fn interactive_mode() { 193 println!("Pell equation solver: x² - n·y² = 1"); 194 195 loop { 196 println!("\nEnter a positive non-square integer n (or 0 to exit):"); 197 198 let mut input = String::new(); 199 io::stdin().read_line(&mut input).expect("Failed to read input"); 200 201 let n: u64 = match input.trim().parse() { 202 Ok(num) if num > 0 => num, 203 Ok(0) => break, // Exit if the user enters 0 204 _ => { 205 println!("Please enter a valid positive integer."); 206 continue; 207 } 208 }; 209 210 // Check if n is a perfect square 211 let sqrt_n = (n as f64).sqrt() as u64; 212 if sqrt_n * sqrt_n == n { 213 println!("n = {} is a perfect square. The Pell equation has no non-trivial solutions.", n); 214 continue; 215 } 216 217 println!("How many solutions would you like to find?"); 218 let mut input = String::new(); 219 io::stdin().read_line(&mut input).expect("Failed to read input"); 220 221 let count: usize = match input.trim().parse() { 222 Ok(num) if num > 0 => num, 223 _ => { 224 println!("Using default count of 5."); 225 5 226 } 227 }; 228 229 // Solve the Pell equation 230 solve_pell(n, count); 231 } 232 233 println!("Goodbye!"); 234 } 235 236 #[cfg(test)] 237 mod tests { 238 use super::*; 239 240 #[test] 241 fn test_continued_fraction_sqrt() { 242 let (a0, period) = continued_fraction_sqrt(92); 243 assert_eq!(a0, 9); 244 assert_eq!(period, vec![1, 1, 2, 4, 2, 1, 1, 18]); 245 } 246 247 #[test] 248 fn test_is_pell_solution() { 249 let x = BigUint::from(1151u64); 250 let y = BigUint::from(120u64); 251 assert!(is_pell_solution(&x, &y, 92)); 252 253 let x = BigUint::from(10u64); 254 let y = BigUint::from(1u64); 255 assert!(!is_pell_solution(&x, &y, 92)); 256 } 257 258 #[test] 259 fn test_find_pell_solutions() { 260 let solutions = find_pell_solutions(92, 1); 261 assert_eq!(solutions.len(), 1); 262 assert_eq!(solutions[0], (BigUint::from(1151u64), BigUint::from(120u64))); 263 } 264 265 #[test] 266 fn test_generate_more_solutions() { 267 let x1 = BigUint::from(1151u64); 268 let y1 = BigUint::from(120u64); 269 let solutions = generate_more_solutions(&x1, &y1, 92, 3); 270 271 assert_eq!(solutions.len(), 3); 272 assert_eq!(solutions[0], (x1.clone(), y1.clone())); 273 assert_eq!(solutions[1], (BigUint::from(2649601u64), BigUint::from(276240u64))); 274 assert_eq!(solutions[2], (BigUint::from(6099380351u64), BigUint::from(635904360u64))); 275 } 276 }