diff --git a/src/comm/comm.go b/src/comm/comm.go index c52f925..7e11989 100644 --- a/src/comm/comm.go +++ b/src/comm/comm.go @@ -15,8 +15,23 @@ type Comm struct { connection net.Conn } +func (c *Comm) IsClosed() bool { + one := []byte{} + c.connection.SetReadDeadline(time.Now()) + _, err := c.connection.Read(one) + if err != nil { + fmt.Println(err) + c.connection.Close() + c.connection = nil + return true + } else { + c.connection.SetReadDeadline(time.Now().Add(3 * time.Hour)) + } + return false +} + // NewConnection gets a new comm to a tcp address -func NewConnection(address string) (c Comm, err error) { +func NewConnection(address string) (c *Comm, err error) { connection, err := net.DialTimeout("tcp", address, 3*time.Second) if err != nil { return @@ -26,24 +41,26 @@ func NewConnection(address string) (c Comm, err error) { } // New returns a new comm -func New(c net.Conn) Comm { +func New(c net.Conn) *Comm { c.SetReadDeadline(time.Now().Add(3 * time.Hour)) c.SetDeadline(time.Now().Add(3 * time.Hour)) c.SetWriteDeadline(time.Now().Add(3 * time.Hour)) - return Comm{c} + comm := new(Comm) + comm.connection = c + return comm } // Connection returns the net.Conn connection -func (c Comm) Connection() net.Conn { +func (c *Comm) Connection() net.Conn { return c.connection } // Close closes the connection -func (c Comm) Close() { +func (c *Comm) Close() { c.connection.Close() } -func (c Comm) Write(b []byte) (int, error) { +func (c *Comm) Write(b []byte) (int, error) { header := new(bytes.Buffer) err := binary.Write(header, binary.LittleEndian, uint32(len(b))) if err != nil { @@ -62,7 +79,7 @@ func (c Comm) Write(b []byte) (int, error) { return n, err } -func (c Comm) Read() (buf []byte, numBytes int, bs []byte, err error) { +func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) { // read until we get 5 bytes header := make([]byte, 4) n, err := c.connection.Read(header) @@ -99,13 +116,13 @@ func (c Comm) Read() (buf []byte, numBytes int, bs []byte, err error) { } // Send a message -func (c Comm) Send(message []byte) (err error) { +func (c *Comm) Send(message []byte) (err error) { _, err = c.Write(message) return } // Receive a message -func (c Comm) Receive() (b []byte, err error) { +func (c *Comm) Receive() (b []byte, err error) { b, _, _, err = c.Read() return } diff --git a/src/tcp/tcp.go b/src/tcp/tcp.go index 41021ea..cb3624a 100644 --- a/src/tcp/tcp.go +++ b/src/tcp/tcp.go @@ -14,8 +14,8 @@ import ( const TCP_BUFFER_SIZE = 1024 * 64 type roomInfo struct { - first comm.Comm - second comm.Comm + first *comm.Comm + second *comm.Comm opened time.Time full bool } @@ -77,7 +77,7 @@ func run(port string) (err error) { } } -func clientCommuncation(port string, c comm.Comm) (err error) { +func clientCommuncation(port string, c *comm.Comm) (err error) { // send ok to tell client they are connected log.Debug("sending ok") err = c.Send([]byte("ok")) @@ -134,7 +134,7 @@ func clientCommuncation(port string, c comm.Comm) (err error) { wg.Add(1) // start piping - go func(com1, com2 comm.Comm, wg *sync.WaitGroup) { + go func(com1, com2 *comm.Comm, wg *sync.WaitGroup) { log.Debug("starting pipes") pipe(com1.Connection(), com2.Connection()) wg.Done() @@ -153,6 +153,7 @@ func clientCommuncation(port string, c comm.Comm) (err error) { log.Debugf("deleting room: %s", room) rooms.rooms[room].first.Close() rooms.rooms[room].second.Close() + rooms.rooms[room] = roomInfo{first: nil, second: nil} delete(rooms.rooms, room) rooms.Unlock() return nil diff --git a/src/tcp/tcp_test.go b/src/tcp/tcp_test.go index d336b3d..5f88b65 100644 --- a/src/tcp/tcp_test.go +++ b/src/tcp/tcp_test.go @@ -20,6 +20,7 @@ func TestTCP(t *testing.T) { _, err = ConnectToTCPServer("localhost:8081", "testRoom") assert.NotNil(t, err) + assert.False(t, c1.IsClosed()) // try sending data assert.Nil(t, c1.Send([]byte("hello, c2"))) data, err := c2.Receive() @@ -32,14 +33,13 @@ func TestTCP(t *testing.T) { assert.Equal(t, []byte("hello, c1"), data) c1.Close() + assert.True(t, c1.IsClosed()) + time.Sleep(200 * time.Millisecond) - err = c2.Send([]byte("test")) - assert.Nil(t, err) - _, err = c2.Receive() - assert.NotNil(t, err) + assert.True(t, c2.IsClosed()) } -func ConnectToTCPServer(address, room string) (c comm.Comm, err error) { +func ConnectToTCPServer(address, room string) (c *comm.Comm, err error) { c, err = comm.NewConnection("localhost:8081") if err != nil { return