tower/retry/
backoff.rs

1//! This module contains generic [backoff] utilities to be used with the retry
2//! layer.
3//!
4//! The [`Backoff`] trait is a generic way to represent backoffs that can use
5//! any timer type.
6//!
7//! [`ExponentialBackoffMaker`] implements the maker type for  
8//! [`ExponentialBackoff`] which implements the [`Backoff`] trait and provides
9//! a batteries included exponential backoff and jitter strategy.
10//!
11//! [backoff]: https://en.wikipedia.org/wiki/Exponential_backoff
12
13use std::fmt::Display;
14use std::future::Future;
15use std::time::Duration;
16use tokio::time;
17
18use crate::util::rng::{HasherRng, Rng};
19
20/// Trait used to construct [`Backoff`] trait implementors.
21pub trait MakeBackoff {
22    /// The backoff type produced by this maker.
23    type Backoff: Backoff;
24
25    /// Constructs a new backoff type.
26    fn make_backoff(&mut self) -> Self::Backoff;
27}
28
29/// A backoff trait where a single mutable reference represents a single
30/// backoff session. Implementors must also implement [`Clone`] which will
31/// reset the backoff back to the default state for the next session.
32pub trait Backoff {
33    /// The future associated with each backoff. This usually will be some sort
34    /// of timer.
35    type Future: Future<Output = ()>;
36
37    /// Initiate the next backoff in the sequence.
38    fn next_backoff(&mut self) -> Self::Future;
39}
40
41/// A maker type for [`ExponentialBackoff`].
42#[derive(Debug, Clone)]
43pub struct ExponentialBackoffMaker<R = HasherRng> {
44    /// The minimum amount of time to wait before resuming an operation.
45    min: time::Duration,
46    /// The maximum amount of time to wait before resuming an operation.
47    max: time::Duration,
48    /// The ratio of the base timeout that may be randomly added to a backoff.
49    ///
50    /// Must be greater than or equal to 0.0.
51    jitter: f64,
52    rng: R,
53}
54
55/// A jittered [exponential backoff] strategy.
56///
57/// The backoff duration will increase exponentially for every subsequent
58/// backoff, up to a maximum duration. A small amount of [random jitter] is
59/// added to each backoff duration, in order to avoid retry spikes.
60///
61/// [exponential backoff]: https://en.wikipedia.org/wiki/Exponential_backoff
62/// [random jitter]: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
63#[derive(Debug, Clone)]
64pub struct ExponentialBackoff<R = HasherRng> {
65    min: time::Duration,
66    max: time::Duration,
67    jitter: f64,
68    rng: R,
69    iterations: u32,
70}
71
72impl<R> ExponentialBackoffMaker<R>
73where
74    R: Rng,
75{
76    /// Create a new `ExponentialBackoff`.
77    ///
78    /// # Error
79    ///
80    /// Returns a config validation error if:
81    /// - `min` > `max`
82    /// - `max` == 0
83    /// - `jitter` < `0.0`
84    /// - `jitter` > `100.0`
85    /// - `jitter` is not finite
86    pub fn new(
87        min: time::Duration,
88        max: time::Duration,
89        jitter: f64,
90        rng: R,
91    ) -> Result<Self, InvalidBackoff> {
92        if min > max {
93            return Err(InvalidBackoff("maximum must not be less than minimum"));
94        }
95        if max == time::Duration::from_millis(0) {
96            return Err(InvalidBackoff("maximum must be non-zero"));
97        }
98        if jitter < 0.0 {
99            return Err(InvalidBackoff("jitter must not be negative"));
100        }
101        if jitter > 100.0 {
102            return Err(InvalidBackoff("jitter must not be greater than 100"));
103        }
104        if !jitter.is_finite() {
105            return Err(InvalidBackoff("jitter must be finite"));
106        }
107
108        Ok(ExponentialBackoffMaker {
109            min,
110            max,
111            jitter,
112            rng,
113        })
114    }
115}
116
117impl<R> MakeBackoff for ExponentialBackoffMaker<R>
118where
119    R: Rng + Clone,
120{
121    type Backoff = ExponentialBackoff<R>;
122
123    fn make_backoff(&mut self) -> Self::Backoff {
124        ExponentialBackoff {
125            max: self.max,
126            min: self.min,
127            jitter: self.jitter,
128            rng: self.rng.clone(),
129            iterations: 0,
130        }
131    }
132}
133
134impl<R: Rng> ExponentialBackoff<R> {
135    fn base(&self) -> time::Duration {
136        debug_assert!(
137            self.min <= self.max,
138            "maximum backoff must not be less than minimum backoff"
139        );
140        debug_assert!(
141            self.max > time::Duration::from_millis(0),
142            "Maximum backoff must be non-zero"
143        );
144        self.min
145            .checked_mul(2_u32.saturating_pow(self.iterations))
146            .unwrap_or(self.max)
147            .min(self.max)
148    }
149
150    /// Returns a random, uniform duration on `[0, base*self.jitter]` no greater
151    /// than `self.max`.
152    fn jitter(&mut self, base: time::Duration) -> time::Duration {
153        if self.jitter == 0.0 {
154            time::Duration::default()
155        } else {
156            let jitter_factor = self.rng.next_f64();
157            debug_assert!(
158                jitter_factor > 0.0,
159                "rng returns values between 0.0 and 1.0"
160            );
161            let rand_jitter = jitter_factor * self.jitter;
162            let secs = (base.as_secs() as f64) * rand_jitter;
163            let nanos = (base.subsec_nanos() as f64) * rand_jitter;
164            let remaining = self.max - base;
165            time::Duration::new(secs as u64, nanos as u32).min(remaining)
166        }
167    }
168}
169
170impl<R> Backoff for ExponentialBackoff<R>
171where
172    R: Rng,
173{
174    type Future = tokio::time::Sleep;
175
176    fn next_backoff(&mut self) -> Self::Future {
177        let base = self.base();
178        let next = base + self.jitter(base);
179
180        self.iterations += 1;
181
182        tokio::time::sleep(next)
183    }
184}
185
186impl Default for ExponentialBackoffMaker {
187    fn default() -> Self {
188        ExponentialBackoffMaker::new(
189            Duration::from_millis(50),
190            Duration::from_millis(u64::MAX),
191            0.99,
192            HasherRng::default(),
193        )
194        .expect("Unable to create ExponentialBackoff")
195    }
196}
197
198/// Backoff validation error.
199#[derive(Debug)]
200pub struct InvalidBackoff(&'static str);
201
202impl Display for InvalidBackoff {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        write!(f, "invalid backoff: {}", self.0)
205    }
206}
207
208impl std::error::Error for InvalidBackoff {}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use quickcheck::*;
214
215    quickcheck! {
216        fn backoff_base_first(min_ms: u64, max_ms: u64) -> TestResult {
217            let min = time::Duration::from_millis(min_ms);
218            let max = time::Duration::from_millis(max_ms);
219            let rng = HasherRng::default();
220            let mut backoff = match ExponentialBackoffMaker::new(min, max, 0.0, rng) {
221                Err(_) => return TestResult::discard(),
222                Ok(backoff) => backoff,
223            };
224            let backoff = backoff.make_backoff();
225
226            let delay = backoff.base();
227            TestResult::from_bool(min == delay)
228        }
229
230        fn backoff_base(min_ms: u64, max_ms: u64, iterations: u32) -> TestResult {
231            let min = time::Duration::from_millis(min_ms);
232            let max = time::Duration::from_millis(max_ms);
233            let rng = HasherRng::default();
234            let mut backoff = match ExponentialBackoffMaker::new(min, max, 0.0, rng) {
235                Err(_) => return TestResult::discard(),
236                Ok(backoff) => backoff,
237            };
238            let mut backoff = backoff.make_backoff();
239
240            backoff.iterations = iterations;
241            let delay = backoff.base();
242            TestResult::from_bool(min <= delay && delay <= max)
243        }
244
245        fn backoff_jitter(base_ms: u64, max_ms: u64, jitter: f64) -> TestResult {
246            let base = time::Duration::from_millis(base_ms);
247            let max = time::Duration::from_millis(max_ms);
248            let rng = HasherRng::default();
249            let mut backoff = match ExponentialBackoffMaker::new(base, max, jitter, rng) {
250                Err(_) => return TestResult::discard(),
251                Ok(backoff) => backoff,
252            };
253            let mut backoff = backoff.make_backoff();
254
255            let j = backoff.jitter(base);
256            if jitter == 0.0 || base_ms == 0 || max_ms == base_ms {
257                TestResult::from_bool(j == time::Duration::default())
258            } else {
259                TestResult::from_bool(j > time::Duration::default())
260            }
261        }
262    }
263}