WireFormatReader.swift 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. // Sources/SwiftProtobuf/WireFormatReader.swift - Low-level reader for protobuf binary format
  2. //
  3. // Copyright (c) 2014 - 2025 Apple Inc. and the project authors
  4. // Licensed under Apache License v2.0 with Runtime Library Exception
  5. //
  6. // See LICENSE.txt for license information:
  7. // https://github.com/apple/swift-protobuf/blob/main/LICENSE.txt
  8. //
  9. // -----------------------------------------------------------------------------
  10. ///
  11. /// A low-level reader for data in protobuf binary format.
  12. ///
  13. // -----------------------------------------------------------------------------
  14. /// Wraps a memory buffer and provides low-level APIs to read protobuf wire format values from it.
  15. struct WireFormatReader {
  16. /// The pointer to the location in the buffer from which the next value will be read.
  17. private var pointer: UnsafeRawPointer?
  18. /// The pointer to the location in the buffer where the last tag was read started.
  19. private var lastTagPointer: UnsafeRawPointer?
  20. /// The number of bytes that are still available to read from the buffer.
  21. private var available: Int
  22. /// The maximum number of times that the reader can recurse via sub-readers (for message and
  23. /// group parsing) before throwing an error.
  24. private var recursionBudget: Int
  25. /// If non-zero, this indicates that the reader is being used to read a group with the given
  26. /// field number.
  27. ///
  28. /// When the corresponding end-group tag is encountered, the reader will automatically update
  29. /// its internal state such that `hasAvailableData` returns false.
  30. private var trackingGroupFieldNumber: UInt32
  31. /// Creates a new `WireFormatReader` that reads from the given buffer.
  32. ///
  33. /// - Parameters:
  34. /// - buffer: The raw buffer from which binary-formatted data will be read.
  35. /// - recursionBudget: The maximum number of times that the reader can recurse via
  36. /// sub-readers (for message and group parsing) before throwing an error.
  37. init(buffer: UnsafeRawBufferPointer, recursionBudget: Int) {
  38. self.init(
  39. buffer: buffer,
  40. recursionBudget: recursionBudget,
  41. trackingGroupFieldNumber: 0
  42. )
  43. }
  44. /// Creates a new `WireFormatReader` with the given buffer and tracking information.
  45. ///
  46. /// - Parameters:
  47. /// - buffer: The raw buffer from which binary-formatted data will be read.
  48. /// - recursionBudget: The maximum number of times that the reader can recurse via
  49. /// sub-readers (for message and group parsing) before throwing an error.
  50. /// - trackingGroupFieldNumber: If non-zero, this indicates that the reader is currently
  51. /// reading a group with the given field number.
  52. private init(
  53. buffer: UnsafeRawBufferPointer,
  54. recursionBudget: Int,
  55. trackingGroupFieldNumber: UInt32
  56. ) {
  57. if buffer.count > 0, let pointer = buffer.baseAddress {
  58. self.pointer = pointer
  59. self.lastTagPointer = pointer
  60. self.available = buffer.count
  61. } else {
  62. self.pointer = nil
  63. self.lastTagPointer = nil
  64. self.available = 0
  65. }
  66. self.recursionBudget = recursionBudget
  67. self.trackingGroupFieldNumber = trackingGroupFieldNumber
  68. }
  69. /// Indicates whether there is still data available to be read in the buffer.
  70. var hasAvailableData: Bool {
  71. available > 0
  72. }
  73. /// Indicates whether the reader is currently inside a group.
  74. var isTrackingGroup: Bool {
  75. trackingGroupFieldNumber != 0
  76. }
  77. /// Reads the next varint from the input and returns the equivalent `FieldTag`, updating the
  78. /// reader's internal state to track the pointer to this tag as the last one that was read.
  79. ///
  80. /// - Throws: `BinaryDecodingError` if an error occurred while reading from the input or while
  81. /// converting to a tag.
  82. mutating func nextTag() throws -> FieldTag {
  83. self.lastTagPointer = pointer
  84. let tag = try nextTagWithoutUpdatingLastTagPointer()
  85. if tag.wireFormat == .endGroup {
  86. // Verify that we're tracking a group and the field number matches the same field
  87. // number.
  88. guard trackingGroupFieldNumber == UInt32(tag.fieldNumber) else {
  89. throw BinaryDecodingError.malformedProtobuf
  90. }
  91. available = 0
  92. trackingGroupFieldNumber = 0
  93. }
  94. return tag
  95. }
  96. /// Reads the next varint from the input and returns the equivalent `FieldTag` without updating
  97. /// the reader's state that tracks the pointer to the last tag that was read.
  98. ///
  99. /// - Throws: `BinaryDecodingError` if an error occurred while reading from the input or while
  100. /// converting to a tag.
  101. private mutating func nextTagWithoutUpdatingLastTagPointer() throws -> FieldTag {
  102. let varint = try nextVarint()
  103. guard varint < UInt64(UInt32.max), let tag = FieldTag(rawValue: UInt32(truncatingIfNeeded: varint)) else {
  104. throw BinaryDecodingError.malformedProtobuf
  105. }
  106. return tag
  107. }
  108. /// Reads and returns the next varint from the input.
  109. ///
  110. /// - Throws: `BinaryDecodingError` if an error occurred while reading from the input or if the
  111. /// varint was malformed.
  112. mutating func nextVarint() throws -> UInt64 {
  113. if available < 1 {
  114. throw BinaryDecodingError.truncated
  115. }
  116. var start = pointer!
  117. var length = available
  118. var byte = start.load(fromByteOffset: 0, as: UInt8.self)
  119. start += 1
  120. length &-= 1
  121. if byte & 0x80 == 0 {
  122. pointer = start
  123. available = length
  124. return UInt64(byte)
  125. }
  126. var value = UInt64(byte & 0x7f)
  127. var shift = UInt64(7)
  128. while true {
  129. if length < 1 || shift > 63 {
  130. throw BinaryDecodingError.malformedProtobuf
  131. }
  132. byte = start.load(fromByteOffset: 0, as: UInt8.self)
  133. start += 1
  134. length &-= 1
  135. value |= UInt64(byte & 0x7f) &<< shift
  136. if byte & 0x80 == 0 {
  137. pointer = start
  138. available = length
  139. return value
  140. }
  141. shift &+= 7
  142. }
  143. }
  144. /// Reads and returns the next 32-bit little endian integer from the input.
  145. ///
  146. /// - Throws: `BinaryDecodingError` if an error occurred while reading from the input.
  147. mutating func nextLittleEndianUInt32() throws -> UInt32 {
  148. let size = MemoryLayout<UInt32>.size
  149. guard available >= size else { throw BinaryDecodingError.truncated }
  150. defer { advance(by: size) }
  151. return UInt32(littleEndian: pointer!.loadUnaligned(as: UInt32.self))
  152. }
  153. /// Reads and returns the next 64-bit little endian integer from the input.
  154. ///
  155. /// - Throws: `BinaryDecodingError` if an error occurred while reading from the input.
  156. mutating func nextLittleEndianUInt64() throws -> UInt64 {
  157. let size = MemoryLayout<UInt64>.size
  158. guard available >= size else { throw BinaryDecodingError.truncated }
  159. defer { advance(by: size) }
  160. return UInt64(littleEndian: pointer!.loadUnaligned(as: UInt64.self))
  161. }
  162. /// Reads the next varint from the input and validates that it falls within the permitted range
  163. /// for length delimited fields.
  164. ///
  165. /// - Throws: `BinaryDecodingError` if an error occurred while reading from the input or if the
  166. /// value read was greater than 2GB.
  167. mutating func nextVarintAsValidatedDelimitedLength() throws -> Int {
  168. let length = try nextVarint()
  169. // Length-delimited fields (bytes, strings, and messages) have a maximum size of 2GB. The
  170. // upstream C++ implementation does the same sort of enforcement (see `parse_context`,
  171. // `delimited_message_util`, `message_lite`, etc.).
  172. // https://protobuf.dev/programming-guides/encoding/#cheat-sheet
  173. //
  174. // This function also gets called for some packed repeated fields
  175. // that is length delimited on the wire, so the spec would imply
  176. // the limit still applies.
  177. guard length < 0x7fff_ffff else {
  178. // Reuse existing error to avoid breaking change of changing thrown error
  179. throw BinaryDecodingError.malformedProtobuf
  180. }
  181. guard length <= UInt64(available) else {
  182. throw BinaryDecodingError.truncated
  183. }
  184. return Int(length)
  185. }
  186. /// Reads the next varint from the input that indicates the length of the subsequent data and
  187. /// then returns a rebased `UnsafeRawBufferPointer` representing that slice of the buffer,
  188. /// advancing the reader past it.
  189. ///
  190. /// - Throws: `BinaryDecodingError` if an error occurred while reading from the input.
  191. mutating func nextLengthDelimitedSlice() throws -> UnsafeRawBufferPointer {
  192. let length = try nextVarintAsValidatedDelimitedLength()
  193. defer { advance(by: length) }
  194. return UnsafeRawBufferPointer(start: pointer, count: length)
  195. }
  196. /// Calls the given function, passing it a reader that is configured only to read the slice of
  197. /// the buffer whose length is defined by reading the next varint.
  198. ///
  199. /// After the function returns, this reader is advanced past that length to continue reading.
  200. ///
  201. /// - Throws: `BinaryDecodingError` if an error occurred while reading from the input.
  202. mutating func withReaderForNextLengthDelimitedSlice(perform: (inout WireFormatReader) throws -> Void) throws {
  203. guard recursionBudget > 0 else {
  204. throw BinaryDecodingError.messageDepthLimit
  205. }
  206. let length = try nextVarintAsValidatedDelimitedLength()
  207. let subBuffer = UnsafeRawBufferPointer(start: pointer, count: length)
  208. var subReader = WireFormatReader(
  209. buffer: subBuffer,
  210. recursionBudget: recursionBudget &- 1,
  211. trackingGroupFieldNumber: 0
  212. )
  213. try perform(&subReader)
  214. advance(by: length)
  215. }
  216. /// Calls the given function, passing it a reader that is configured only to read the slice of
  217. /// the buffer up to the next end-group tag that has the same field number.
  218. ///
  219. /// After the function returns, this reader is advanced past that end-group tag to continue
  220. /// reading.
  221. ///
  222. /// - Throws: `BinaryDecodingError` if an error occurred while reading from the input.
  223. mutating func withReaderForNextGroup(
  224. withFieldNumber fieldNumber: UInt32,
  225. perform: (inout WireFormatReader) throws -> Void
  226. ) throws {
  227. guard recursionBudget > 0 else {
  228. throw BinaryDecodingError.messageDepthLimit
  229. }
  230. // Unlike a length-delimited field, we don't know exactly where the group will end. Have
  231. // the sub-reader start with the rest of the buffer, and after the the closure returns,
  232. // we'll know how far to advance ourselves.
  233. let subBuffer = UnsafeRawBufferPointer(start: pointer, count: available)
  234. var subReader = WireFormatReader(
  235. buffer: subBuffer,
  236. recursionBudget: recursionBudget &- 1,
  237. trackingGroupFieldNumber: fieldNumber
  238. )
  239. try perform(&subReader)
  240. guard let subPointer = subReader.pointer else {
  241. throw BinaryDecodingError.truncated
  242. }
  243. advance(by: subPointer - pointer!)
  244. }
  245. /// Advances the reader past the field at the current location, assuming it had the given tag,
  246. /// and returns a rebased `UnsafeRawBufferPointer` representing the slice of the buffer that
  247. /// includes the tag and subsequent data for that field.
  248. ///
  249. /// - Throws: `BinaryDecodingError` if an error occurred while reading from the input.
  250. mutating func sliceBySkippingField(tag: FieldTag) throws -> UnsafeRawBufferPointer {
  251. switch tag.wireFormat {
  252. case .varint:
  253. _ = try nextVarint()
  254. case .lengthDelimited:
  255. advance(by: try nextVarintAsValidatedDelimitedLength())
  256. case .fixed32:
  257. let size = MemoryLayout<UInt32>.size
  258. guard available >= size else { throw BinaryDecodingError.truncated }
  259. advance(by: size)
  260. case .fixed64:
  261. let size = MemoryLayout<UInt64>.size
  262. guard available >= size else { throw BinaryDecodingError.truncated }
  263. advance(by: size)
  264. case .startGroup:
  265. guard recursionBudget > 0 else {
  266. throw BinaryDecodingError.messageDepthLimit
  267. }
  268. var sawEndGroup = false
  269. recursionBudget &-= 1
  270. while hasAvailableData {
  271. let innerTag = try nextTagWithoutUpdatingLastTagPointer()
  272. if innerTag.wireFormat == .endGroup {
  273. guard innerTag.fieldNumber == tag.fieldNumber else {
  274. // The binary data is invalid if the startGroup/endGroup field numbers are
  275. // not balanced.
  276. throw BinaryDecodingError.malformedProtobuf
  277. }
  278. recursionBudget &+= 1
  279. sawEndGroup = true
  280. break
  281. } else {
  282. _ = try sliceBySkippingField(tag: innerTag)
  283. }
  284. }
  285. if !sawEndGroup && !hasAvailableData {
  286. throw BinaryDecodingError.truncated
  287. }
  288. case .endGroup:
  289. throw BinaryDecodingError.truncated
  290. }
  291. return UnsafeRawBufferPointer(start: lastTagPointer, count: pointer! - lastTagPointer!)
  292. }
  293. /// Advances the input the given number of bytes.
  294. ///
  295. /// - Precondition: There must be at least `length` amount of data in the buffer. It is the
  296. /// caller's responsibility to verify this and throw if there is not.
  297. private mutating func advance(by length: Int) {
  298. precondition(
  299. available >= length,
  300. "Internal error: tried to advance \(length) bytes with only \(available) available"
  301. )
  302. available &-= length
  303. pointer! += length
  304. }
  305. }