FSTMockDatastore.mm 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. /*
  2. * Copyright 2017 Google
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #import "Firestore/Example/Tests/SpecTests/FSTMockDatastore.h"
  17. #include <map>
  18. #include <memory>
  19. #include <queue>
  20. #include <utility>
  21. #include "Firestore/core/src/core/database_info.h"
  22. #include "Firestore/core/src/credentials/credentials_provider.h"
  23. #include "Firestore/core/src/credentials/empty_credentials_provider.h"
  24. #include "Firestore/core/src/local/target_data.h"
  25. #include "Firestore/core/src/model/database_id.h"
  26. #include "Firestore/core/src/model/mutation.h"
  27. #include "Firestore/core/src/remote/connectivity_monitor.h"
  28. #include "Firestore/core/src/remote/datastore.h"
  29. #include "Firestore/core/src/remote/firebase_metadata_provider.h"
  30. #include "Firestore/core/src/remote/grpc_connection.h"
  31. #include "Firestore/core/src/remote/serializer.h"
  32. #include "Firestore/core/src/remote/stream.h"
  33. #include "Firestore/core/src/util/async_queue.h"
  34. #include "Firestore/core/src/util/log.h"
  35. #include "Firestore/core/src/util/string_apple.h"
  36. #include "Firestore/core/test/unit/remote/create_noop_connectivity_monitor.h"
  37. #include "absl/memory/memory.h"
  38. #include "grpcpp/completion_queue.h"
  39. NS_ASSUME_NONNULL_BEGIN
  40. using firebase::firestore::core::DatabaseInfo;
  41. using firebase::firestore::credentials::AppCheckCredentialsProvider;
  42. using firebase::firestore::credentials::AuthCredentialsProvider;
  43. using firebase::firestore::credentials::EmptyCredentialsProvider;
  44. using firebase::firestore::local::TargetData;
  45. using firebase::firestore::model::DatabaseId;
  46. using firebase::firestore::model::Mutation;
  47. using firebase::firestore::model::MutationResult;
  48. using firebase::firestore::model::SnapshotVersion;
  49. using firebase::firestore::model::TargetId;
  50. using firebase::firestore::remote::ConnectivityMonitor;
  51. using firebase::firestore::remote::FirebaseMetadataProvider;
  52. using firebase::firestore::remote::GrpcConnection;
  53. using firebase::firestore::remote::WatchChange;
  54. using firebase::firestore::remote::WatchStream;
  55. using firebase::firestore::remote::WatchTargetChange;
  56. using firebase::firestore::remote::WriteStream;
  57. using firebase::firestore::util::AsyncQueue;
  58. using firebase::firestore::util::Status;
  59. namespace firebase {
  60. namespace firestore {
  61. namespace remote {
  62. class MockWatchStream : public WatchStream {
  63. public:
  64. MockWatchStream(const std::shared_ptr<AsyncQueue>& worker_queue,
  65. std::shared_ptr<AuthCredentialsProvider> auth_credentials_provider,
  66. std::shared_ptr<AppCheckCredentialsProvider> app_check_credentials_provider,
  67. Serializer serializer,
  68. GrpcConnection* grpc_connection,
  69. WatchStreamCallback* callback,
  70. MockDatastore* datastore)
  71. : WatchStream{worker_queue,
  72. auth_credentials_provider,
  73. app_check_credentials_provider,
  74. std::move(serializer),
  75. grpc_connection,
  76. callback},
  77. datastore_{datastore},
  78. callback_{callback} {
  79. }
  80. const std::unordered_map<TargetId, TargetData>& ActiveTargets() const {
  81. return active_targets_;
  82. }
  83. void Start() override {
  84. HARD_ASSERT(!open_, "Trying to start already started watch stream");
  85. open_ = true;
  86. callback_->OnWatchStreamOpen();
  87. }
  88. void Stop() override {
  89. WatchStream::Stop();
  90. open_ = false;
  91. active_targets_.clear();
  92. }
  93. bool IsStarted() const override {
  94. return open_;
  95. }
  96. bool IsOpen() const override {
  97. return open_;
  98. }
  99. void WatchQuery(const TargetData& query) override {
  100. LOG_DEBUG("WatchQuery: %s: %s, %s", query.target_id(), query.target().ToString(),
  101. query.resume_token().ToString());
  102. // Snapshot version is ignored on the wire
  103. TargetData sentTargetData =
  104. query.WithResumeToken(query.resume_token(), SnapshotVersion::None());
  105. if (query.expected_count().has_value()) {
  106. sentTargetData = sentTargetData.WithExpectedCount(query.expected_count().value());
  107. }
  108. datastore_->IncrementWatchStreamRequests();
  109. active_targets_[query.target_id()] = sentTargetData;
  110. }
  111. void UnwatchTargetId(model::TargetId target_id) override {
  112. LOG_DEBUG("UnwatchTargetId: %s", target_id);
  113. active_targets_.erase(target_id);
  114. }
  115. void FailStream(const Status& error) {
  116. open_ = false;
  117. callback_->OnWatchStreamClose(error);
  118. }
  119. void WriteWatchChange(const WatchChange& change, SnapshotVersion snap) {
  120. if (change.type() == WatchChange::Type::TargetChange) {
  121. const auto& targetChange = static_cast<const WatchTargetChange&>(change);
  122. if (!targetChange.cause().ok()) {
  123. for (TargetId target_id : targetChange.target_ids()) {
  124. auto found = active_targets_.find(target_id);
  125. if (found == active_targets_.end()) {
  126. // Technically removing an unknown target is valid (e.g. it could race with a
  127. // server-side removal), but we want to pay extra careful attention in tests
  128. // that we only remove targets we listened to.
  129. HARD_FAIL("Removing a non-active target");
  130. }
  131. active_targets_.erase(found);
  132. }
  133. }
  134. if (!targetChange.target_ids().empty()) {
  135. // If the list of target IDs is not empty, we reset the snapshot version to NONE as
  136. // done in `Serializer::DecodeVersion:`.
  137. snap = SnapshotVersion::None();
  138. }
  139. }
  140. callback_->OnWatchStreamChange(change, snap);
  141. }
  142. private:
  143. bool open_ = false;
  144. std::unordered_map<TargetId, TargetData> active_targets_;
  145. MockDatastore* datastore_ = nullptr;
  146. WatchStreamCallback* callback_ = nullptr;
  147. };
  148. class MockWriteStream : public WriteStream {
  149. public:
  150. MockWriteStream(const std::shared_ptr<AsyncQueue>& worker_queue,
  151. std::shared_ptr<AuthCredentialsProvider> auth_credentials_provider,
  152. std::shared_ptr<AppCheckCredentialsProvider> app_check_credentials_provider,
  153. Serializer serializer,
  154. GrpcConnection* grpc_connection,
  155. WriteStreamCallback* callback,
  156. MockDatastore* datastore)
  157. : WriteStream{worker_queue,
  158. auth_credentials_provider,
  159. app_check_credentials_provider,
  160. std::move(serializer),
  161. grpc_connection,
  162. callback},
  163. datastore_{datastore},
  164. callback_{callback} {
  165. }
  166. void Start() override {
  167. HARD_ASSERT(!open_, "Trying to start already started write stream");
  168. open_ = true;
  169. sent_mutations_ = {};
  170. callback_->OnWriteStreamOpen();
  171. }
  172. void Stop() override {
  173. datastore_->IncrementWriteStreamRequests();
  174. WriteStream::Stop();
  175. sent_mutations_ = {};
  176. open_ = false;
  177. SetHandshakeComplete(false);
  178. }
  179. bool IsStarted() const override {
  180. return open_;
  181. }
  182. bool IsOpen() const override {
  183. return open_;
  184. }
  185. void WriteHandshake() override {
  186. datastore_->IncrementWriteStreamRequests();
  187. SetHandshakeComplete();
  188. callback_->OnWriteStreamHandshakeComplete();
  189. }
  190. void WriteMutations(const std::vector<Mutation>& mutations) override {
  191. datastore_->IncrementWriteStreamRequests();
  192. sent_mutations_.push(mutations);
  193. }
  194. /** Injects a write ack as though it had come from the backend in response to a write. */
  195. void AckWrite(const SnapshotVersion& commitVersion, std::vector<MutationResult> results) {
  196. callback_->OnWriteStreamMutationResult(commitVersion, std::move(results));
  197. }
  198. /** Injects a failed write response as though it had come from the backend. */
  199. void FailStream(const Status& error) {
  200. open_ = false;
  201. callback_->OnWriteStreamClose(error);
  202. }
  203. /**
  204. * Returns the next write that was "sent to the backend", failing if there are no queued sent
  205. */
  206. std::vector<Mutation> NextSentWrite() {
  207. HARD_ASSERT(!sent_mutations_.empty(),
  208. "Writes need to happen before you can call NextSentWrite.");
  209. std::vector<Mutation> result = std::move(sent_mutations_.front());
  210. sent_mutations_.pop();
  211. return result;
  212. }
  213. /**
  214. * Returns the number of mutations that have been sent to the backend but not retrieved via
  215. * nextSentWrite yet.
  216. */
  217. int sent_mutations_count() const {
  218. return static_cast<int>(sent_mutations_.size());
  219. }
  220. private:
  221. bool open_ = false;
  222. std::queue<std::vector<Mutation>> sent_mutations_;
  223. MockDatastore* datastore_ = nullptr;
  224. WriteStreamCallback* callback_ = nullptr;
  225. };
  226. MockDatastore::MockDatastore(
  227. const core::DatabaseInfo& database_info,
  228. const std::shared_ptr<util::AsyncQueue>& worker_queue,
  229. std::shared_ptr<credentials::AuthCredentialsProvider> auth_credentials,
  230. std::shared_ptr<credentials::AppCheckCredentialsProvider> app_check_credentials,
  231. ConnectivityMonitor* connectivity_monitor,
  232. FirebaseMetadataProvider* firebase_metadata_provider)
  233. : Datastore{database_info, worker_queue, auth_credentials,
  234. app_check_credentials, connectivity_monitor, firebase_metadata_provider},
  235. database_info_{&database_info},
  236. worker_queue_{worker_queue},
  237. app_check_credentials_{app_check_credentials},
  238. auth_credentials_{auth_credentials} {
  239. }
  240. std::shared_ptr<WatchStream> MockDatastore::CreateWatchStream(WatchStreamCallback* callback) {
  241. watch_stream_ = std::make_shared<MockWatchStream>(
  242. worker_queue_, auth_credentials_, app_check_credentials_,
  243. Serializer{database_info_->database_id()}, grpc_connection(), callback, this);
  244. return watch_stream_;
  245. }
  246. std::shared_ptr<WriteStream> MockDatastore::CreateWriteStream(WriteStreamCallback* callback) {
  247. write_stream_ = std::make_shared<MockWriteStream>(
  248. worker_queue_, auth_credentials_, app_check_credentials_,
  249. Serializer{database_info_->database_id()}, grpc_connection(), callback, this);
  250. return write_stream_;
  251. }
  252. void MockDatastore::WriteWatchChange(const WatchChange& change, const SnapshotVersion& snap) {
  253. watch_stream_->WriteWatchChange(change, snap);
  254. }
  255. void MockDatastore::FailWatchStream(const Status& error) {
  256. watch_stream_->FailStream(error);
  257. }
  258. const std::unordered_map<TargetId, TargetData>& MockDatastore::ActiveTargets() const {
  259. return watch_stream_->ActiveTargets();
  260. }
  261. bool MockDatastore::IsWatchStreamOpen() const {
  262. return watch_stream_->IsOpen();
  263. }
  264. std::vector<Mutation> MockDatastore::NextSentWrite() {
  265. return write_stream_->NextSentWrite();
  266. }
  267. int MockDatastore::WritesSent() const {
  268. return write_stream_->sent_mutations_count();
  269. }
  270. void MockDatastore::AckWrite(const SnapshotVersion& version, std::vector<MutationResult> results) {
  271. write_stream_->AckWrite(version, std::move(results));
  272. }
  273. void MockDatastore::FailWrite(const Status& error) {
  274. write_stream_->FailStream(error);
  275. }
  276. } // namespace remote
  277. } // namespace firestore
  278. } // namespace firebase
  279. NS_ASSUME_NONNULL_END