/ Sources / FoundationNetworking / DataURLProtocol.swift
DataURLProtocol.swift
  1  // This source file is part of the Swift.org open source project
  2  //
  3  // Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors
  4  // Licensed under Apache License v2.0 with Runtime Library Exception
  5  //
  6  // See https://swift.org/LICENSE.txt for license information
  7  // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
  8  //
  9  
 10  // Protocol implementation of data: URL scheme
 11  
 12  #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
 13  import SwiftFoundation
 14  #else
 15  import Foundation
 16  #endif
 17  
 18  
 19  // Iterate through a SubString validating that the input is ASCII and converting any %xx
 20  // percent endcoded hex sequences to a UInt8 byte.
 21  private struct _PercentDecoder: IteratorProtocol {
 22  
 23      enum Element {
 24          case asciiCharacter(Character)
 25          case decodedByte(UInt8)
 26          case invalid                    // Not ASCII or hex encoded
 27      }
 28  
 29      private let subString: Substring
 30      private var currentIndex: String.Index
 31      var remainingString: Substring { subString[currentIndex...] }
 32  
 33  
 34      init(subString: Substring) {
 35          self.subString = subString
 36          currentIndex = subString.startIndex
 37      }
 38  
 39      mutating private func nextChar() -> Character? {
 40          guard currentIndex < subString.endIndex else { return nil }
 41          let ch = subString[currentIndex]
 42          currentIndex = subString.index(after: currentIndex)
 43          return ch
 44      }
 45  
 46      mutating func next() -> _PercentDecoder.Element? {
 47          guard let ch = nextChar() else { return nil }
 48  
 49          guard let asciiValue = ch.asciiValue else { return .invalid }
 50  
 51          guard asciiValue == UInt8(ascii: "%") else {
 52              return .asciiCharacter(ch)
 53          }
 54  
 55          // Decode the %xx value
 56          guard let hiNibble = nextChar(), hiNibble.isASCII,
 57              let hiNibbleValue = hiNibble.hexDigitValue else {
 58                  return .invalid
 59          }
 60  
 61          guard let loNibble = nextChar(), loNibble.isASCII,
 62              let loNibbleValue = loNibble.hexDigitValue else {
 63                  return .invalid
 64          }
 65          let byte = UInt8(hiNibbleValue) << 4 | UInt8(loNibbleValue)
 66          return .decodedByte(byte)
 67      }
 68  }
 69  
 70  
 71  internal class _DataURLProtocol: URLProtocol {
 72  
 73      override class func canInit(with request: URLRequest) -> Bool {
 74          return request.url?.scheme == "data"
 75      }
 76  
 77      override class func canInit(with task: URLSessionTask) -> Bool {
 78          return task.currentRequest?.url?.scheme == "data"
 79      }
 80  
 81      override class func canonicalRequest(for request: URLRequest) -> URLRequest {
 82          return request
 83      }
 84  
 85      override func startLoading() {
 86          guard let urlClient = self.client else { fatalError("No URLProtocol client set") }
 87  
 88          if let (response, decodedData) = decodeURI() {
 89              urlClient.urlProtocol(self, didReceive: response, cacheStoragePolicy: .allowed)
 90              urlClient.urlProtocol(self, didLoad: decodedData)
 91              urlClient.urlProtocolDidFinishLoading(self)
 92          } else {
 93              let error = NSError(domain: NSURLErrorDomain, code: NSURLErrorBadURL)
 94              if let session = self.task?.session as? URLSession, let delegate = session.delegate as? URLSessionTaskDelegate,
 95                  let task = self.task {
 96                  delegate.urlSession(session, task: task, didCompleteWithError: error)
 97              }
 98          }
 99      }
100  
101  
102      private func decodeURI() -> (URLResponse, Data)? {
103          guard let url = self.request.url else {
104              return nil
105          }
106          let dataBody = url.absoluteString
107          guard dataBody.hasPrefix("data:") else {
108              return nil
109          }
110  
111          let startIdx = dataBody.index(dataBody.startIndex, offsetBy: 5)
112          var iterator = _PercentDecoder(subString: dataBody[startIdx...])
113  
114          var mimeType: String?
115          var charSet: String?
116          var base64 = false
117  
118          // Simple validation that the mime type has only one '/' and its not at the start or end.
119          func validate(mimeType: String) -> Bool {
120              if mimeType.hasPrefix("/") { return false }
121              var count = 0
122              var lastChar: Character!
123  
124              for ch in mimeType {
125                  if ch == "/" { count += 1 }
126                  if count > 1 { return false }
127                  lastChar = ch
128              }
129              guard count == 1 else { return false }
130              return lastChar != "/"
131          }
132  
133          // Determine optional mime type, optional charset and whether ;base64 flag is just before a comma.
134          func decodeHeader() -> Bool {
135              let defaultMimeType = "text/plain"
136  
137              var part = ""
138              var foundCharsetKey = false
139  
140               while let element = iterator.next() {
141                  switch element {
142                      case .asciiCharacter(let ch) where ch == Character(","):
143                          // ";base64 must be the last part just before the ',' that seperates the header from the data
144                          if foundCharsetKey {
145                              charSet = part
146                          } else {
147                              base64 = (part == ";base64")
148                          }
149                          if mimeType == nil || !validate(mimeType: mimeType!) {
150                              mimeType = defaultMimeType
151                          }
152                          return true
153  
154  
155                      case .asciiCharacter(let ch) where ch == Character(";"):
156                          // First item is the mimeType if there is a '/' in the string
157                          if mimeType == nil {
158                              if part.contains("/") {
159                                  mimeType = part
160                              } else {
161                                  mimeType = defaultMimeType // default value
162                              }
163                          }
164                          if foundCharsetKey {
165                              charSet = part
166                              foundCharsetKey = false
167                          }
168                          part = ";"
169  
170                      case .asciiCharacter(let ch) where ch == Character("="):
171                          if mimeType == nil {
172                              mimeType = defaultMimeType
173                          } else if part == ";charset" && charSet == nil {
174                              foundCharsetKey = true
175                              part = ""
176                          }
177  
178                      case .asciiCharacter(let ch):
179                          part += String(ch)
180  
181                      case .decodedByte(_), .invalid:
182                          // Dont allow percent encoded bytes in the header.
183                          return false
184                  }
185              }
186              // No comma found.
187              return false
188          }
189  
190          // Convert any percent encoding to bytes then pass the whole String to be Base64 decoded.
191          // Let the Base64 decoder take care of input validation.
192          func decodeBase64Body() -> Data? {
193              var base64encoded = ""
194              base64encoded.reserveCapacity(iterator.remainingString.count)
195  
196              while let element = iterator.next() {
197                  switch element {
198                      case .asciiCharacter(let ch):
199                          base64encoded += String(ch)
200  
201                      case .decodedByte(let value) where UnicodeScalar(value).isASCII:
202                          base64encoded += String(Character(UnicodeScalar(value)))
203  
204                      default: return nil
205                  }
206              }
207              return Data(base64Encoded: base64encoded)
208          }
209  
210          // Convert any percent encoding to bytes and append to a `Data` instance. The bytes may
211          // be valid in the specified charset in the header and not necessarily UTF-8.
212          func decodeStringBody() -> Data? {
213              var data = Data()
214              data.reserveCapacity(iterator.remainingString.count)
215  
216              while let ch = iterator.next() {
217                  switch ch {
218                      case .asciiCharacter(let ch): data.append(ch.asciiValue!)
219                      case .decodedByte(let value): data.append(value)
220                      default: return nil
221                  }
222              }
223              return data
224          }
225  
226          guard decodeHeader() else { return nil }
227          guard let decodedData = base64 ? decodeBase64Body() : decodeStringBody() else {
228              return nil
229          }
230  
231          let response = URLResponse(url: url, mimeType: mimeType, expectedContentLength: decodedData.count, textEncodingName: charSet)
232          return (response, decodedData)
233      }
234  
235      // Nothing to do here.
236      override func stopLoading() {
237      }
238  }