Skip to content
245 changes: 235 additions & 10 deletions rodbus/src/client/task.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::num::NonZeroUsize;
use std::time::Duration;

use tracing::Instrument;
Expand All @@ -21,6 +22,8 @@ pub(crate) enum SessionError {
BadFrame,
/// channel was disabled
Disabled,
/// maximum number of consecutive response timeouts reached
MaxTimeouts(usize),
/// the mpsc is closed (dropped) on the sender side
Shutdown,
}
Expand Down Expand Up @@ -58,6 +61,9 @@ impl std::fmt::Display for SessionError {
SessionError::Shutdown => {
write!(f, "Shutdown was requested")
}
SessionError::MaxTimeouts(max) => {
write!(f, "Maximum number ({max}) of consecutive timeouts reached")
}
}
}
}
Expand All @@ -73,11 +79,58 @@ impl SessionError {
}
}

enum TimeoutCounterState {
Disabled,
Enabled { current: usize, max: usize },
}

struct TimeoutCounter {
state: TimeoutCounterState,
}

impl TimeoutCounter {
fn new(max_timeouts: Option<NonZeroUsize>) -> Self {
Self {
state: match max_timeouts {
None => TimeoutCounterState::Disabled,
Some(max) => TimeoutCounterState::Enabled {
current: 0,
max: max.get(),
},
},
}
}

fn reset(&mut self) {
match &mut self.state {
TimeoutCounterState::Disabled => {}
TimeoutCounterState::Enabled { current, .. } => {
*current = 0;
}
}
}

fn increment(&mut self) -> Result<(), SessionError> {
match &mut self.state {
TimeoutCounterState::Disabled => Ok(()),
TimeoutCounterState::Enabled { current, max } => {
*current = current.saturating_add(1);
if current >= max {
Err(SessionError::MaxTimeouts(*max))
} else {
Ok(())
}
}
}
}
}

pub(crate) struct ClientLoop {
rx: crate::channel::Receiver<Command>,
writer: FrameWriter,
reader: FramedReader,
tx_id: TxId,
timeout_counter: TimeoutCounter,
decode: DecodeLevel,
enabled: bool,
}
Expand All @@ -88,12 +141,14 @@ impl ClientLoop {
writer: FrameWriter,
reader: FramedReader,
decode: DecodeLevel,
max_timeouts: Option<NonZeroUsize>,
) -> Self {
Self {
rx,
writer,
reader,
tx_id: TxId::default(),
timeout_counter: TimeoutCounter::new(max_timeouts),
decode,
enabled: false,
}
Expand Down Expand Up @@ -129,6 +184,7 @@ impl ClientLoop {
}

pub(crate) async fn run(&mut self, io: &mut PhysLayer) -> SessionError {
self.timeout_counter.reset();
loop {
if let Err(err) = self.poll(io).await {
tracing::warn!("ending session: {err}");
Expand Down Expand Up @@ -169,16 +225,29 @@ impl ClientLoop {
.instrument(tracing::info_span!("Transaction", tx_id = %tx_id))
.await;

if let Err(err) = result {
// Fail the request in ONE place. If the whole future
// gets dropped, then the request gets failed with Shutdown
tracing::warn!("request error: {}", err);
request.details.fail(err);
match result {
Ok(()) => self.timeout_counter.reset(),
Err(err) => {
// Fail the request in ONE place. If the whole future
// gets dropped, then the request gets failed with Shutdown
tracing::warn!("request error: {}", err);
request.details.fail(err);

// some request errors are a session error that will
// bubble up and close the session
if let Some(err) = SessionError::from_request_err(err) {
return Err(err);
}

// some request errors are a session error that will
// bubble up and close the session
if let Some(err) = SessionError::from_request_err(err) {
return Err(err);
if err == RequestError::ResponseTimeout {
// if we reach the maximum number of consecutive timeouts,
// this can also terminate the session
self.timeout_counter.increment()?;
} else {
// all other errors reset the response timeout counter,
// e.g. a Modbus exception
self.timeout_counter.reset();
}
}
}

Expand Down Expand Up @@ -307,7 +376,9 @@ mod tests {

use sfio_tokio_mock_io::Event;

fn spawn_client_loop() -> (
fn spawn_client_loop_with_max_timeouts(
max_timeouts: Option<NonZeroUsize>,
) -> (
Channel,
tokio::task::JoinHandle<SessionError>,
sfio_tokio_mock_io::Handle,
Expand All @@ -319,6 +390,7 @@ mod tests {
FrameWriter::tcp(),
FramedReader::tcp(),
DecodeLevel::default().application(AppDecodeLevel::DataValues),
max_timeouts,
);
let join_handle = tokio::spawn(async move {
let mut phys = PhysLayer::new_mock(mock);
Expand All @@ -328,6 +400,14 @@ mod tests {
(channel, join_handle, io_handle)
}

fn spawn_client_loop() -> (
Channel,
tokio::task::JoinHandle<SessionError>,
sfio_tokio_mock_io::Handle,
) {
spawn_client_loop_with_max_timeouts(None)
}

fn get_framed_adu<T>(function: FunctionCode, payload: &T) -> Vec<u8>
where
T: Serialize + Loggable + Sized,
Expand Down Expand Up @@ -464,4 +544,149 @@ mod tests {
vec![Indexed::new(7, true), Indexed::new(8, false)]
);
}

#[tokio::test]
async fn terminates_after_max_consecutive_timeouts() {
let (channel, task, mut io) = spawn_client_loop_with_max_timeouts(NonZeroUsize::new(3));

tokio::time::pause();

let range = AddressRange::try_from(7, 2).unwrap();

// spawn 3 requests that will all timeout
for _ in 0..3 {
let mut ch = channel.clone();
tokio::spawn(async move {
ch.read_coils(
RequestParam::new(UnitId::new(1), Duration::from_secs(1)),
range,
)
.await
});

// wait for write, don't care about exact tx_id
match io.next_event().await {
Event::Write(_) => {}
other => panic!("Expected Write, got {:?}", other),
}
}

// session should terminate with MaxTimeouts(3)
assert_eq!(task.await.unwrap(), SessionError::MaxTimeouts(3));
}

#[tokio::test]
async fn disabled_when_none_allows_unlimited_timeouts() {
let (channel, task, mut io) = spawn_client_loop_with_max_timeouts(None);

tokio::time::pause();

let range = AddressRange::try_from(7, 2).unwrap();

// send 10 requests that all timeout
for _ in 0..10 {
let mut ch = channel.clone();
tokio::spawn(async move {
ch.read_coils(
RequestParam::new(UnitId::new(1), Duration::from_secs(1)),
range,
)
.await
});

match io.next_event().await {
Event::Write(_) => {}
other => panic!("Expected Write, got {:?}", other),
}
}

// task should still be running
assert!(!task.is_finished());
}

#[tokio::test]
async fn counter_resets_on_successful_request() {
let (channel, task, mut io) = spawn_client_loop_with_max_timeouts(NonZeroUsize::new(3));

tokio::time::pause();

let range = AddressRange::try_from(7, 2).unwrap();

// Pattern: timeout -> timeout -> success -> timeout -> timeout
// With max=3, this should NOT terminate because the success resets the counter

// First two timeouts
for _ in 0..2 {
let mut ch = channel.clone();
tokio::spawn(async move {
ch.read_coils(
RequestParam::new(UnitId::new(1), Duration::from_secs(1)),
range,
)
.await
});
match io.next_event().await {
Event::Write(_) => {}
other => panic!("Expected Write, got {:?}", other),
}
}

// Successful request
let success_task = tokio::spawn({
let mut ch = channel.clone();
async move {
ch.read_coils(
RequestParam::new(UnitId::new(1), Duration::from_secs(1)),
range,
)
.await
}
});

// Get the request and respond with matching tx_id
let request_bytes = match io.next_event().await {
Event::Write(bytes) => bytes,
other => panic!("Expected Write, got {:?}", other),
};

let mut response = get_framed_adu(
FunctionCode::ReadCoils,
&BitWriter::new(ReadBitsRange { inner: range }, |idx| match idx {
7 => Ok(true),
8 => Ok(false),
_ => Err(ExceptionCode::IllegalDataAddress),
}),
);
response[0] = request_bytes[0];
response[1] = request_bytes[1];

io.read(&response);

// The response will be read by the client loop
match io.next_event().await {
Event::Read => {} // Expected - client loop reads our response
other => panic!("Expected Read after providing response, got {:?}", other),
}

assert!(success_task.await.unwrap().is_ok());

// Two more timeouts - should NOT terminate since counter was reset
for _ in 0..2 {
let mut ch = channel.clone();
tokio::spawn(async move {
ch.read_coils(
RequestParam::new(UnitId::new(1), Duration::from_secs(1)),
range,
)
.await
});
match io.next_event().await {
Event::Write(_) => {}
other => panic!("Expected Write, got {:?}", other),
}
}

// Task should still be running (only 2 consecutive timeouts, not 3)
assert!(!task.is_finished());
}
}
6 changes: 5 additions & 1 deletion rodbus/src/serial/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ impl SerialChannelTask {
FrameWriter::rtu(),
FramedReader::rtu_response(),
decode,
None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The max_timeouts option is hardcoded to None for the serial client, which disables the new consecutive timeout limit feature. The PR description and the addition of SessionError::MaxTimeouts handling for serial connections suggest this feature should be available for serial clients as well. To fix this, the max_timeouts value should be configurable, likely by passing ClientOptions to SerialChannelTask::new similarly to how it's done for the TCP client.

),
listener,
}
Expand Down Expand Up @@ -83,7 +84,10 @@ impl SerialChannelTask {
// don't wait, we're disabled
SessionError::Disabled => Ok(()),
// wait before retrying
SessionError::IoError(_) | SessionError::BadFrame => {
SessionError::IoError(_)
| SessionError::BadFrame
| SessionError::MaxTimeouts(_) => {
drop(phys);
let delay = self.retry.after_disconnect();
self.listener.update(PortState::Wait(delay)).get().await;
tracing::warn!("waiting {} ms to re-open port", delay.as_millis());
Expand Down
6 changes: 5 additions & 1 deletion rodbus/src/tcp/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ impl TcpChannelTask {
FrameWriter::tcp(),
FramedReader::tcp(),
options.decode_level,
options.max_timeouts,
),
listener,
channel_logging: options.channel_logging,
Expand Down Expand Up @@ -175,8 +176,11 @@ impl TcpChannelTask {
match self.client_loop.run(&mut phys).await {
// the mpsc was closed, end the task
SessionError::Shutdown => Err(StateChange::Shutdown),
// don't wait, we're disabled
SessionError::Disabled => Ok(()),
// re-establish the connection
SessionError::Disabled | SessionError::IoError(_) | SessionError::BadFrame => {
SessionError::IoError(_) | SessionError::BadFrame | SessionError::MaxTimeouts(_) => {
drop(phys);
let delay = self.connect_retry.after_disconnect();
log_channel_event!(self.channel_logging, "waiting {:?} to reconnect", delay);
self.listener
Expand Down
Loading