/ src / main.rs
main.rs
  1  use std::collections::HashSet;
  2  use std::io;
  3  use clap::Parser;
  4  use num_bigint::BigUint;
  5  use num_traits::One;
  6  
  7  /// Computes the continued fraction expansion of sqrt(n) and returns the period
  8  fn continued_fraction_sqrt(n: u64) -> (u64, Vec<u64>) {
  9      let a0 = (n as f64).sqrt() as u64;
 10  
 11      // If n is a perfect square, return early
 12      if a0 * a0 == n {
 13          return (a0, vec![]);
 14      }
 15  
 16      let mut m = 0;
 17      let mut d = 1;
 18      let mut a = a0;
 19  
 20      let mut expansion = vec![];
 21      let mut seen_states = HashSet::new();
 22  
 23      loop {
 24          m = d * a - m;
 25          d = (n - m * m) / d;
 26          a = (a0 + m) / d;
 27  
 28          // Store the current state
 29          let state = (m, d, a);
 30  
 31          // Check if we've seen this state before (indicates periodicity)
 32          if seen_states.contains(&state) {
 33              // Return the integer part and the periodic part
 34              return (a0, expansion);
 35          }
 36  
 37          expansion.push(a);
 38          seen_states.insert(state);
 39      }
 40  }
 41  
 42  /// Computes the convergents of the continued fraction expansion
 43  fn compute_convergents(a0: u64, period: &[u64], count: usize) -> Vec<(BigUint, BigUint)> {
 44      // Initialize with special cases for first two convergents
 45      let mut h = vec![BigUint::from(a0), BigUint::from(a0) * BigUint::from(period[0]) + BigUint::one()];
 46      let mut k = vec![BigUint::one(), BigUint::from(period[0])];
 47  
 48      for i in 1..count {
 49          let idx = i % period.len();
 50          let a = BigUint::from(period[idx]);
 51  
 52          let next_h = &a * &h[i] + &h[i - 1];
 53          let next_k = &a * &k[i] + &k[i - 1];
 54  
 55          h.push(next_h);
 56          k.push(next_k);
 57      }
 58  
 59      h.into_iter().zip(k.into_iter()).collect()
 60  }
 61  
 62  /// Checks if a pair (x, y) is a solution to the Pell equation x² - ny² = 1
 63  fn is_pell_solution(x: &BigUint, y: &BigUint, n: u64) -> bool {
 64      let x_squared = x * x;
 65      let ny_squared = BigUint::from(n) * y * y;
 66  
 67      x_squared == ny_squared + BigUint::one()
 68  }
 69  
 70  /// Finds solutions to the Pell equation x² - ny² = 1
 71  fn find_pell_solutions(n: u64, count: usize) -> Vec<(BigUint, BigUint)> {
 72      // If n is a perfect square, there are no solutions
 73      let sqrt_n = (n as f64).sqrt() as u64;
 74      if sqrt_n * sqrt_n == n {
 75          return vec![];
 76      }
 77  
 78      // Get the continued fraction expansion of sqrt(n)
 79      let (a0, period) = continued_fraction_sqrt(n);
 80  
 81      // If the period is empty, return early
 82      if period.is_empty() {
 83          return vec![];
 84      }
 85  
 86      // Compute convergents until we have enough solutions
 87      let period_length = period.len();
 88      let mut num_convergents = period_length * 2; // Usually sufficient to find at least one solution
 89  
 90      let max_iterations = 1000;
 91      let mut iterations = 0;
 92  
 93      loop {
 94          iterations += 1;
 95          if iterations > max_iterations {
 96              println!("Reached maximum iterations without finding a solution.");
 97              return vec![];
 98          }
 99  
100          let convergents = compute_convergents(a0, &period, num_convergents);
101  
102          // Check which convergents are solutions to the Pell equation
103          let solutions: Vec<(BigUint, BigUint)> = convergents
104              .into_iter()
105              .filter(|(x, y)| is_pell_solution(x, y, n))
106              .collect();
107  
108          if solutions.len() >= count {
109              return solutions.into_iter().take(count).collect();
110          }
111  
112          // If we don't have enough solutions, compute more convergents
113          num_convergents *= 2;
114      }
115  }
116  
117  /// Generate more solutions using the fundamental solution
118  fn generate_more_solutions(x1: &BigUint, y1: &BigUint, n: u64, count: usize) -> Vec<(BigUint, BigUint)> {
119      let mut solutions = vec![(x1.clone(), y1.clone())];
120      let mut x_prev = x1.clone();
121      let mut y_prev = y1.clone();
122  
123      for _ in 1..count {
124          let x_next = x1 * &x_prev + BigUint::from(n) * y1 * &y_prev;
125          let y_next = x1 * &y_prev + y1 * &x_prev;
126  
127          solutions.push((x_next.clone(), y_next.clone()));
128          x_prev = x_next;
129          y_prev = y_next;
130      }
131  
132      solutions
133  }
134  
135  #[derive(Parser)]
136  struct Cli {
137      /// The value of n in the Pell equation x² - n·y² = 1
138      n: Option<u64>,
139  
140      /// The number of solutions to find
141      #[clap(short = 'c', long)] // Add short form 'c' for --count
142      count: Option<usize>,
143  }
144  
145  fn main() {
146      let args = Cli::parse();
147  
148      // If n is provided as a command-line argument, use it
149      if let Some(n) = args.n {
150          let count = args.count.unwrap_or(5);
151          solve_pell(n, count);
152      } else {
153          // Otherwise, run in interactive mode
154          interactive_mode();
155      }
156  }
157  
158  /// Solves the Pell equation for a given n and count
159  fn solve_pell(n: u64, count: usize) {
160      println!("Pell equation solver: x² - {}·y² = 1", n);
161  
162      // Check if n is a perfect square
163      let sqrt_n = (n as f64).sqrt() as u64;
164      if sqrt_n * sqrt_n == n {
165          println!("n = {} is a perfect square. The Pell equation has no non-trivial solutions.", n);
166          return;
167      }
168  
169      // Find the fundamental solution
170      println!("Finding solutions...");
171      let solutions = find_pell_solutions(n, 1);
172  
173      if solutions.is_empty() {
174          println!("Could not find solutions for n = {}", n);
175          return;
176      }
177  
178      let (x1, y1) = &solutions[0];
179      println!("\nFundamental solution (x, y) = ({}, {})", x1, y1);
180  
181      // Generate more solutions
182      println!("Generating additional solutions...");
183      let all_solutions = generate_more_solutions(x1, y1, n, count);
184  
185      println!("\nSolutions to x² - {}·y² = 1:", n);
186      for (i, (x, y)) in all_solutions.iter().enumerate() {
187          println!("Solution {}: (x, y) = ({}, {})", i + 1, x, y);
188      }
189  }
190  
191  /// Runs the program in interactive mode
192  fn interactive_mode() {
193      println!("Pell equation solver: x² - n·y² = 1");
194  
195      loop {
196          println!("\nEnter a positive non-square integer n (or 0 to exit):");
197  
198          let mut input = String::new();
199          io::stdin().read_line(&mut input).expect("Failed to read input");
200  
201          let n: u64 = match input.trim().parse() {
202              Ok(num) if num > 0 => num,
203              Ok(0) => break, // Exit if the user enters 0
204              _ => {
205                  println!("Please enter a valid positive integer.");
206                  continue;
207              }
208          };
209  
210          // Check if n is a perfect square
211          let sqrt_n = (n as f64).sqrt() as u64;
212          if sqrt_n * sqrt_n == n {
213              println!("n = {} is a perfect square. The Pell equation has no non-trivial solutions.", n);
214              continue;
215          }
216  
217          println!("How many solutions would you like to find?");
218          let mut input = String::new();
219          io::stdin().read_line(&mut input).expect("Failed to read input");
220  
221          let count: usize = match input.trim().parse() {
222              Ok(num) if num > 0 => num,
223              _ => {
224                  println!("Using default count of 5.");
225                  5
226              }
227          };
228  
229          // Solve the Pell equation
230          solve_pell(n, count);
231      }
232  
233      println!("Goodbye!");
234  }
235  
236  #[cfg(test)]
237  mod tests {
238      use super::*;
239  
240      #[test]
241      fn test_continued_fraction_sqrt() {
242          let (a0, period) = continued_fraction_sqrt(92);
243          assert_eq!(a0, 9);
244          assert_eq!(period, vec![1, 1, 2, 4, 2, 1, 1, 18]);
245      }
246  
247      #[test]
248      fn test_is_pell_solution() {
249          let x = BigUint::from(1151u64);
250          let y = BigUint::from(120u64);
251          assert!(is_pell_solution(&x, &y, 92));
252  
253          let x = BigUint::from(10u64);
254          let y = BigUint::from(1u64);
255          assert!(!is_pell_solution(&x, &y, 92));
256      }
257  
258      #[test]
259      fn test_find_pell_solutions() {
260          let solutions = find_pell_solutions(92, 1);
261          assert_eq!(solutions.len(), 1);
262          assert_eq!(solutions[0], (BigUint::from(1151u64), BigUint::from(120u64)));
263      }
264  
265      #[test]
266      fn test_generate_more_solutions() {
267          let x1 = BigUint::from(1151u64);
268          let y1 = BigUint::from(120u64);
269          let solutions = generate_more_solutions(&x1, &y1, 92, 3);
270  
271          assert_eq!(solutions.len(), 3);
272          assert_eq!(solutions[0], (x1.clone(), y1.clone()));
273          assert_eq!(solutions[1], (BigUint::from(2649601u64), BigUint::from(276240u64)));
274          assert_eq!(solutions[2], (BigUint::from(6099380351u64), BigUint::from(635904360u64)));
275      }
276  }