tower_test/mock/
mod.rs

1//! Mock `Service` that can be used in tests.
2
3pub mod error;
4pub mod future;
5pub mod spawn;
6
7pub use spawn::Spawn;
8
9use crate::mock::{error::Error, future::ResponseFuture};
10use core::task::Waker;
11
12use tokio::sync::{mpsc, oneshot};
13use tower_layer::Layer;
14use tower_service::Service;
15
16use std::{
17    collections::HashMap,
18    future::Future,
19    sync::{Arc, Mutex},
20    task::{Context, Poll},
21};
22
23/// Spawn a layer onto a mock service.
24pub fn spawn_layer<T, U, L>(layer: L) -> (Spawn<L::Service>, Handle<T, U>)
25where
26    L: Layer<Mock<T, U>>,
27{
28    let (inner, handle) = pair();
29    let svc = layer.layer(inner);
30
31    (Spawn::new(svc), handle)
32}
33
34/// Spawn a Service onto a mock task.
35pub fn spawn<T, U>() -> (Spawn<Mock<T, U>>, Handle<T, U>) {
36    let (svc, handle) = pair();
37
38    (Spawn::new(svc), handle)
39}
40
41/// Spawn a Service via the provided wrapper closure.
42pub fn spawn_with<T, U, F, S>(f: F) -> (Spawn<S>, Handle<T, U>)
43where
44    F: Fn(Mock<T, U>) -> S,
45{
46    let (svc, handle) = pair();
47
48    let svc = f(svc);
49
50    (Spawn::new(svc), handle)
51}
52
53/// A mock service
54#[derive(Debug)]
55pub struct Mock<T, U> {
56    id: u64,
57    tx: Mutex<Tx<T, U>>,
58    state: Arc<Mutex<State>>,
59    can_send: bool,
60}
61
62/// Handle to the `Mock`.
63#[derive(Debug)]
64pub struct Handle<T, U> {
65    rx: Rx<T, U>,
66    state: Arc<Mutex<State>>,
67}
68
69type Request<T, U> = (T, SendResponse<U>);
70
71/// Send a response in reply to a received request.
72#[derive(Debug)]
73pub struct SendResponse<T> {
74    tx: oneshot::Sender<Result<T, Error>>,
75}
76
77#[derive(Debug)]
78struct State {
79    /// Tracks the number of requests that can be sent through
80    rem: u64,
81
82    /// Tasks that are blocked
83    tasks: HashMap<u64, Waker>,
84
85    /// Tracks if the `Handle` dropped
86    is_closed: bool,
87
88    /// Tracks the ID for the next mock clone
89    next_clone_id: u64,
90
91    /// Tracks the next error to yield (if any)
92    err_with: Option<Error>,
93}
94
95type Tx<T, U> = mpsc::UnboundedSender<Request<T, U>>;
96type Rx<T, U> = mpsc::UnboundedReceiver<Request<T, U>>;
97
98/// Create a new `Mock` and `Handle` pair.
99pub fn pair<T, U>() -> (Mock<T, U>, Handle<T, U>) {
100    let (tx, rx) = mpsc::unbounded_channel();
101    let tx = Mutex::new(tx);
102
103    let state = Arc::new(Mutex::new(State::new()));
104
105    let mock = Mock {
106        id: 0,
107        tx,
108        state: state.clone(),
109        can_send: false,
110    };
111
112    let handle = Handle { rx, state };
113
114    (mock, handle)
115}
116
117impl<T, U> Service<T> for Mock<T, U> {
118    type Response = U;
119    type Error = Error;
120    type Future = ResponseFuture<U>;
121
122    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
123        let mut state = self.state.lock().unwrap();
124
125        if state.is_closed {
126            return Poll::Ready(Err(error::Closed::new().into()));
127        }
128
129        if let Some(e) = state.err_with.take() {
130            return Poll::Ready(Err(e));
131        }
132
133        if self.can_send {
134            return Poll::Ready(Ok(()));
135        }
136
137        if state.rem > 0 {
138            assert!(!state.tasks.contains_key(&self.id));
139
140            // Returning `Ready` means the next call to `call` must succeed.
141            self.can_send = true;
142
143            Poll::Ready(Ok(()))
144        } else {
145            // Bit weird... but whatevz
146            *state
147                .tasks
148                .entry(self.id)
149                .or_insert_with(|| cx.waker().clone()) = cx.waker().clone();
150
151            Poll::Pending
152        }
153    }
154
155    fn call(&mut self, request: T) -> Self::Future {
156        // Make sure that the service has capacity
157        let mut state = self.state.lock().unwrap();
158
159        if state.is_closed {
160            return ResponseFuture::closed();
161        }
162
163        if !self.can_send {
164            panic!("service not ready; poll_ready must be called first");
165        }
166
167        self.can_send = false;
168
169        // Decrement the number of remaining requests that can be sent
170        if state.rem > 0 {
171            state.rem -= 1;
172        }
173
174        let (tx, rx) = oneshot::channel();
175        let send_response = SendResponse { tx };
176
177        match self.tx.lock().unwrap().send((request, send_response)) {
178            Ok(_) => {}
179            Err(_) => {
180                // TODO: Can this be reached
181                return ResponseFuture::closed();
182            }
183        }
184
185        ResponseFuture::new(rx)
186    }
187}
188
189impl<T, U> Clone for Mock<T, U> {
190    fn clone(&self) -> Self {
191        let id = {
192            let mut state = self.state.lock().unwrap();
193            let id = state.next_clone_id;
194
195            state.next_clone_id += 1;
196
197            id
198        };
199
200        let tx = Mutex::new(self.tx.lock().unwrap().clone());
201
202        Mock {
203            id,
204            tx,
205            state: self.state.clone(),
206            can_send: false,
207        }
208    }
209}
210
211impl<T, U> Drop for Mock<T, U> {
212    fn drop(&mut self) {
213        let mut state = match self.state.lock() {
214            Ok(v) => v,
215            Err(e) => {
216                if ::std::thread::panicking() {
217                    return;
218                }
219
220                panic!("{:?}", e);
221            }
222        };
223
224        state.tasks.remove(&self.id);
225    }
226}
227
228// ===== impl Handle =====
229
230impl<T, U> Handle<T, U> {
231    /// Asynchronously gets the next request
232    pub fn poll_request(&mut self) -> Poll<Option<Request<T, U>>> {
233        tokio_test::task::spawn(()).enter(|cx, _| Box::pin(self.rx.recv()).as_mut().poll(cx))
234    }
235
236    /// Gets the next request.
237    pub async fn next_request(&mut self) -> Option<Request<T, U>> {
238        self.rx.recv().await
239    }
240
241    /// Allow a certain number of requests
242    pub fn allow(&mut self, num: u64) {
243        let mut state = self.state.lock().unwrap();
244        state.rem = num;
245
246        if num > 0 {
247            for (_, task) in state.tasks.drain() {
248                task.wake();
249            }
250        }
251    }
252
253    /// Make the next poll_ method error with the given error.
254    pub fn send_error<E: Into<Error>>(&mut self, e: E) {
255        let mut state = self.state.lock().unwrap();
256        state.err_with = Some(e.into());
257
258        for (_, task) in state.tasks.drain() {
259            task.wake();
260        }
261    }
262}
263
264impl<T, U> Drop for Handle<T, U> {
265    fn drop(&mut self) {
266        let mut state = match self.state.lock() {
267            Ok(v) => v,
268            Err(e) => {
269                if ::std::thread::panicking() {
270                    return;
271                }
272
273                panic!("{:?}", e);
274            }
275        };
276
277        state.is_closed = true;
278
279        for (_, task) in state.tasks.drain() {
280            task.wake();
281        }
282    }
283}
284
285// ===== impl SendResponse =====
286
287impl<T> SendResponse<T> {
288    /// Resolve the pending request future for the linked request with the given response.
289    pub fn send_response(self, response: T) {
290        // TODO: Should the result be dropped?
291        let _ = self.tx.send(Ok(response));
292    }
293
294    /// Resolve the pending request future for the linked request with the given error.
295    pub fn send_error<E: Into<Error>>(self, err: E) {
296        // TODO: Should the result be dropped?
297        let _ = self.tx.send(Err(err.into()));
298    }
299}
300
301// ===== impl State =====
302
303impl State {
304    fn new() -> State {
305        State {
306            rem: u64::MAX,
307            tasks: HashMap::new(),
308            is_closed: false,
309            next_clone_id: 1,
310            err_with: None,
311        }
312    }
313}