From 02e80217356d8af83076a33510a13bd353a33d4f Mon Sep 17 00:00:00 2001 From: Zack Scholl Date: Tue, 25 Sep 2018 16:14:41 -0700 Subject: [PATCH] fix recipient --- src/croc/sending.go | 6 +-- src/models/constants.go | 5 ++ src/recipient/recipient.go | 107 +++++++++++++++++++++++++------------ 3 files changed, 82 insertions(+), 36 deletions(-) diff --git a/src/croc/sending.go b/src/croc/sending.go index 1edf3b9..2869644 100644 --- a/src/croc/sending.go +++ b/src/croc/sending.go @@ -30,7 +30,7 @@ func (c *Croc) Send(fname, codephrase string) (err error) { if !c.LocalOnly { go func() { // atttempt to connect to public relay - errChan <- c.sendReceive(c.Address, c.AddressWebsocketPort, c.AddressTCPPort, fname, codephrase, true, false) + errChan <- c.sendReceive(c.Address, c.AddressWebsocketPort, c.AddressTCPPorts, fname, codephrase, true, false) }() } else { waitingFor = 1 @@ -106,7 +106,7 @@ func (c *Croc) Receive(codephrase string) (err error) { if err == nil { if resp.StatusCode == http.StatusOK { // we connected, so use this - return c.sendReceive(discovered[0].Address, strings.TrimSpace(ports[0]), strings.TrimSpace(strings.Split(ports[1], ",")), "", codephrase, false, true) + return c.sendReceive(discovered[0].Address, strings.TrimSpace(ports[0]), strings.Split(strings.TrimSpace(ports[1]), ","), "", codephrase, false, true) } } else { log.Debugf("could not connect: %s", err.Error()) @@ -119,7 +119,7 @@ func (c *Croc) Receive(codephrase string) (err error) { // use public relay if !c.LocalOnly { log.Debug("using public relay") - return c.sendReceive(c.Address, c.AddressWebsocketPort, c.AddressTCPPort, "", codephrase, false, false) + return c.sendReceive(c.Address, c.AddressWebsocketPort, c.AddressTCPPorts, "", codephrase, false, false) } return errors.New("must use local or public relay") diff --git a/src/models/constants.go b/src/models/constants.go index 0c23ef9..6ec61d7 100644 --- a/src/models/constants.go +++ b/src/models/constants.go @@ -2,3 +2,8 @@ package models const WEBSOCKET_BUFFER_SIZE = 1024 * 1024 * 32 const TCP_BUFFER_SIZE = 1024 * 64 + +type BytesAndLocation struct { + Bytes []byte `json:"b"` + Location int64 `json:"l"` +} diff --git a/src/recipient/recipient.go b/src/recipient/recipient.go index 1df0a93..2154ed0 100644 --- a/src/recipient/recipient.go +++ b/src/recipient/recipient.go @@ -31,9 +31,9 @@ import ( var DebugLevel string // Receive is the async operation to receive a file -func Receive(forceSend int, serverAddress, serverTCP string, isLocal bool, done chan struct{}, c *websocket.Conn, codephrase string, noPrompt bool, useStdout bool) { +func Receive(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, done chan struct{}, c *websocket.Conn, codephrase string, noPrompt bool, useStdout bool) { logger.SetLogLevel(DebugLevel) - err := receive(forceSend, serverAddress, serverTCP, isLocal, c, codephrase, noPrompt, useStdout) + err := receive(forceSend, serverAddress, tcpPorts, isLocal, c, codephrase, noPrompt, useStdout) if err != nil { if !strings.HasPrefix(err.Error(), "websocket: close 100") { fmt.Fprintf(os.Stderr, "\n"+err.Error()) @@ -42,13 +42,13 @@ func Receive(forceSend int, serverAddress, serverTCP string, isLocal bool, done done <- struct{}{} } -func receive(forceSend int, serverAddress, serverTCP string, isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool, useStdout bool) (err error) { +func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool, useStdout bool) (err error) { var fstats models.FileStats var sessionKey []byte var transferTime time.Duration var hash256 []byte var otherIP string - var tcpConnection comm.Comm + var tcpConnections []comm.Comm dataChan := make(chan []byte, 1024*1024) useWebsockets := true @@ -174,12 +174,15 @@ func receive(forceSend int, serverAddress, serverTCP string, isLocal bool, c *we // connect to TCP to receive file if !useWebsockets { log.Debugf("connecting to server") - tcpConnection, err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%x", sessionKey)), serverAddress+":"+serverTCP) - if err != nil { - log.Error(err) - return err + tcpConnections := make([]comm.Comm, len(tcpPorts)) + for i, tcpPort := range tcpPorts { + tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%x", sessionKey)), serverAddress+":"+tcpPort) + if err != nil { + log.Error(err) + return err + } + defer tcpConnections[i].Close() } - defer tcpConnection.Close() } // await file @@ -219,8 +222,20 @@ func receive(forceSend int, serverAddress, serverTCP string, isLocal bool, c *we decrypted = compress.Decompress(decrypted) } - // write to file - n, err := f.Write(decrypted) + var n int + if !useWebsockets { + var bl models.BytesAndLocation + err = json.Unmarshal(decrypted, &bl) + if err != nil { + log.Error(err) + return err + } + n, err = f.WriteAt(bl.Bytes, bl.Location) + } else { + // write to file + n, err = f.Write(decrypted) + } + if err != nil { return err } @@ -238,35 +253,61 @@ func receive(forceSend int, serverAddress, serverTCP string, isLocal bool, c *we }(finished, dataChan) c.WriteMessage(websocket.BinaryMessage, []byte("ready")) startTime := time.Now() - for { - if useWebsockets { + if useWebsockets { + for { var messageType int // read from websockets messageType, message, err = c.ReadMessage() if messageType != websocket.BinaryMessage { continue } - } else { - // read from TCP connection - message, _, _, err = tcpConnection.Read() - // log.Debugf("message: %s", message) + if err != nil { + log.Error(err) + return err + } + if bytes.Equal(message, []byte("magic")) { + log.Debug("got magic") + break + } + select { + case dataChan <- message: + continue + default: + log.Debug("blocked") + // no message sent + // block + dataChan <- message + } } - if err != nil { - log.Error(err) - return err - } - if bytes.Equal(message, []byte("magic")) { - log.Debug("got magic") - break - } - select { - case dataChan <- message: - continue - default: - log.Debug("blocked") - // no message sent - // block - dataChan <- message + _ = <-finished + + } else { + // using TCP + for i := range tcpConnections { + go func(tcpConnection comm.Comm) { + for { + // read from TCP connection + message, _, _, err = tcpConnection.Read() + // log.Debugf("message: %s", message) + if err != nil { + log.Error(err) + return + } + if bytes.Equal(message, []byte("magic")) { + log.Debug("got magic") + break + } + select { + case dataChan <- message: + continue + default: + log.Debug("blocked") + // no message sent + // block + dataChan <- message + } + } + }(tcpConnections[i]) } }