1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
use crate::rt::{object, Access, Location, Synchronize, VersionVec};
use std::collections::VecDeque;
use std::sync::atomic::Ordering::{Acquire, Release};

#[derive(Debug)]
pub(crate) struct Channel {
    state: object::Ref<State>,
}

#[derive(Debug)]
pub(super) struct State {
    /// Count of messages in the channel.
    msg_cnt: usize,

    /// Last access that was a send operation.
    last_send_access: Option<Access>,
    /// Last access that was a receive operation.
    last_recv_access: Option<Access>,

    /// A synchronization point for synchronizing the sending threads and the
    /// channel.
    ///
    /// The `mpsc` channels have a guarantee that the messages will be received
    /// in the same order in which they were sent. Therefore, if thread `t1`
    /// managed to send `m1` before `t2` sent `m2`, the thread that received
    /// `m2` can be sure that `m1` was already sent and received. In other
    /// words, it is sound for the receiver of `m2` to know that `m1` happened
    /// before `m2`. That is why we have a single `sender_synchronize` for
    /// senders which we use to "timestamp" each message put in the channel.
    /// However, in our example, the receiver of `m1` does not know whether `m2`
    /// was already sent or not and, therefore, by reading from the channel it
    /// should not learn any facts about `happens_before(send(m2), recv(m1))`.
    /// That is why we cannot use single `Synchronize` for the entire channel
    /// and on the receiver side we need to use `Synchronize` per message.
    sender_synchronize: Synchronize,
    /// A synchronization point per message synchronizing the receiving thread
    /// with the channel state at the point when the received message was sent.
    receiver_synchronize: VecDeque<Synchronize>,

    created: Location,
}

/// Actions performed on the Channel.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(super) enum Action {
    /// Send a message
    MsgSend,
    /// Receive a message
    MsgRecv,
}

impl Channel {
    pub(crate) fn new(location: Location) -> Self {
        super::execution(|execution| {
            let state = execution.objects.insert(State {
                msg_cnt: 0,
                last_send_access: None,
                last_recv_access: None,
                sender_synchronize: Synchronize::new(),
                receiver_synchronize: VecDeque::new(),
                created: location,
            });

            tracing::trace!(?state, %location, "mpsc::channel");
            Self { state }
        })
    }

    pub(crate) fn send(&self, location: Location) {
        self.state.branch_action(Action::MsgSend, location);
        super::execution(|execution| {
            let state = self.state.get_mut(&mut execution.objects);
            state.msg_cnt = state.msg_cnt.checked_add(1).expect("overflow");

            state
                .sender_synchronize
                .sync_store(&mut execution.threads, Release);
            state
                .receiver_synchronize
                .push_back(state.sender_synchronize);

            if state.msg_cnt == 1 {
                // Unblock all threads that are blocked waiting on this channel
                let thread_id = execution.threads.active_id();
                for (id, thread) in execution.threads.iter_mut() {
                    if id == thread_id {
                        continue;
                    }

                    let obj = thread
                        .operation
                        .as_ref()
                        .map(|operation| operation.object());

                    if obj == Some(self.state.erase()) {
                        thread.set_runnable();
                    }
                }
            }
        })
    }

    pub(crate) fn recv(&self, location: Location) {
        self.state
            .branch_disable(Action::MsgRecv, self.is_empty(), location);
        super::execution(|execution| {
            let state = self.state.get_mut(&mut execution.objects);
            let thread_id = execution.threads.active_id();
            state.msg_cnt = state
                .msg_cnt
                .checked_sub(1)
                .expect("expected to be able to read the message");
            let mut synchronize = state.receiver_synchronize.pop_front().unwrap();
            dbg!(synchronize.sync_load(&mut execution.threads, Acquire));
            if state.msg_cnt == 0 {
                // Block all **other** threads attempting to read from the channel
                for (id, thread) in execution.threads.iter_mut() {
                    if id == thread_id {
                        continue;
                    }

                    if let Some(operation) = thread.operation.as_ref() {
                        if operation.object() == self.state.erase()
                            && operation.action() == object::Action::Channel(Action::MsgRecv)
                        {
                            let location = operation.location();
                            thread.set_blocked(location);
                        }
                    }
                }
            }
        })
    }

    /// Returns `true` if the channel is currently empty
    pub(crate) fn is_empty(&self) -> bool {
        super::execution(|execution| self.get_state(&mut execution.objects).msg_cnt == 0)
    }

    fn get_state<'a>(&self, objects: &'a mut object::Store) -> &'a mut State {
        self.state.get_mut(objects)
    }
}

impl State {
    pub(super) fn check_for_leaks(&self, index: usize) {
        if self.msg_cnt != 0 {
            if self.created.is_captured() {
                panic!(
                    "Messages leaked.\n  \
                    Channel created: {}\n            \
                    Index: {}\n        \
                    Messages: {}",
                    self.created, index, self.msg_cnt
                );
            } else {
                panic!(
                    "Messages leaked.\n     Index: {}\n  Messages: {}",
                    index, self.msg_cnt
                );
            }
        }
    }

    pub(super) fn last_dependent_access(&self, action: Action) -> Option<&Access> {
        match action {
            Action::MsgSend => self.last_send_access.as_ref(),
            Action::MsgRecv => self.last_recv_access.as_ref(),
        }
    }

    pub(super) fn set_last_access(&mut self, action: Action, path_id: usize, version: &VersionVec) {
        match action {
            Action::MsgSend => Access::set_or_create(&mut self.last_send_access, path_id, version),
            Action::MsgRecv => Access::set_or_create(&mut self.last_recv_access, path_id, version),
        }
    }
}