1use std::fmt::Display;
14use std::future::Future;
15use std::time::Duration;
16use tokio::time;
17
18use crate::util::rng::{HasherRng, Rng};
19
20pub trait MakeBackoff {
22 type Backoff: Backoff;
24
25 fn make_backoff(&mut self) -> Self::Backoff;
27}
28
29pub trait Backoff {
33 type Future: Future<Output = ()>;
36
37 fn next_backoff(&mut self) -> Self::Future;
39}
40
41#[derive(Debug, Clone)]
43pub struct ExponentialBackoffMaker<R = HasherRng> {
44 min: time::Duration,
46 max: time::Duration,
48 jitter: f64,
52 rng: R,
53}
54
55#[derive(Debug, Clone)]
64pub struct ExponentialBackoff<R = HasherRng> {
65 min: time::Duration,
66 max: time::Duration,
67 jitter: f64,
68 rng: R,
69 iterations: u32,
70}
71
72impl<R> ExponentialBackoffMaker<R>
73where
74 R: Rng,
75{
76 pub fn new(
87 min: time::Duration,
88 max: time::Duration,
89 jitter: f64,
90 rng: R,
91 ) -> Result<Self, InvalidBackoff> {
92 if min > max {
93 return Err(InvalidBackoff("maximum must not be less than minimum"));
94 }
95 if max == time::Duration::from_millis(0) {
96 return Err(InvalidBackoff("maximum must be non-zero"));
97 }
98 if jitter < 0.0 {
99 return Err(InvalidBackoff("jitter must not be negative"));
100 }
101 if jitter > 100.0 {
102 return Err(InvalidBackoff("jitter must not be greater than 100"));
103 }
104 if !jitter.is_finite() {
105 return Err(InvalidBackoff("jitter must be finite"));
106 }
107
108 Ok(ExponentialBackoffMaker {
109 min,
110 max,
111 jitter,
112 rng,
113 })
114 }
115}
116
117impl<R> MakeBackoff for ExponentialBackoffMaker<R>
118where
119 R: Rng + Clone,
120{
121 type Backoff = ExponentialBackoff<R>;
122
123 fn make_backoff(&mut self) -> Self::Backoff {
124 ExponentialBackoff {
125 max: self.max,
126 min: self.min,
127 jitter: self.jitter,
128 rng: self.rng.clone(),
129 iterations: 0,
130 }
131 }
132}
133
134impl<R: Rng> ExponentialBackoff<R> {
135 fn base(&self) -> time::Duration {
136 debug_assert!(
137 self.min <= self.max,
138 "maximum backoff must not be less than minimum backoff"
139 );
140 debug_assert!(
141 self.max > time::Duration::from_millis(0),
142 "Maximum backoff must be non-zero"
143 );
144 self.min
145 .checked_mul(2_u32.saturating_pow(self.iterations))
146 .unwrap_or(self.max)
147 .min(self.max)
148 }
149
150 fn jitter(&mut self, base: time::Duration) -> time::Duration {
153 if self.jitter == 0.0 {
154 time::Duration::default()
155 } else {
156 let jitter_factor = self.rng.next_f64();
157 debug_assert!(
158 jitter_factor > 0.0,
159 "rng returns values between 0.0 and 1.0"
160 );
161 let rand_jitter = jitter_factor * self.jitter;
162 let secs = (base.as_secs() as f64) * rand_jitter;
163 let nanos = (base.subsec_nanos() as f64) * rand_jitter;
164 let remaining = self.max - base;
165 time::Duration::new(secs as u64, nanos as u32).min(remaining)
166 }
167 }
168}
169
170impl<R> Backoff for ExponentialBackoff<R>
171where
172 R: Rng,
173{
174 type Future = tokio::time::Sleep;
175
176 fn next_backoff(&mut self) -> Self::Future {
177 let base = self.base();
178 let next = base + self.jitter(base);
179
180 self.iterations += 1;
181
182 tokio::time::sleep(next)
183 }
184}
185
186impl Default for ExponentialBackoffMaker {
187 fn default() -> Self {
188 ExponentialBackoffMaker::new(
189 Duration::from_millis(50),
190 Duration::from_millis(u64::MAX),
191 0.99,
192 HasherRng::default(),
193 )
194 .expect("Unable to create ExponentialBackoff")
195 }
196}
197
198#[derive(Debug)]
200pub struct InvalidBackoff(&'static str);
201
202impl Display for InvalidBackoff {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 write!(f, "invalid backoff: {}", self.0)
205 }
206}
207
208impl std::error::Error for InvalidBackoff {}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use quickcheck::*;
214
215 quickcheck! {
216 fn backoff_base_first(min_ms: u64, max_ms: u64) -> TestResult {
217 let min = time::Duration::from_millis(min_ms);
218 let max = time::Duration::from_millis(max_ms);
219 let rng = HasherRng::default();
220 let mut backoff = match ExponentialBackoffMaker::new(min, max, 0.0, rng) {
221 Err(_) => return TestResult::discard(),
222 Ok(backoff) => backoff,
223 };
224 let backoff = backoff.make_backoff();
225
226 let delay = backoff.base();
227 TestResult::from_bool(min == delay)
228 }
229
230 fn backoff_base(min_ms: u64, max_ms: u64, iterations: u32) -> TestResult {
231 let min = time::Duration::from_millis(min_ms);
232 let max = time::Duration::from_millis(max_ms);
233 let rng = HasherRng::default();
234 let mut backoff = match ExponentialBackoffMaker::new(min, max, 0.0, rng) {
235 Err(_) => return TestResult::discard(),
236 Ok(backoff) => backoff,
237 };
238 let mut backoff = backoff.make_backoff();
239
240 backoff.iterations = iterations;
241 let delay = backoff.base();
242 TestResult::from_bool(min <= delay && delay <= max)
243 }
244
245 fn backoff_jitter(base_ms: u64, max_ms: u64, jitter: f64) -> TestResult {
246 let base = time::Duration::from_millis(base_ms);
247 let max = time::Duration::from_millis(max_ms);
248 let rng = HasherRng::default();
249 let mut backoff = match ExponentialBackoffMaker::new(base, max, jitter, rng) {
250 Err(_) => return TestResult::discard(),
251 Ok(backoff) => backoff,
252 };
253 let mut backoff = backoff.make_backoff();
254
255 let j = backoff.jitter(base);
256 if jitter == 0.0 || base_ms == 0 || max_ms == base_ms {
257 TestResult::from_bool(j == time::Duration::default())
258 } else {
259 TestResult::from_bool(j > time::Duration::default())
260 }
261 }
262 }
263}