1use std::collections::VecDeque;
4use std::ffi::CStr;
5use std::os::unix::io::{BorrowedFd, OwnedFd, RawFd};
6
7use crate::protocol::{Argument, ArgumentType, Message};
8
9use smallvec::SmallVec;
10
11#[derive(Debug)]
13pub enum MessageWriteError {
14 BufferTooSmall,
16 DupFdFailed(std::io::Error),
18}
19
20impl std::error::Error for MessageWriteError {}
21
22impl std::fmt::Display for MessageWriteError {
23 #[cfg_attr(coverage, coverage(off))]
24 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> {
25 match self {
26 Self::BufferTooSmall => {
27 f.write_str("The provided buffer is too small to hold message content.")
28 }
29 Self::DupFdFailed(e) => {
30 write!(
31 f,
32 "The message contains a file descriptor that could not be dup()-ed ({}).",
33 e
34 )
35 }
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub enum MessageParseError {
43 MissingFD,
45 MissingData,
47 Malformed,
49}
50
51impl std::error::Error for MessageParseError {}
52
53impl std::fmt::Display for MessageParseError {
54 #[cfg_attr(coverage, coverage(off))]
55 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> {
56 match *self {
57 Self::MissingFD => {
58 f.write_str("The message references a FD but the buffer FD is empty.")
59 }
60 Self::MissingData => f.write_str("More data is needed to deserialize the message"),
61 Self::Malformed => f.write_str("The message is malformed and cannot be parsed"),
62 }
63 }
64}
65
66pub fn write_to_buffers(
72 msg: &Message<u32, RawFd>,
73 payload: &mut [u8],
74 fds: &mut Vec<OwnedFd>,
75) -> Result<usize, MessageWriteError> {
76 let orig_payload_len = payload.len();
77 fn write_buf(u: u32, payload: &mut [u8]) -> Result<&mut [u8], MessageWriteError> {
79 if payload.len() >= 4 {
80 let (head, tail) = payload.split_at_mut(4);
81 head.copy_from_slice(&u.to_ne_bytes());
82 Ok(tail)
83 } else {
84 Err(MessageWriteError::BufferTooSmall)
85 }
86 }
87
88 fn write_array_to_payload<'a>(
90 array: &[u8],
91 payload: &'a mut [u8],
92 ) -> Result<&'a mut [u8], MessageWriteError> {
93 let payload = write_buf(array.len() as u32, payload)?;
95
96 let len = next_multiple_of(array.len(), 4);
98
99 if payload.len() < len {
100 return Err(MessageWriteError::BufferTooSmall);
101 }
102
103 let (buffer_slice, rest) = payload.split_at_mut(len);
104 buffer_slice[..array.len()].copy_from_slice(array);
105 Ok(rest)
106 }
107
108 let free_size = payload.len();
109 if free_size < 2 * 4 {
110 return Err(MessageWriteError::BufferTooSmall);
111 }
112
113 let (header, mut payload) = payload.split_at_mut(2 * 4);
114
115 for arg in &msg.args {
117 payload = match *arg {
118 Argument::Int(i) => write_buf(i as u32, payload)?,
119 Argument::Uint(u) => write_buf(u, payload)?,
120 Argument::Fixed(f) => write_buf(f as u32, payload)?,
121 Argument::Str(Some(ref s)) => write_array_to_payload(s.as_bytes_with_nul(), payload)?,
122 Argument::Str(None) => write_array_to_payload(&[], payload)?,
123 Argument::Object(o) => write_buf(o, payload)?,
124 Argument::NewId(n) => write_buf(n, payload)?,
125 Argument::Array(ref a) => write_array_to_payload(a, payload)?,
126 Argument::Fd(fd) => {
127 let dup_fd = unsafe { BorrowedFd::borrow_raw(fd) }
128 .try_clone_to_owned()
129 .map_err(MessageWriteError::DupFdFailed)?;
130 fds.push(dup_fd);
131 payload
132 }
133 };
134 }
135
136 let wrote_size = free_size - payload.len();
137 header[..4].copy_from_slice(&msg.sender_id.to_ne_bytes());
138 header[4..]
139 .copy_from_slice(&(((wrote_size as u32) << 16) | u32::from(msg.opcode)).to_ne_bytes());
140 Ok(orig_payload_len - payload.len())
141}
142
143#[allow(clippy::type_complexity)]
151pub fn parse_message<'a>(
152 raw: &'a [u8],
153 signature: &[ArgumentType],
154 fds: &mut VecDeque<OwnedFd>,
155) -> Result<(Message<u32, OwnedFd>, &'a [u8]), MessageParseError> {
156 fn read_array_from_payload(
158 array_len: usize,
159 payload: &[u8],
160 ) -> Result<(&[u8], &[u8]), MessageParseError> {
161 let len = next_multiple_of(array_len, 4);
162 if len > payload.len() {
163 return Err(MessageParseError::MissingData);
164 }
165 Ok((&payload[..array_len], &payload[len..]))
166 }
167
168 if raw.len() < 2 * 4 {
169 return Err(MessageParseError::MissingData);
170 }
171
172 let sender_id = u32::from_ne_bytes([raw[0], raw[1], raw[2], raw[3]]);
173 let word_2 = u32::from_ne_bytes([raw[4], raw[5], raw[6], raw[7]]);
174 let opcode = (word_2 & 0x0000_FFFF) as u16;
175 let len = (word_2 >> 16) as usize;
176
177 if len < 2 * 4 {
178 return Err(MessageParseError::Malformed);
179 } else if len > raw.len() {
180 return Err(MessageParseError::MissingData);
181 }
182
183 let fd_len = signature.iter().filter(|x| matches!(x, ArgumentType::Fd)).count();
184 if fd_len > fds.len() {
185 return Err(MessageParseError::MissingFD);
186 }
187
188 let (mut payload, rest) = raw.split_at(len);
189 payload = &payload[2 * 4..];
190
191 let arguments = signature
192 .iter()
193 .map(|argtype| {
194 if let ArgumentType::Fd = *argtype {
195 if let Some(front) = fds.pop_front() {
197 Ok(Argument::Fd(front))
198 } else {
199 Err(MessageParseError::MissingFD)
200 }
201 } else if payload.len() >= 4 {
202 let (front, mut tail) = payload.split_at(4);
203 let front = u32::from_ne_bytes(front.try_into().unwrap());
204 let arg = match *argtype {
205 ArgumentType::Int => Ok(Argument::Int(front as i32)),
206 ArgumentType::Uint => Ok(Argument::Uint(front)),
207 ArgumentType::Fixed => Ok(Argument::Fixed(front as i32)),
208 ArgumentType::Str(_) => {
209 read_array_from_payload(front as usize, tail).and_then(|(v, rest)| {
210 tail = rest;
211 if !v.is_empty() {
212 match CStr::from_bytes_with_nul(v) {
213 Ok(s) => Ok(Argument::Str(Some(Box::new(s.into())))),
214 Err(_) => Err(MessageParseError::Malformed),
215 }
216 } else {
217 Ok(Argument::Str(None))
218 }
219 })
220 }
221 ArgumentType::Object(_) => Ok(Argument::Object(front)),
222 ArgumentType::NewId => Ok(Argument::NewId(front)),
223 ArgumentType::Array => {
224 read_array_from_payload(front as usize, tail).map(|(v, rest)| {
225 tail = rest;
226 Argument::Array(Box::new(v.into()))
227 })
228 }
229 ArgumentType::Fd => unreachable!(),
230 };
231 payload = tail;
232 arg
233 } else {
234 Err(MessageParseError::MissingData)
235 }
236 })
237 .collect::<Result<SmallVec<_>, MessageParseError>>()?;
238
239 let msg = Message { sender_id, opcode, args: arguments };
240 Ok((msg, rest))
241}
242
243fn next_multiple_of(lhs: usize, rhs: usize) -> usize {
245 match lhs % rhs {
246 0 => lhs,
247 r => lhs + (rhs - r),
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::protocol::AllowNull;
255 use smallvec::smallvec;
256 use std::{ffi::CString, os::unix::io::IntoRawFd};
257
258 #[test]
259 fn into_from_raw_cycle() {
260 let mut bytes_buffer = vec![0; 1024];
261 let mut fd_buffer = Vec::new();
262
263 let msg = Message {
264 sender_id: 42,
265 opcode: 7,
266 args: smallvec![
267 Argument::Uint(3),
268 Argument::Fixed(-89),
269 Argument::Str(Some(Box::new(CString::new(&b"I like trains!"[..]).unwrap()))),
270 Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
271 Argument::Object(88),
272 Argument::NewId(56),
273 Argument::Int(-25),
274 ],
275 };
276 write_to_buffers(&msg, &mut bytes_buffer[..], &mut fd_buffer).unwrap();
278 let mut fd_buffer = VecDeque::from(fd_buffer);
280 let (rebuilt, _) = parse_message(
281 &bytes_buffer[..],
282 &[
283 ArgumentType::Uint,
284 ArgumentType::Fixed,
285 ArgumentType::Str(AllowNull::No),
286 ArgumentType::Array,
287 ArgumentType::Object(AllowNull::No),
288 ArgumentType::NewId,
289 ArgumentType::Int,
290 ],
291 &mut fd_buffer,
292 )
293 .unwrap();
294 assert_eq!(rebuilt.map_fd(IntoRawFd::into_raw_fd), msg);
295 }
296}