use crate::body::{Body, HttpBody};
use crate::error::Error;
use crate::status::infer_grpc_status;
use crate::Status;
use bytes::{Buf, BufMut, Bytes, BytesMut, IntoBuf};
use futures::{try_ready, Async, Poll, Stream};
use http::{HeaderMap, StatusCode};
use log::{debug, trace, warn};
use std::collections::VecDeque;
use std::fmt;
type BytesBuf = <Bytes as IntoBuf>::Buf;
pub trait Codec {
type Encode;
type Encoder: Encoder<Item = Self::Encode>;
type Decode;
type Decoder: Decoder<Item = Self::Decode>;
fn encoder(&mut self) -> Self::Encoder;
fn decoder(&mut self) -> Self::Decoder;
}
pub trait Encoder {
type Item;
const CONTENT_TYPE: &'static str;
fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Status>;
}
pub trait Decoder {
type Item;
fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Self::Item, Status>;
}
#[must_use = "futures do nothing unless polled"]
#[derive(Debug)]
pub struct Encode<T, U> {
inner: EncodeInner<T, U>,
buf: BytesMut,
role: Role,
}
#[derive(Debug)]
enum EncodeInner<T, U> {
Ok {
encoder: T,
inner: U,
},
Empty,
Err(Status),
}
#[derive(Debug)]
enum Role {
Client,
Server,
}
#[must_use = "futures do nothing unless polled"]
pub struct Streaming<T, B: Body> {
decoder: T,
inner: B,
bufs: BufList<B::Data>,
state: State,
direction: Direction,
}
#[derive(Clone, Copy, Debug)]
pub(crate) enum Direction {
Request,
Response(StatusCode),
EmptyResponse,
}
#[derive(Debug)]
enum State {
ReadHeader,
ReadBody { compression: bool, len: usize },
Done,
}
#[derive(Debug)]
pub struct EncodeBuf<'a> {
bytes: &'a mut BytesMut,
}
pub struct DecodeBuf<'a> {
bufs: &'a mut dyn Buf,
len: usize,
}
#[derive(Debug)]
pub struct BufList<B> {
bufs: VecDeque<B>,
}
impl<T, U> Encode<T, U>
where
T: Encoder<Item = U::Item>,
U: Stream,
U::Error: Into<Error>,
{
fn new(encoder: T, inner: U, role: Role) -> Self {
Encode {
inner: EncodeInner::Ok { encoder, inner },
buf: BytesMut::new(),
role,
}
}
pub(crate) fn request(encoder: T, inner: U) -> Self {
Encode::new(encoder, inner, Role::Client)
}
pub(crate) fn response(encoder: T, inner: U) -> Self {
Encode::new(encoder, inner, Role::Server)
}
pub(crate) fn empty() -> Self {
Encode {
inner: EncodeInner::Empty,
buf: BytesMut::new(),
role: Role::Server,
}
}
}
impl<T, U> HttpBody for Encode<T, U>
where
T: Encoder<Item = U::Item>,
U: Stream,
U::Error: Into<Error>,
{
type Data = BytesBuf;
type Error = Status;
fn is_end_stream(&self) -> bool {
if let EncodeInner::Empty = self.inner {
true
} else {
false
}
}
fn poll_data(&mut self) -> Poll<Option<Self::Data>, Status> {
match self.inner.poll_encode(&mut self.buf) {
Ok(ok) => Ok(ok),
Err(status) => {
match self.role {
Role::Client => Err(status),
Role::Server => {
self.inner = EncodeInner::Err(status);
Ok(None.into())
}
}
}
}
}
fn poll_trailers(&mut self) -> Poll<Option<HeaderMap>, Status> {
if let Role::Client = self.role {
return Ok(Async::Ready(None));
}
let map = match self.inner {
EncodeInner::Ok { .. } => Status::new(crate::Code::Ok, "").to_header_map(),
EncodeInner::Empty => return Ok(None.into()),
EncodeInner::Err(ref status) => status.to_header_map(),
};
Ok(Some(map?).into())
}
}
impl<T, U> EncodeInner<T, U>
where
T: Encoder<Item = U::Item>,
U: Stream,
U::Error: Into<Error>,
{
fn poll_encode(&mut self, buf: &mut BytesMut) -> Poll<Option<BytesBuf>, Status> {
match self {
EncodeInner::Ok {
ref mut inner,
ref mut encoder,
} => {
let item = try_ready!(inner.poll().map_err(|err| {
let err = err.into();
debug!("encoder inner stream error: {:?}", err);
Status::from_error(&*err)
}));
let item = if let Some(item) = item {
buf.reserve(5);
unsafe {
buf.advance_mut(5);
}
encoder.encode(item, &mut EncodeBuf { bytes: buf })?;
let len = buf.len() - 5;
assert!(len <= ::std::u32::MAX as usize);
{
let mut cursor = ::std::io::Cursor::new(&mut buf[..5]);
cursor.put_u8(0);
cursor.put_u32_be(len as u32);
}
Some(buf.split_to(len + 5).freeze().into_buf())
} else {
None
};
return Ok(Async::Ready(item));
}
_ => return Ok(Async::Ready(None)),
}
}
}
impl<T, U> Streaming<T, U>
where
T: Decoder,
U: Body,
{
pub(crate) fn new(decoder: T, inner: U, direction: Direction) -> Self {
Streaming {
decoder,
inner,
bufs: BufList {
bufs: VecDeque::new(),
},
state: State::ReadHeader,
direction,
}
}
fn decode(&mut self) -> Result<Option<T::Item>, crate::Status> {
if let State::ReadHeader = self.state {
if self.bufs.remaining() < 5 {
return Ok(None);
}
let is_compressed = match self.bufs.get_u8() {
0 => false,
1 => {
trace!("message compressed, compression not supported yet");
return Err(crate::Status::new(
crate::Code::Unimplemented,
"Message compressed, compression not supported yet.".to_string(),
));
}
f => {
trace!("unexpected compression flag");
return Err(crate::Status::new(
crate::Code::Internal,
format!("Unexpected compression flag: {}", f),
));
}
};
let len = self.bufs.get_u32_be() as usize;
self.state = State::ReadBody {
compression: is_compressed,
len,
}
}
if let State::ReadBody { len, .. } = self.state {
if self.bufs.remaining() < len {
return Ok(None);
}
match self.decoder.decode(&mut DecodeBuf {
bufs: &mut self.bufs,
len,
}) {
Ok(msg) => {
self.state = State::ReadHeader;
return Ok(Some(msg));
}
Err(e) => {
return Err(e);
}
}
}
Ok(None)
}
}
impl<T, U> Stream for Streaming<T, U>
where
T: Decoder,
U: Body,
{
type Item = T::Item;
type Error = Status;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop {
if let State::Done = self.state {
break;
}
match self.decode()? {
Some(val) => return Ok(Async::Ready(Some(val))),
None => (),
}
let chunk = try_ready!(self.inner.poll_data().map_err(|err| {
let err = err.into();
debug!("decoder inner stream error: {:?}", err);
Status::from_error(&*err)
}));
if let Some(data) = chunk {
self.bufs.bufs.push_back(data.into_buf());
} else {
if self.bufs.has_remaining() {
trace!("unexpected EOF decoding stream");
return Err(crate::Status::new(
crate::Code::Internal,
"Unexpected EOF decoding stream.".to_string(),
));
} else {
self.state = State::Done;
break;
}
}
}
if let Direction::Response(status_code) = self.direction {
let trailers = try_ready!(self.inner.poll_trailers().map_err(|err| {
let err = err.into();
debug!("decoder inner trailers error: {:?}", err);
Status::from_error(&*err)
}));
match infer_grpc_status(trailers, status_code) {
Ok(_) => Ok(Async::Ready(None)),
Err(err) => Err(err),
}
} else {
Ok(Async::Ready(None))
}
}
}
impl<T, B> fmt::Debug for Streaming<T, B>
where
T: fmt::Debug,
B: Body + fmt::Debug,
B::Data: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Streaming").finish()
}
}
impl<'a> EncodeBuf<'a> {
#[inline]
pub fn reserve(&mut self, capacity: usize) {
self.bytes.reserve(capacity);
}
}
impl<'a> BufMut for EncodeBuf<'a> {
#[inline]
fn remaining_mut(&self) -> usize {
self.bytes.remaining_mut()
}
#[inline]
unsafe fn advance_mut(&mut self, cnt: usize) {
self.bytes.advance_mut(cnt)
}
#[inline]
unsafe fn bytes_mut(&mut self) -> &mut [u8] {
self.bytes.bytes_mut()
}
}
impl<'a> Buf for DecodeBuf<'a> {
#[inline]
fn remaining(&self) -> usize {
self.len
}
#[inline]
fn bytes(&self) -> &[u8] {
let ret = self.bufs.bytes();
if ret.len() > self.len {
&ret[..self.len]
} else {
ret
}
}
#[inline]
fn advance(&mut self, cnt: usize) {
assert!(cnt <= self.len);
self.bufs.advance(cnt);
self.len -= cnt;
}
}
impl<'a> fmt::Debug for DecodeBuf<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DecodeBuf").finish()
}
}
impl<'a> Drop for DecodeBuf<'a> {
fn drop(&mut self) {
if self.len > 0 {
warn!("DecodeBuf was not advanced to end");
self.bufs.advance(self.len);
}
}
}
impl<T: Buf> Buf for BufList<T> {
#[inline]
fn remaining(&self) -> usize {
self.bufs.iter().map(|buf| buf.remaining()).sum()
}
#[inline]
fn bytes(&self) -> &[u8] {
if self.bufs.is_empty() {
&[]
} else {
self.bufs[0].bytes()
}
}
#[inline]
fn advance(&mut self, mut cnt: usize) {
while cnt > 0 {
{
let front = &mut self.bufs[0];
if front.remaining() > cnt {
front.advance(cnt);
return;
} else {
cnt -= front.remaining();
}
}
self.bufs.pop_front();
}
}
}