/ src / modules / MouseWithoutBorders / App / Core / Encryption.cs
Encryption.cs
  1  // Copyright (c) Microsoft Corporation
  2  // The Microsoft Corporation licenses this file to you under the MIT license.
  3  // See the LICENSE file in the project root for more information.
  4  
  5  using System;
  6  using System.Collections.Concurrent;
  7  using System.Globalization;
  8  using System.IO;
  9  using System.Linq;
 10  using System.Security.Cryptography;
 11  using System.Threading.Tasks;
 12  
 13  // <summary>
 14  //     Encrypt/decrypt implementation.
 15  // </summary>
 16  // <history>
 17  //     2008 created by Truong Do (ductdo).
 18  //     2009-... modified by Truong Do (TruongDo).
 19  //     2023- Included in PowerToys.
 20  // </history>
 21  namespace MouseWithoutBorders.Core;
 22  
 23  internal static class Encryption
 24  {
 25  #pragma warning disable SYSLIB0021
 26      private static AesCryptoServiceProvider symAl;
 27  #pragma warning restore SYSLIB0021
 28  #pragma warning disable SA1307 // Accessible fields should begin with upper-case letter
 29      internal static string myKey;
 30  #pragma warning restore SA1307
 31      private static uint magicNumber;
 32      private static Random ran = new(); // Used for non encryption related functionality.
 33      internal const int SymAlBlockSize = 16;
 34  
 35      /// <summary>
 36      /// This is used for the first encryption block, the following blocks will be combined with the cipher text of the previous block.
 37      /// Thus identical blocks in the socket stream would be encrypted to different cipher text blocks.
 38      /// The first block is a handshake one containing random data.
 39      /// Related Unit Test: TestEncryptDecrypt
 40      /// </summary>
 41      private static readonly string InitialIV = ulong.MaxValue.ToString(CultureInfo.InvariantCulture);
 42  
 43      internal static Random Ran
 44      {
 45          get => Encryption.ran ??= new Random();
 46          set => Encryption.ran = value;
 47      }
 48  
 49      internal static uint MagicNumber
 50      {
 51          get => Encryption.magicNumber;
 52          set => Encryption.magicNumber = value;
 53      }
 54  
 55      internal static string MyKey
 56      {
 57          get => Encryption.myKey;
 58  
 59          set
 60          {
 61              if (Encryption.myKey != value)
 62              {
 63                  Encryption.myKey = value;
 64                  _ = Task.Factory.StartNew(
 65                      () => Encryption.GenLegalKey(),
 66                      System.Threading.CancellationToken.None,
 67                      TaskCreationOptions.None,
 68                      TaskScheduler.Default); // Cache the key to improve UX.
 69              }
 70          }
 71      }
 72  
 73      private static string KeyDisplayedText(string key)
 74      {
 75          string displayedValue = string.Empty;
 76          int i = 0;
 77  
 78          do
 79          {
 80              int length = Math.Min(4, key.Length - i);
 81              displayedValue += string.Concat(key.AsSpan(i, length), "  ");
 82              i += 4;
 83          }
 84          while (i < key.Length - 1);
 85  
 86          return displayedValue.Trim();
 87      }
 88  
 89      internal static bool GeneratedKey { get; set; }
 90  
 91      internal static bool KeyCorrupted { get; set; }
 92  
 93      internal static void InitEncryption()
 94      {
 95          try
 96          {
 97              if (symAl == null)
 98              {
 99  #pragma warning disable SYSLIB0021 // No proper replacement for now
100                  symAl = new AesCryptoServiceProvider();
101  #pragma warning restore SYSLIB0021
102                  symAl.KeySize = 256;
103                  symAl.BlockSize = SymAlBlockSize * 8;
104                  symAl.Padding = PaddingMode.Zeros;
105                  symAl.Mode = CipherMode.CBC;
106                  symAl.GenerateIV();
107              }
108          }
109          catch (Exception e)
110          {
111              Logger.Log(e);
112          }
113      }
114  
115      private static readonly ConcurrentDictionary<string, byte[]> LegalKeyDictionary = new(StringComparer.OrdinalIgnoreCase);
116  
117      private static byte[] GenLegalKey()
118      {
119          byte[] rv;
120          string myKey = Encryption.MyKey;
121  
122          if (!LegalKeyDictionary.TryGetValue(myKey, out byte[] value))
123          {
124              Rfc2898DeriveBytes key = new(
125                  myKey,
126                  Common.GetBytesU(InitialIV),
127                  50000,
128                  HashAlgorithmName.SHA512);
129              rv = key.GetBytes(32);
130              _ = LegalKeyDictionary.AddOrUpdate(myKey, rv, (k, v) => rv);
131          }
132          else
133          {
134              rv = value;
135          }
136  
137          return rv;
138      }
139  
140      private static byte[] GenLegalIV()
141      {
142          string st = InitialIV;
143          int ivLength = symAl.IV.Length;
144          if (st.Length > ivLength)
145          {
146              st = st[..ivLength];
147          }
148          else if (st.Length < ivLength)
149          {
150              st = st.PadRight(ivLength, ' ');
151          }
152  
153          return Common.GetBytes(st);
154      }
155  
156      internal static Stream GetEncryptedStream(Stream encryptedStream)
157      {
158          ICryptoTransform encryptor;
159          encryptor = symAl.CreateEncryptor(GenLegalKey(), GenLegalIV());
160          return new CryptoStream(encryptedStream, encryptor, CryptoStreamMode.Write);
161      }
162  
163      internal static Stream GetDecryptedStream(Stream encryptedStream)
164      {
165          ICryptoTransform decryptor;
166          decryptor = symAl.CreateDecryptor(GenLegalKey(), GenLegalIV());
167          return new CryptoStream(encryptedStream, decryptor, CryptoStreamMode.Read);
168      }
169  
170      internal static uint Get24BitHash(string st)
171      {
172          if (string.IsNullOrEmpty(st))
173          {
174              return 0;
175          }
176  
177          byte[] bytes = new byte[Package.PACKAGE_SIZE];
178          for (int i = 0; i < Package.PACKAGE_SIZE; i++)
179          {
180              if (i < st.Length)
181              {
182                  bytes[i] = (byte)st[i];
183              }
184          }
185  
186          var hash = SHA512.Create();
187          byte[] hashValue = hash.ComputeHash(bytes);
188  
189          for (int i = 0; i < 50000; i++)
190          {
191              hashValue = hash.ComputeHash(hashValue);
192          }
193  
194          Logger.LogDebug(string.Format(CultureInfo.CurrentCulture, "magic: {0},{1},{2}", hashValue[0], hashValue[1], hashValue[^1]));
195          hash.Clear();
196          return (uint)((hashValue[0] << 23) + (hashValue[1] << 16) + (hashValue[^1] << 8) + hashValue[2]);
197      }
198  
199      internal static string GetDebugInfo(string st)
200      {
201          return string.IsNullOrEmpty(st) ? st : ((byte)(Common.GetBytesU(st).Sum(value => value) % 256)).ToString(CultureInfo.InvariantCulture);
202      }
203  
204      internal static string CreateDefaultKey()
205      {
206          return CreateRandomKey();
207      }
208  
209      private const int PW_LENGTH = 16;
210  
211      internal static string CreateRandomKey()
212      {
213          // Not including characters like "'`O0& since they are confusing to users.
214          string[] chars = new[] { "abcdefghjkmnpqrstuvxyz", "ABCDEFGHJKMNPQRSTUVXYZ", "123456789", "~!@#$%^*()_-+=:;<,>.?/\\|[]" };
215          char[][] charactersUsedForKey = chars.Select(charset => Enumerable.Range(0, charset.Length - 1).Select(i => charset[i]).ToArray()).ToArray();
216          byte[] randomData = new byte[1];
217          string key = string.Empty;
218  
219          do
220          {
221              foreach (string set in chars)
222              {
223                  randomData = RandomNumberGenerator.GetBytes(1);
224                  key += set[randomData[0] % set.Length];
225  
226                  if (key.Length >= PW_LENGTH)
227                  {
228                      break;
229                  }
230              }
231          }
232          while (key.Length < PW_LENGTH);
233  
234          return key;
235      }
236  
237      internal static bool IsKeyValid(string key, out string error)
238      {
239          error = string.IsNullOrEmpty(key) || key.Length < 16
240              ? "Key must have at least 16 characters in length (spaces are discarded). Key must be auto generated in one of the machines."
241              : null;
242  
243          return error == null;
244      }
245  }