1pub mod error;
4pub mod future;
5pub mod spawn;
6
7pub use spawn::Spawn;
8
9use crate::mock::{error::Error, future::ResponseFuture};
10use core::task::Waker;
11
12use tokio::sync::{mpsc, oneshot};
13use tower_layer::Layer;
14use tower_service::Service;
15
16use std::{
17 collections::HashMap,
18 future::Future,
19 sync::{Arc, Mutex},
20 task::{Context, Poll},
21};
22
23pub fn spawn_layer<T, U, L>(layer: L) -> (Spawn<L::Service>, Handle<T, U>)
25where
26 L: Layer<Mock<T, U>>,
27{
28 let (inner, handle) = pair();
29 let svc = layer.layer(inner);
30
31 (Spawn::new(svc), handle)
32}
33
34pub fn spawn<T, U>() -> (Spawn<Mock<T, U>>, Handle<T, U>) {
36 let (svc, handle) = pair();
37
38 (Spawn::new(svc), handle)
39}
40
41pub fn spawn_with<T, U, F, S>(f: F) -> (Spawn<S>, Handle<T, U>)
43where
44 F: Fn(Mock<T, U>) -> S,
45{
46 let (svc, handle) = pair();
47
48 let svc = f(svc);
49
50 (Spawn::new(svc), handle)
51}
52
53#[derive(Debug)]
55pub struct Mock<T, U> {
56 id: u64,
57 tx: Mutex<Tx<T, U>>,
58 state: Arc<Mutex<State>>,
59 can_send: bool,
60}
61
62#[derive(Debug)]
64pub struct Handle<T, U> {
65 rx: Rx<T, U>,
66 state: Arc<Mutex<State>>,
67}
68
69type Request<T, U> = (T, SendResponse<U>);
70
71#[derive(Debug)]
73pub struct SendResponse<T> {
74 tx: oneshot::Sender<Result<T, Error>>,
75}
76
77#[derive(Debug)]
78struct State {
79 rem: u64,
81
82 tasks: HashMap<u64, Waker>,
84
85 is_closed: bool,
87
88 next_clone_id: u64,
90
91 err_with: Option<Error>,
93}
94
95type Tx<T, U> = mpsc::UnboundedSender<Request<T, U>>;
96type Rx<T, U> = mpsc::UnboundedReceiver<Request<T, U>>;
97
98pub fn pair<T, U>() -> (Mock<T, U>, Handle<T, U>) {
100 let (tx, rx) = mpsc::unbounded_channel();
101 let tx = Mutex::new(tx);
102
103 let state = Arc::new(Mutex::new(State::new()));
104
105 let mock = Mock {
106 id: 0,
107 tx,
108 state: state.clone(),
109 can_send: false,
110 };
111
112 let handle = Handle { rx, state };
113
114 (mock, handle)
115}
116
117impl<T, U> Service<T> for Mock<T, U> {
118 type Response = U;
119 type Error = Error;
120 type Future = ResponseFuture<U>;
121
122 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
123 let mut state = self.state.lock().unwrap();
124
125 if state.is_closed {
126 return Poll::Ready(Err(error::Closed::new().into()));
127 }
128
129 if let Some(e) = state.err_with.take() {
130 return Poll::Ready(Err(e));
131 }
132
133 if self.can_send {
134 return Poll::Ready(Ok(()));
135 }
136
137 if state.rem > 0 {
138 assert!(!state.tasks.contains_key(&self.id));
139
140 self.can_send = true;
142
143 Poll::Ready(Ok(()))
144 } else {
145 *state
147 .tasks
148 .entry(self.id)
149 .or_insert_with(|| cx.waker().clone()) = cx.waker().clone();
150
151 Poll::Pending
152 }
153 }
154
155 fn call(&mut self, request: T) -> Self::Future {
156 let mut state = self.state.lock().unwrap();
158
159 if state.is_closed {
160 return ResponseFuture::closed();
161 }
162
163 if !self.can_send {
164 panic!("service not ready; poll_ready must be called first");
165 }
166
167 self.can_send = false;
168
169 if state.rem > 0 {
171 state.rem -= 1;
172 }
173
174 let (tx, rx) = oneshot::channel();
175 let send_response = SendResponse { tx };
176
177 match self.tx.lock().unwrap().send((request, send_response)) {
178 Ok(_) => {}
179 Err(_) => {
180 return ResponseFuture::closed();
182 }
183 }
184
185 ResponseFuture::new(rx)
186 }
187}
188
189impl<T, U> Clone for Mock<T, U> {
190 fn clone(&self) -> Self {
191 let id = {
192 let mut state = self.state.lock().unwrap();
193 let id = state.next_clone_id;
194
195 state.next_clone_id += 1;
196
197 id
198 };
199
200 let tx = Mutex::new(self.tx.lock().unwrap().clone());
201
202 Mock {
203 id,
204 tx,
205 state: self.state.clone(),
206 can_send: false,
207 }
208 }
209}
210
211impl<T, U> Drop for Mock<T, U> {
212 fn drop(&mut self) {
213 let mut state = match self.state.lock() {
214 Ok(v) => v,
215 Err(e) => {
216 if ::std::thread::panicking() {
217 return;
218 }
219
220 panic!("{:?}", e);
221 }
222 };
223
224 state.tasks.remove(&self.id);
225 }
226}
227
228impl<T, U> Handle<T, U> {
231 pub fn poll_request(&mut self) -> Poll<Option<Request<T, U>>> {
233 tokio_test::task::spawn(()).enter(|cx, _| Box::pin(self.rx.recv()).as_mut().poll(cx))
234 }
235
236 pub async fn next_request(&mut self) -> Option<Request<T, U>> {
238 self.rx.recv().await
239 }
240
241 pub fn allow(&mut self, num: u64) {
243 let mut state = self.state.lock().unwrap();
244 state.rem = num;
245
246 if num > 0 {
247 for (_, task) in state.tasks.drain() {
248 task.wake();
249 }
250 }
251 }
252
253 pub fn send_error<E: Into<Error>>(&mut self, e: E) {
255 let mut state = self.state.lock().unwrap();
256 state.err_with = Some(e.into());
257
258 for (_, task) in state.tasks.drain() {
259 task.wake();
260 }
261 }
262}
263
264impl<T, U> Drop for Handle<T, U> {
265 fn drop(&mut self) {
266 let mut state = match self.state.lock() {
267 Ok(v) => v,
268 Err(e) => {
269 if ::std::thread::panicking() {
270 return;
271 }
272
273 panic!("{:?}", e);
274 }
275 };
276
277 state.is_closed = true;
278
279 for (_, task) in state.tasks.drain() {
280 task.wake();
281 }
282 }
283}
284
285impl<T> SendResponse<T> {
288 pub fn send_response(self, response: T) {
290 let _ = self.tx.send(Ok(response));
292 }
293
294 pub fn send_error<E: Into<Error>>(self, err: E) {
296 let _ = self.tx.send(Err(err.into()));
298 }
299}
300
301impl State {
304 fn new() -> State {
305 State {
306 rem: u64::MAX,
307 tasks: HashMap::new(),
308 is_closed: false,
309 next_clone_id: 1,
310 err_with: None,
311 }
312 }
313}