diff --git a/src/comm/comm.go b/src/comm/comm.go index 47267b2..7ad3203 100644 --- a/src/comm/comm.go +++ b/src/comm/comm.go @@ -2,10 +2,9 @@ package comm import ( "bytes" + "encoding/binary" "fmt" "net" - "strconv" - "strings" "time" "github.com/pkg/errors" @@ -16,6 +15,19 @@ type Comm struct { connection net.Conn } +// NewConnection gets a new comm to a tcp address +func NewConnection(address string) (c Comm, err error) { + connection, err := net.DialTimeout("tcp", address, 3*time.Hour) + if err != nil { + return + } + connection.SetReadDeadline(time.Now().Add(3 * time.Hour)) + connection.SetDeadline(time.Now().Add(3 * time.Hour)) + connection.SetWriteDeadline(time.Now().Add(3 * time.Hour)) + c = New(connection) + return +} + // New returns a new comm func New(c net.Conn) Comm { c.SetReadDeadline(time.Now().Add(3 * time.Hour)) @@ -35,10 +47,12 @@ func (c Comm) Close() { } func (c Comm) Write(b []byte) (int, error) { - tmpCopy := make([]byte, len(b)+5) - // Copy the buffer so it doesn't get changed while read by the recipient. - copy(tmpCopy[:5], []byte(fmt.Sprintf("%0.5d", len(b)))) - copy(tmpCopy[5:], b) + header := new(bytes.Buffer) + err := binary.Write(header, binary.LittleEndian, uint32(len(b))) + if err != nil { + fmt.Println("binary.Write failed:", err) + } + tmpCopy := append(header.Bytes(), b...) n, err := c.connection.Write(tmpCopy) if n != len(tmpCopy) { if err != nil { @@ -53,68 +67,48 @@ func (c Comm) Write(b []byte) (int, error) { func (c Comm) Read() (buf []byte, numBytes int, bs []byte, err error) { // read until we get 5 bytes - tmp := make([]byte, 5) - n, err := c.connection.Read(tmp) + header := make([]byte, 4) + n, err := c.connection.Read(header) if err != nil { return } - tmpCopy := make([]byte, n) - // Copy the buffer so it doesn't get changed while read by the recipient. - copy(tmpCopy, tmp[:n]) - bs = tmpCopy - - tmp = make([]byte, 1) - for { - // see if we have enough bytes - bs = bytes.Trim(bs, "\x00") - if len(bs) == 5 { - break - } - n, err := c.connection.Read(tmp) - if err != nil { - return nil, 0, nil, err - } - tmpCopy = make([]byte, n) - // Copy the buffer so it doesn't get changed while read by the recipient. - copy(tmpCopy, tmp[:n]) - bs = append(bs, tmpCopy...) + if n < 4 { + err = fmt.Errorf("not enough bytes: %d", n) + return } + // make it so it won't change + header = append([]byte(nil), header...) - numBytes, err = strconv.Atoi(strings.TrimLeft(string(bs), "0")) + var numBytesUint32 uint32 + rbuf := bytes.NewReader(header) + err = binary.Read(rbuf, binary.LittleEndian, &numBytesUint32) if err != nil { - return nil, 0, nil, err + fmt.Println("binary.Read failed:", err) } - buf = []byte{} - tmp = make([]byte, numBytes) + numBytes = int(numBytesUint32) for { - n, err := c.connection.Read(tmp) - if err != nil { - return nil, 0, nil, err + tmp := make([]byte, numBytes) + n, errRead := c.connection.Read(tmp) + if errRead != nil { + err = errRead + return } - tmpCopy = make([]byte, n) - // Copy the buffer so it doesn't get changed while read by the recipient. - copy(tmpCopy, tmp[:n]) - buf = append(buf, bytes.TrimRight(tmpCopy, "\x00")...) - if len(buf) < numBytes { - // shrink the amount we need to read - tmp = tmp[:numBytes-len(buf)] - } else { + buf = append(buf, tmp[:n]...) + if numBytes == len(buf) { break } } - // log.Printf("wanted %d and got %d", numBytes, len(buf)) return } // Send a message -func (c Comm) Send(message string) (err error) { - _, err = c.Write([]byte(message)) +func (c Comm) Send(message []byte) (err error) { + _, err = c.Write(message) return } // Receive a message -func (c Comm) Receive() (s string, err error) { - b, _, _, err := c.Read() - s = string(b) +func (c Comm) Receive() (b []byte, err error) { + b, _, _, err = c.Read() return } diff --git a/src/comm/comm_test.go b/src/comm/comm_test.go new file mode 100644 index 0000000..41430bf --- /dev/null +++ b/src/comm/comm_test.go @@ -0,0 +1,52 @@ +package comm + +import ( + "net" + "testing" + "time" + + log "github.com/cihub/seelog" + "github.com/stretchr/testify/assert" +) + +func TestComm(t *testing.T) { + defer log.Flush() + + port := "8001" + go func() { + log.Debugf("starting TCP server on " + port) + server, err := net.Listen("tcp", "0.0.0.0:"+port) + if err != nil { + log.Error(err) + } + defer server.Close() + // spawn a new goroutine whenever a client connects + for { + connection, err := server.Accept() + if err != nil { + log.Error(err) + } + log.Debugf("client %s connected", connection.RemoteAddr().String()) + go func(port string, connection net.Conn) { + c := New(connection) + err = c.Send([]byte("hello, world")) + assert.Nil(t, err) + data, err := c.Receive() + assert.Nil(t, err) + assert.Equal(t, []byte("hello, computer"), data) + data, err = c.Receive() + assert.Nil(t, err) + assert.Equal(t, []byte{'\x00'}, data) + }(port, connection) + } + }() + + time.Sleep(100 * time.Millisecond) + a, err := NewConnection("localhost:" + port) + assert.Nil(t, err) + data, err := a.Receive() + assert.Equal(t, []byte("hello, world"), data) + assert.Nil(t, err) + assert.Nil(t, a.Send([]byte("hello, computer"))) + assert.Nil(t, a.Send([]byte{'\x00'})) +} diff --git a/src/tcp/tcp.go b/src/tcp/tcp.go index 3cece44..599a58b 100644 --- a/src/tcp/tcp.go +++ b/src/tcp/tcp.go @@ -13,8 +13,10 @@ import ( ) type roomInfo struct { - receiver comm.Comm - opened time.Time + first comm.Comm + second comm.Comm + opened time.Time + full bool } type roomMap struct { @@ -77,37 +79,53 @@ func run(port string) (err error) { func clientCommuncation(port string, c comm.Comm) (err error) { // send ok to tell client they are connected log.Debug("sending ok") - err = c.Send("ok") + err = c.Send([]byte("ok")) if err != nil { return } // wait for client to tell me which room they want log.Debug("waiting for answer") - room, err := c.Receive() + roomBytes, err := c.Receive() if err != nil { return } + room := string(roomBytes) rooms.Lock() - // first connection is always the receiver + // create the room if it is new if _, ok := rooms.rooms[room]; !ok { rooms.rooms[room] = roomInfo{ - receiver: c, - opened: time.Now(), + first: c, + opened: time.Now(), } rooms.Unlock() // tell the client that they got the room - err = c.Send("recipient") + err = c.Send([]byte("ok")) if err != nil { log.Error(err) return } - log.Debug("recipient connected") + log.Debugf("room %s has 1", room) return nil } - log.Debug("sender connected") - receiver := rooms.rooms[room].receiver + if rooms.rooms[room].full { + rooms.Unlock() + err = c.Send([]byte("room full")) + if err != nil { + log.Error(err) + return + } + return nil + } + log.Debugf("room %s has 2", room) + rooms.rooms[room] = roomInfo{ + first: rooms.rooms[room].first, + second: c, + opened: rooms.rooms[room].opened, + full: true, + } + otherConnection := rooms.rooms[room].first rooms.Unlock() // second connection is the sender, time to staple connections @@ -120,10 +138,10 @@ func clientCommuncation(port string, c comm.Comm) (err error) { pipe(com1.Connection(), com2.Connection()) wg.Done() log.Debug("done piping") - }(c, receiver, &wg) + }(otherConnection, c, &wg) // tell the sender everything is ready - err = c.Send("sender") + err = c.Send([]byte("ok")) if err != nil { return } @@ -132,6 +150,8 @@ func clientCommuncation(port string, c comm.Comm) (err error) { // delete room rooms.Lock() log.Debugf("deleting room: %s", room) + rooms.rooms[room].first.Close() + rooms.rooms[room].second.Close() delete(rooms.rooms, room) rooms.Unlock() return nil diff --git a/src/tcp/tcp_test.go b/src/tcp/tcp_test.go new file mode 100644 index 0000000..d336b3d --- /dev/null +++ b/src/tcp/tcp_test.go @@ -0,0 +1,68 @@ +package tcp + +import ( + "bytes" + "fmt" + "testing" + "time" + + "github.com/schollz/croc/src/comm" + "github.com/stretchr/testify/assert" +) + +func TestTCP(t *testing.T) { + go Run("debug", "8081") + time.Sleep(100 * time.Millisecond) + c1, err := ConnectToTCPServer("localhost:8081", "testRoom") + assert.Nil(t, err) + c2, err := ConnectToTCPServer("localhost:8081", "testRoom") + assert.Nil(t, err) + _, err = ConnectToTCPServer("localhost:8081", "testRoom") + assert.NotNil(t, err) + + // try sending data + assert.Nil(t, c1.Send([]byte("hello, c2"))) + data, err := c2.Receive() + assert.Nil(t, err) + assert.Equal(t, []byte("hello, c2"), data) + + assert.Nil(t, c2.Send([]byte("hello, c1"))) + data, err = c1.Receive() + assert.Nil(t, err) + assert.Equal(t, []byte("hello, c1"), data) + + c1.Close() + time.Sleep(200 * time.Millisecond) + err = c2.Send([]byte("test")) + assert.Nil(t, err) + _, err = c2.Receive() + assert.NotNil(t, err) +} + +func ConnectToTCPServer(address, room string) (c comm.Comm, err error) { + c, err = comm.NewConnection("localhost:8081") + if err != nil { + return + } + data, err := c.Receive() + if err != nil { + return + } + if !bytes.Equal(data, []byte("ok")) { + err = fmt.Errorf("got bad response: %s", data) + return + } + err = c.Send([]byte(room)) + if err != nil { + return + } + data, err = c.Receive() + if err != nil { + return + } + if !bytes.Equal(data, []byte("ok")) { + err = fmt.Errorf("got bad response: %s", data) + return + } + return +}