VectorIntegrationTests.swift 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. /*
  2. * Copyright 2023 Google LLC
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. import Combine
  17. import FirebaseFirestore
  18. import Foundation
  19. // iOS 15 required for test implementation, not vector feature
  20. @available(iOS 13, tvOS 13, macOS 10.15, macCatalyst 13, watchOS 7, *)
  21. class VectorIntegrationTests: FSTIntegrationTestCase {
  22. func testWriteAndReadVectorEmbeddings() async throws {
  23. let collection = collectionRef()
  24. let ref = try await collection.addDocument(data: [
  25. "vector0": FieldValue.vector([0.0]),
  26. "vector1": FieldValue.vector([1, 2, 3.99]),
  27. ])
  28. try await ref.setData([
  29. "vector0": FieldValue.vector([0.0]),
  30. "vector1": FieldValue.vector([1, 2, 3.99]),
  31. "vector2": FieldValue.vector([0, 0, 0] as [Double]),
  32. ])
  33. try await ref.updateData([
  34. "vector3": FieldValue.vector([-1, -200, -999] as [Double]),
  35. ])
  36. let snapshot = try await ref.getDocument()
  37. XCTAssertEqual(snapshot.get("vector0") as? VectorValue, FieldValue.vector([0.0]))
  38. XCTAssertEqual(snapshot.get("vector1") as? VectorValue, FieldValue.vector([1, 2, 3.99]))
  39. XCTAssertEqual(
  40. snapshot.get("vector2") as? VectorValue,
  41. FieldValue.vector([0, 0, 0] as [Double])
  42. )
  43. XCTAssertEqual(
  44. snapshot.get("vector3") as? VectorValue,
  45. FieldValue.vector([-1, -200, -999] as [Double])
  46. )
  47. }
  48. @available(iOS 15, tvOS 15, macOS 12.0, macCatalyst 13, watchOS 7, *)
  49. func testSdkOrdersVectorFieldSameWayAsBackend() async throws {
  50. let collection = collectionRef()
  51. let docsInOrder: [[String: Any]] = [
  52. ["embedding": [1, 2, 3, 4, 5, 6]],
  53. ["embedding": [100]],
  54. ["embedding": FieldValue.vector([Double.infinity * -1])],
  55. ["embedding": FieldValue.vector([-100.0])],
  56. ["embedding": FieldValue.vector([100.0])],
  57. ["embedding": FieldValue.vector([Double.infinity])],
  58. ["embedding": FieldValue.vector([1, 2.0])],
  59. ["embedding": FieldValue.vector([2, 2.0])],
  60. ["embedding": FieldValue.vector([1, 2, 3.0])],
  61. ["embedding": FieldValue.vector([1, 2, 3, 4.0])],
  62. ["embedding": FieldValue.vector([1, 2, 3, 4, 5.0])],
  63. ["embedding": FieldValue.vector([1, 2, 100, 4, 4.0])],
  64. ["embedding": FieldValue.vector([100, 2, 3, 4, 5.0])],
  65. ["embedding": ["HELLO": "WORLD"]],
  66. ["embedding": ["hello": "world"]],
  67. ]
  68. var docs: [[String: Any]] = []
  69. for data in docsInOrder {
  70. let docRef = try await collection.addDocument(data: data)
  71. docs.append(["id": docRef.documentID, "value": data])
  72. }
  73. // We validate that the SDK orders the vector field the same way as the backend
  74. // by comparing the sort order of vector fields from getDocsFromServer and
  75. // onSnapshot. onSnapshot will return sort order of the SDK,
  76. // and getDocsFromServer will return sort order of the backend.
  77. let orderedQuery = collection.order(by: "embedding")
  78. let watchSnapshot = try await Future<QuerySnapshot, Error>() { promise in
  79. orderedQuery.addSnapshotListener { snapshot, error in
  80. if let error {
  81. promise(Result.failure(error))
  82. }
  83. if let snapshot {
  84. promise(Result.success(snapshot))
  85. }
  86. }
  87. }.value
  88. let getSnapshot = try await orderedQuery.getDocuments(source: .server)
  89. // Compare the snapshot (including sort order) of a snapshot
  90. // from Query.onSnapshot() to an actual snapshot from Query.get()
  91. XCTAssertEqual(watchSnapshot.count, getSnapshot.count)
  92. for i in 0 ..< min(watchSnapshot.count, getSnapshot.count) {
  93. XCTAssertEqual(
  94. watchSnapshot.documents[i].documentID,
  95. getSnapshot.documents[i].documentID
  96. )
  97. }
  98. // Compare the snapshot (including sort order) of a snapshot
  99. // from Query.onSnapshot() to the expected sort order from
  100. // the backend.
  101. XCTAssertEqual(watchSnapshot.count, docs.count)
  102. for i in 0 ..< min(watchSnapshot.count, docs.count) {
  103. XCTAssertEqual(watchSnapshot.documents[i].documentID, docs[i]["id"] as! String)
  104. }
  105. }
  106. func testSdkOrdersVectorFieldSameWayOnlineAndOffline() async throws {
  107. let collection = collectionRef()
  108. let docsInOrder: [[String: Any]] = [
  109. ["embedding": [1, 2, 3, 4, 5, 6]],
  110. ["embedding": [100]],
  111. ["embedding": FieldValue.vector([Double.infinity * -1])],
  112. ["embedding": FieldValue.vector([-100.0])],
  113. ["embedding": FieldValue.vector([100.0])],
  114. ["embedding": FieldValue.vector([Double.infinity])],
  115. ["embedding": FieldValue.vector([1, 2.0])],
  116. ["embedding": FieldValue.vector([2, 2.0])],
  117. ["embedding": FieldValue.vector([1, 2, 3.0])],
  118. ["embedding": FieldValue.vector([1, 2, 3, 4.0])],
  119. ["embedding": FieldValue.vector([1, 2, 3, 4, 5.0])],
  120. ["embedding": FieldValue.vector([1, 2, 100, 4, 4.0])],
  121. ["embedding": FieldValue.vector([100, 2, 3, 4, 5.0])],
  122. ["embedding": ["HELLO": "WORLD"]],
  123. ["embedding": ["hello": "world"]],
  124. ]
  125. var docIds: [String] = []
  126. for data in docsInOrder {
  127. let docRef = try await collection.addDocument(data: data)
  128. docIds.append(docRef.documentID)
  129. }
  130. checkOnlineAndOfflineCollection(
  131. collection,
  132. query: collection.order(by: "embedding"),
  133. matchesResult: docIds
  134. )
  135. }
  136. func testSdkFiltersVectorFieldSameWayOnlineAndOffline() async throws {
  137. let collection = collectionRef()
  138. let docsInOrder: [[String: Any]] = [
  139. ["embedding": [1, 2, 3, 4, 5, 6]],
  140. ["embedding": [100]],
  141. ["embedding": FieldValue.vector([Double.infinity * -1])],
  142. ["embedding": FieldValue.vector([-100.0])],
  143. ["embedding": FieldValue.vector([100.0])],
  144. ["embedding": FieldValue.vector([Double.infinity])],
  145. ["embedding": FieldValue.vector([1, 2.0])],
  146. ["embedding": FieldValue.vector([2, 2.0])],
  147. ["embedding": FieldValue.vector([1, 2, 3.0])],
  148. ["embedding": FieldValue.vector([1, 2, 3, 4.0])],
  149. ["embedding": FieldValue.vector([1, 2, 3, 4, 5.0])],
  150. ["embedding": FieldValue.vector([1, 2, 100, 4, 4.0])],
  151. ["embedding": FieldValue.vector([100, 2, 3, 4, 5.0])],
  152. ["embedding": ["HELLO": "WORLD"]],
  153. ["embedding": ["hello": "world"]],
  154. ]
  155. var docIds: [String] = []
  156. for data in docsInOrder {
  157. let docRef = try await collection.addDocument(data: data)
  158. docIds.append(docRef.documentID)
  159. }
  160. checkOnlineAndOfflineCollection(collection, query:
  161. collection.order(by: "embedding")
  162. .whereField("embedding", isLessThan: FieldValue.vector([1, 2, 100, 4, 4.0])),
  163. matchesResult: Array(docIds[2 ... 10]))
  164. checkOnlineAndOfflineCollection(collection, query:
  165. collection.order(by: "embedding")
  166. .whereField("embedding", isGreaterThanOrEqualTo: FieldValue.vector([1, 2, 100, 4, 4.0])),
  167. matchesResult: Array(docIds[11 ... 12]))
  168. }
  169. func testQueryVectorValueWrittenByCodable() async throws {
  170. let collection = collectionRef()
  171. struct Model: Codable {
  172. var name: String
  173. var embedding: VectorValue
  174. }
  175. let model = Model(
  176. name: "name",
  177. embedding: FieldValue.vector([0.1, 0.3, 0.4])
  178. )
  179. try collection.document().setData(from: model)
  180. let querySnap: QuerySnapshot = try await collection.whereField(
  181. "embedding",
  182. isEqualTo: FieldValue.vector([0.1, 0.3, 0.4])
  183. ).getDocuments()
  184. XCTAssertEqual(1, querySnap.count)
  185. let returnedModel: Model = try querySnap.documents[0].data(as: Model.self)
  186. XCTAssertEqual(returnedModel.embedding, VectorValue([0.1, 0.3, 0.4]))
  187. let vectorData: [Double] = returnedModel.embedding.array
  188. XCTAssertEqual(vectorData, [0.1, 0.3, 0.4])
  189. }
  190. func testQueryVectorValueWrittenByCodableClass() async throws {
  191. let collection = collectionRef()
  192. struct Model: Codable {
  193. var name: String
  194. var embedding: VectorValue
  195. }
  196. struct ModelWithDistance: Codable {
  197. var name: String
  198. var embedding: VectorValue
  199. var distance: Double
  200. }
  201. struct WithDistance<T: Decodable>: Decodable {
  202. var distance: Double
  203. var data: T
  204. private enum CodingKeys: String, CodingKey {
  205. case distance
  206. }
  207. init(from decoder: Decoder) throws {
  208. let container = try decoder.container(keyedBy: CodingKeys.self)
  209. distance = try container.decode(Double.self, forKey: .distance)
  210. data = try T(from: decoder)
  211. }
  212. }
  213. let model = ModelWithDistance(
  214. name: "name",
  215. embedding: FieldValue.vector([0.1, 0.3, 0.4]),
  216. distance: 0.2
  217. )
  218. try collection.document().setData(from: model)
  219. let querySnap: QuerySnapshot = try await collection.getDocuments()
  220. XCTAssertEqual(1, querySnap.count)
  221. let returnedModel: WithDistance =
  222. try querySnap.documents[0].data(as: WithDistance<Model>.self)
  223. XCTAssertEqual(returnedModel.data.embedding, VectorValue([0.1, 0.3, 0.4]))
  224. XCTAssertEqual(returnedModel.distance, 0.2)
  225. let vectorData: [Double] = returnedModel.data.embedding.array
  226. XCTAssertEqual(vectorData, [0.1, 0.3, 0.4])
  227. }
  228. }