diff --git a/rodbus/src/client/task.rs b/rodbus/src/client/task.rs index f9c5ee00..b3a00447 100644 --- a/rodbus/src/client/task.rs +++ b/rodbus/src/client/task.rs @@ -1,3 +1,4 @@ +use std::num::NonZeroUsize; use std::time::Duration; use tracing::Instrument; @@ -21,6 +22,8 @@ pub(crate) enum SessionError { BadFrame, /// channel was disabled Disabled, + /// maximum number of consecutive response timeouts reached + MaxTimeouts(usize), /// the mpsc is closed (dropped) on the sender side Shutdown, } @@ -58,6 +61,9 @@ impl std::fmt::Display for SessionError { SessionError::Shutdown => { write!(f, "Shutdown was requested") } + SessionError::MaxTimeouts(max) => { + write!(f, "Maximum number ({max}) of consecutive timeouts reached") + } } } } @@ -73,11 +79,58 @@ impl SessionError { } } +enum TimeoutCounterState { + Disabled, + Enabled { current: usize, max: usize }, +} + +struct TimeoutCounter { + state: TimeoutCounterState, +} + +impl TimeoutCounter { + fn new(max_timeouts: Option) -> Self { + Self { + state: match max_timeouts { + None => TimeoutCounterState::Disabled, + Some(max) => TimeoutCounterState::Enabled { + current: 0, + max: max.get(), + }, + }, + } + } + + fn reset(&mut self) { + match &mut self.state { + TimeoutCounterState::Disabled => {} + TimeoutCounterState::Enabled { current, .. } => { + *current = 0; + } + } + } + + fn increment(&mut self) -> Result<(), SessionError> { + match &mut self.state { + TimeoutCounterState::Disabled => Ok(()), + TimeoutCounterState::Enabled { current, max } => { + *current = current.saturating_add(1); + if current >= max { + Err(SessionError::MaxTimeouts(*max)) + } else { + Ok(()) + } + } + } + } +} + pub(crate) struct ClientLoop { rx: crate::channel::Receiver, writer: FrameWriter, reader: FramedReader, tx_id: TxId, + timeout_counter: TimeoutCounter, decode: DecodeLevel, enabled: bool, } @@ -88,12 +141,14 @@ impl ClientLoop { writer: FrameWriter, reader: FramedReader, decode: DecodeLevel, + max_timeouts: Option, ) -> Self { Self { rx, writer, reader, tx_id: TxId::default(), + timeout_counter: TimeoutCounter::new(max_timeouts), decode, enabled: false, } @@ -129,6 +184,7 @@ impl ClientLoop { } pub(crate) async fn run(&mut self, io: &mut PhysLayer) -> SessionError { + self.timeout_counter.reset(); loop { if let Err(err) = self.poll(io).await { tracing::warn!("ending session: {err}"); @@ -169,16 +225,29 @@ impl ClientLoop { .instrument(tracing::info_span!("Transaction", tx_id = %tx_id)) .await; - if let Err(err) = result { - // Fail the request in ONE place. If the whole future - // gets dropped, then the request gets failed with Shutdown - tracing::warn!("request error: {}", err); - request.details.fail(err); + match result { + Ok(()) => self.timeout_counter.reset(), + Err(err) => { + // Fail the request in ONE place. If the whole future + // gets dropped, then the request gets failed with Shutdown + tracing::warn!("request error: {}", err); + request.details.fail(err); + + // some request errors are a session error that will + // bubble up and close the session + if let Some(err) = SessionError::from_request_err(err) { + return Err(err); + } - // some request errors are a session error that will - // bubble up and close the session - if let Some(err) = SessionError::from_request_err(err) { - return Err(err); + if err == RequestError::ResponseTimeout { + // if we reach the maximum number of consecutive timeouts, + // this can also terminate the session + self.timeout_counter.increment()?; + } else { + // all other errors reset the response timeout counter, + // e.g. a Modbus exception + self.timeout_counter.reset(); + } } } @@ -307,7 +376,9 @@ mod tests { use sfio_tokio_mock_io::Event; - fn spawn_client_loop() -> ( + fn spawn_client_loop_with_max_timeouts( + max_timeouts: Option, + ) -> ( Channel, tokio::task::JoinHandle, sfio_tokio_mock_io::Handle, @@ -319,6 +390,7 @@ mod tests { FrameWriter::tcp(), FramedReader::tcp(), DecodeLevel::default().application(AppDecodeLevel::DataValues), + max_timeouts, ); let join_handle = tokio::spawn(async move { let mut phys = PhysLayer::new_mock(mock); @@ -328,6 +400,14 @@ mod tests { (channel, join_handle, io_handle) } + fn spawn_client_loop() -> ( + Channel, + tokio::task::JoinHandle, + sfio_tokio_mock_io::Handle, + ) { + spawn_client_loop_with_max_timeouts(None) + } + fn get_framed_adu(function: FunctionCode, payload: &T) -> Vec where T: Serialize + Loggable + Sized, @@ -464,4 +544,149 @@ mod tests { vec![Indexed::new(7, true), Indexed::new(8, false)] ); } + + #[tokio::test] + async fn terminates_after_max_consecutive_timeouts() { + let (channel, task, mut io) = spawn_client_loop_with_max_timeouts(NonZeroUsize::new(3)); + + tokio::time::pause(); + + let range = AddressRange::try_from(7, 2).unwrap(); + + // spawn 3 requests that will all timeout + for _ in 0..3 { + let mut ch = channel.clone(); + tokio::spawn(async move { + ch.read_coils( + RequestParam::new(UnitId::new(1), Duration::from_secs(1)), + range, + ) + .await + }); + + // wait for write, don't care about exact tx_id + match io.next_event().await { + Event::Write(_) => {} + other => panic!("Expected Write, got {:?}", other), + } + } + + // session should terminate with MaxTimeouts(3) + assert_eq!(task.await.unwrap(), SessionError::MaxTimeouts(3)); + } + + #[tokio::test] + async fn disabled_when_none_allows_unlimited_timeouts() { + let (channel, task, mut io) = spawn_client_loop_with_max_timeouts(None); + + tokio::time::pause(); + + let range = AddressRange::try_from(7, 2).unwrap(); + + // send 10 requests that all timeout + for _ in 0..10 { + let mut ch = channel.clone(); + tokio::spawn(async move { + ch.read_coils( + RequestParam::new(UnitId::new(1), Duration::from_secs(1)), + range, + ) + .await + }); + + match io.next_event().await { + Event::Write(_) => {} + other => panic!("Expected Write, got {:?}", other), + } + } + + // task should still be running + assert!(!task.is_finished()); + } + + #[tokio::test] + async fn counter_resets_on_successful_request() { + let (channel, task, mut io) = spawn_client_loop_with_max_timeouts(NonZeroUsize::new(3)); + + tokio::time::pause(); + + let range = AddressRange::try_from(7, 2).unwrap(); + + // Pattern: timeout -> timeout -> success -> timeout -> timeout + // With max=3, this should NOT terminate because the success resets the counter + + // First two timeouts + for _ in 0..2 { + let mut ch = channel.clone(); + tokio::spawn(async move { + ch.read_coils( + RequestParam::new(UnitId::new(1), Duration::from_secs(1)), + range, + ) + .await + }); + match io.next_event().await { + Event::Write(_) => {} + other => panic!("Expected Write, got {:?}", other), + } + } + + // Successful request + let success_task = tokio::spawn({ + let mut ch = channel.clone(); + async move { + ch.read_coils( + RequestParam::new(UnitId::new(1), Duration::from_secs(1)), + range, + ) + .await + } + }); + + // Get the request and respond with matching tx_id + let request_bytes = match io.next_event().await { + Event::Write(bytes) => bytes, + other => panic!("Expected Write, got {:?}", other), + }; + + let mut response = get_framed_adu( + FunctionCode::ReadCoils, + &BitWriter::new(ReadBitsRange { inner: range }, |idx| match idx { + 7 => Ok(true), + 8 => Ok(false), + _ => Err(ExceptionCode::IllegalDataAddress), + }), + ); + response[0] = request_bytes[0]; + response[1] = request_bytes[1]; + + io.read(&response); + + // The response will be read by the client loop + match io.next_event().await { + Event::Read => {} // Expected - client loop reads our response + other => panic!("Expected Read after providing response, got {:?}", other), + } + + assert!(success_task.await.unwrap().is_ok()); + + // Two more timeouts - should NOT terminate since counter was reset + for _ in 0..2 { + let mut ch = channel.clone(); + tokio::spawn(async move { + ch.read_coils( + RequestParam::new(UnitId::new(1), Duration::from_secs(1)), + range, + ) + .await + }); + match io.next_event().await { + Event::Write(_) => {} + other => panic!("Expected Write, got {:?}", other), + } + } + + // Task should still be running (only 2 consecutive timeouts, not 3) + assert!(!task.is_finished()); + } } diff --git a/rodbus/src/serial/client.rs b/rodbus/src/serial/client.rs index c9226099..b8453e8e 100644 --- a/rodbus/src/serial/client.rs +++ b/rodbus/src/serial/client.rs @@ -34,6 +34,7 @@ impl SerialChannelTask { FrameWriter::rtu(), FramedReader::rtu_response(), decode, + None, ), listener, } @@ -83,7 +84,10 @@ impl SerialChannelTask { // don't wait, we're disabled SessionError::Disabled => Ok(()), // wait before retrying - SessionError::IoError(_) | SessionError::BadFrame => { + SessionError::IoError(_) + | SessionError::BadFrame + | SessionError::MaxTimeouts(_) => { + drop(phys); let delay = self.retry.after_disconnect(); self.listener.update(PortState::Wait(delay)).get().await; tracing::warn!("waiting {} ms to re-open port", delay.as_millis()); diff --git a/rodbus/src/tcp/client.rs b/rodbus/src/tcp/client.rs index d7fd8073..7729d951 100644 --- a/rodbus/src/tcp/client.rs +++ b/rodbus/src/tcp/client.rs @@ -106,6 +106,7 @@ impl TcpChannelTask { FrameWriter::tcp(), FramedReader::tcp(), options.decode_level, + options.max_timeouts, ), listener, channel_logging: options.channel_logging, @@ -175,8 +176,11 @@ impl TcpChannelTask { match self.client_loop.run(&mut phys).await { // the mpsc was closed, end the task SessionError::Shutdown => Err(StateChange::Shutdown), + // don't wait, we're disabled + SessionError::Disabled => Ok(()), // re-establish the connection - SessionError::Disabled | SessionError::IoError(_) | SessionError::BadFrame => { + SessionError::IoError(_) | SessionError::BadFrame | SessionError::MaxTimeouts(_) => { + drop(phys); let delay = self.connect_retry.after_disconnect(); log_channel_event!(self.channel_logging, "waiting {:?} to reconnect", delay); self.listener diff --git a/rodbus/src/types.rs b/rodbus/src/types.rs index 5277a713..5961712d 100644 --- a/rodbus/src/types.rs +++ b/rodbus/src/types.rs @@ -1,6 +1,7 @@ use crate::decode::AppDecodeLevel; use crate::error::{AduParseError, InvalidRange}; use crate::DecodeLevel; +use std::num::NonZeroUsize; use scursor::ReadCursor; @@ -401,10 +402,13 @@ pub struct ClientOptions { pub(crate) channel_logging: ChannelLoggingMode, pub(crate) max_queued_requests: usize, pub(crate) decode_level: DecodeLevel, + pub(crate) max_timeouts: Option, } impl ClientOptions { /// Set the channel logging type + /// + /// Note: defaults to [`ChannelLoggingMode::Verbose`] pub fn channel_logging(self, channel_logging: ChannelLoggingMode) -> Self { Self { channel_logging, @@ -413,6 +417,8 @@ impl ClientOptions { } /// Set the maximum number of queued requests + /// + /// Note: defaults to 16 pub fn max_queued_requests(self, max_queued_requests: usize) -> Self { Self { max_queued_requests, @@ -421,12 +427,28 @@ impl ClientOptions { } /// Set the decode level + /// + /// Note: defaults to [`DecodeLevel::default()`] pub fn decode_level(self, decode_level: DecodeLevel) -> Self { Self { decode_level, ..self } } + + /// Set the maximum number of consecutive response timeouts before forcing a reconnect + /// + /// Useful for detecting dead TCP connections where the remote device stops responding + /// but doesn't send a proper FIN/RST (e.g., due to network issues or third-party interference). + /// The counter resets on any successful request. + /// + /// Defaults to `None` (no limit) + pub fn max_response_timeouts(self, max_timeouts: Option) -> Self { + Self { + max_timeouts, + ..self + } + } } impl Default for ClientOptions { @@ -435,6 +457,7 @@ impl Default for ClientOptions { channel_logging: ChannelLoggingMode::default(), max_queued_requests: 16, decode_level: DecodeLevel::default(), + max_timeouts: None, } } }