VectorIntegrationTests.swift 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. /*
  2. * Copyright 2023 Google LLC
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import Combine
  17. import FirebaseFirestore
  18. import Foundation
  19. // iOS 15 required for test implementation, not vector feature
  20. @available(iOS 13, tvOS 13, macOS 10.15, macCatalyst 13, watchOS 7, *)
  21. class VectorIntegrationTests: FSTIntegrationTestCase {
  22. func testWriteAndReadVectorEmbeddings() async throws {
  23. let collection = collectionRef()
  24. let ref = try await collection.addDocument(data: [
  25. "vector0": FieldValue.vector([0.0]),
  26. "vector1": FieldValue.vector([1, 2, 3.99]),
  27. ])
  28. try await ref.setData([
  29. "vector0": FieldValue.vector([0.0]),
  30. "vector1": FieldValue.vector([1, 2, 3.99]),
  31. "vector2": FieldValue.vector([0, 0, 0] as [Double]),
  32. ])
  33. try await ref.updateData([
  34. "vector3": FieldValue.vector([-1, -200, -999] as [Double]),
  35. ])
  36. let snapshot = try await ref.getDocument()
  37. XCTAssertEqual(snapshot.get("vector0") as? VectorValue, FieldValue.vector([0.0]))
  38. XCTAssertEqual(snapshot.get("vector1") as? VectorValue, FieldValue.vector([1, 2, 3.99]))
  39. XCTAssertEqual(
  40. snapshot.get("vector2") as? VectorValue,
  41. FieldValue.vector([0, 0, 0] as [Double])
  42. )
  43. XCTAssertEqual(
  44. snapshot.get("vector3") as? VectorValue,
  45. FieldValue.vector([-1, -200, -999] as [Double])
  46. )
  47. }
  48. @available(iOS 15, tvOS 15, macOS 12.0, macCatalyst 13, watchOS 7, *)
  49. func testSdkOrdersVectorFieldSameWayAsBackend() async throws {
  50. let collection = collectionRef()
  51. let docsInOrder: [[String: Any]] = [
  52. ["embedding": [1, 2, 3, 4, 5, 6]],
  53. ["embedding": [100]],
  54. ["embedding": FieldValue.vector([Double.infinity * -1])],
  55. ["embedding": FieldValue.vector([-100.0])],
  56. ["embedding": FieldValue.vector([100.0])],
  57. ["embedding": FieldValue.vector([Double.infinity])],
  58. ["embedding": FieldValue.vector([1, 2.0])],
  59. ["embedding": FieldValue.vector([2, 2.0])],
  60. ["embedding": FieldValue.vector([1, 2, 3.0])],
  61. ["embedding": FieldValue.vector([1, 2, 3, 4.0])],
  62. ["embedding": FieldValue.vector([1, 2, 3, 4, 5.0])],
  63. ["embedding": FieldValue.vector([1, 2, 100, 4, 4.0])],
  64. ["embedding": FieldValue.vector([100, 2, 3, 4, 5.0])],
  65. ["embedding": ["HELLO": "WORLD"]],
  66. ["embedding": ["hello": "world"]],
  67. ]
  68. var docs: [[String: Any]] = []
  69. for data in docsInOrder {
  70. let docRef = try await collection.addDocument(data: data)
  71. docs.append(["id": docRef.documentID, "value": data])
  72. }
  73. // We validate that the SDK orders the vector field the same way as the backend
  74. // by comparing the sort order of vector fields from getDocsFromServer and
  75. // onSnapshot. onSnapshot will return sort order of the SDK,
  76. // and getDocsFromServer will return sort order of the backend.
  77. let orderedQuery = collection.order(by: "embedding")
  78. let watchSnapshot = try await Future<QuerySnapshot, Error>() { promise in
  79. orderedQuery.addSnapshotListener { snapshot, error in
  80. if let error {
  81. promise(Result.failure(error))
  82. }
  83. if let snapshot {
  84. promise(Result.success(snapshot))
  85. }
  86. }
  87. }.value
  88. let getSnapshot = try await orderedQuery.getDocuments(source: .server)
  89. // Compare the snapshot (including sort order) of a snapshot
  90. // from Query.onSnapshot() to an actual snapshot from Query.get()
  91. XCTAssertEqual(watchSnapshot.count, getSnapshot.count)
  92. for i in 0 ..< min(watchSnapshot.count, getSnapshot.count) {
  93. XCTAssertEqual(
  94. watchSnapshot.documents[i].documentID,
  95. getSnapshot.documents[i].documentID
  96. )
  97. }
  98. // Compare the snapshot (including sort order) of a snapshot
  99. // from Query.onSnapshot() to the expected sort order from
  100. // the backend.
  101. XCTAssertEqual(watchSnapshot.count, docs.count)
  102. for i in 0 ..< min(watchSnapshot.count, docs.count) {
  103. XCTAssertEqual(watchSnapshot.documents[i].documentID, docs[i]["id"] as! String)
  104. }
  105. }
  106. func testSdkOrdersVectorFieldSameWayOnlineAndOffline() async throws {
  107. let collection = collectionRef()
  108. let docsInOrder: [[String: Any]] = [
  109. ["embedding": [1, 2, 3, 4, 5, 6]],
  110. ["embedding": [100]],
  111. ["embedding": FieldValue.vector([Double.infinity * -1])],
  112. ["embedding": FieldValue.vector([-100.0])],
  113. ["embedding": FieldValue.vector([100.0])],
  114. ["embedding": FieldValue.vector([Double.infinity])],
  115. ["embedding": FieldValue.vector([1, 2.0])],
  116. ["embedding": FieldValue.vector([2, 2.0])],
  117. ["embedding": FieldValue.vector([1, 2, 3.0])],
  118. ["embedding": FieldValue.vector([1, 2, 3, 4.0])],
  119. ["embedding": FieldValue.vector([1, 2, 3, 4, 5.0])],
  120. ["embedding": FieldValue.vector([1, 2, 100, 4, 4.0])],
  121. ["embedding": FieldValue.vector([100, 2, 3, 4, 5.0])],
  122. ["embedding": ["HELLO": "WORLD"]],
  123. ["embedding": ["hello": "world"]],
  124. ]
  125. var docIds: [String] = []
  126. for data in docsInOrder {
  127. let docRef = try await collection.addDocument(data: data)
  128. docIds.append(docRef.documentID)
  129. }
  130. checkOnlineAndOfflineQuery(collection.order(by: "embedding"), matchesResult: docIds)
  131. }
  132. func testSdkFiltersVectorFieldSameWayOnlineAndOffline() async throws {
  133. let collection = collectionRef()
  134. let docsInOrder: [[String: Any]] = [
  135. ["embedding": [1, 2, 3, 4, 5, 6]],
  136. ["embedding": [100]],
  137. ["embedding": FieldValue.vector([Double.infinity * -1])],
  138. ["embedding": FieldValue.vector([-100.0])],
  139. ["embedding": FieldValue.vector([100.0])],
  140. ["embedding": FieldValue.vector([Double.infinity])],
  141. ["embedding": FieldValue.vector([1, 2.0])],
  142. ["embedding": FieldValue.vector([2, 2.0])],
  143. ["embedding": FieldValue.vector([1, 2, 3.0])],
  144. ["embedding": FieldValue.vector([1, 2, 3, 4.0])],
  145. ["embedding": FieldValue.vector([1, 2, 3, 4, 5.0])],
  146. ["embedding": FieldValue.vector([1, 2, 100, 4, 4.0])],
  147. ["embedding": FieldValue.vector([100, 2, 3, 4, 5.0])],
  148. ["embedding": ["HELLO": "WORLD"]],
  149. ["embedding": ["hello": "world"]],
  150. ]
  151. var docIds: [String] = []
  152. for data in docsInOrder {
  153. let docRef = try await collection.addDocument(data: data)
  154. docIds.append(docRef.documentID)
  155. }
  156. checkOnlineAndOfflineQuery(
  157. collection.order(by: "embedding")
  158. .whereField("embedding", isLessThan: FieldValue.vector([1, 2, 100, 4, 4.0])),
  159. matchesResult: Array(docIds[2 ... 10])
  160. )
  161. checkOnlineAndOfflineQuery(
  162. collection.order(by: "embedding")
  163. .whereField("embedding", isGreaterThanOrEqualTo: FieldValue.vector([1, 2, 100, 4, 4.0])),
  164. matchesResult: Array(docIds[11 ... 12])
  165. )
  166. }
  167. func testQueryVectorValueWrittenByCodable() async throws {
  168. let collection = collectionRef()
  169. struct Model: Codable {
  170. var name: String
  171. var embedding: VectorValue
  172. }
  173. let model = Model(
  174. name: "name",
  175. embedding: FieldValue.vector([0.1, 0.3, 0.4])
  176. )
  177. try collection.document().setData(from: model)
  178. let querySnap: QuerySnapshot = try await collection.whereField(
  179. "embedding",
  180. isEqualTo: FieldValue.vector([0.1, 0.3, 0.4])
  181. ).getDocuments()
  182. XCTAssertEqual(1, querySnap.count)
  183. let returnedModel: Model = try querySnap.documents[0].data(as: Model.self)
  184. XCTAssertEqual(returnedModel.embedding, VectorValue([0.1, 0.3, 0.4]))
  185. let vectorData: [Double] = returnedModel.embedding.array
  186. XCTAssertEqual(vectorData, [0.1, 0.3, 0.4])
  187. }
  188. func testQueryVectorValueWrittenByCodableClass() async throws {
  189. let collection = collectionRef()
  190. struct Model: Codable {
  191. var name: String
  192. var embedding: VectorValue
  193. }
  194. struct ModelWithDistance: Codable {
  195. var name: String
  196. var embedding: VectorValue
  197. var distance: Double
  198. }
  199. struct WithDistance<T: Decodable>: Decodable {
  200. var distance: Double
  201. var data: T
  202. private enum CodingKeys: String, CodingKey {
  203. case distance
  204. }
  205. init(from decoder: Decoder) throws {
  206. let container = try decoder.container(keyedBy: CodingKeys.self)
  207. distance = try container.decode(Double.self, forKey: .distance)
  208. data = try T(from: decoder)
  209. }
  210. }
  211. let model = ModelWithDistance(
  212. name: "name",
  213. embedding: FieldValue.vector([0.1, 0.3, 0.4]),
  214. distance: 0.2
  215. )
  216. try collection.document().setData(from: model)
  217. let querySnap: QuerySnapshot = try await collection.getDocuments()
  218. XCTAssertEqual(1, querySnap.count)
  219. let returnedModel: WithDistance =
  220. try querySnap.documents[0].data(as: WithDistance<Model>.self)
  221. XCTAssertEqual(returnedModel.data.embedding, VectorValue([0.1, 0.3, 0.4]))
  222. XCTAssertEqual(returnedModel.distance, 0.2)
  223. let vectorData: [Double] = returnedModel.data.embedding.array
  224. XCTAssertEqual(vectorData, [0.1, 0.3, 0.4])
  225. }
  226. }