From e10789b558f080602dfa1a2e4fd875f216466523 Mon Sep 17 00:00:00 2001 From: Zack Scholl Date: Mon, 24 Sep 2018 07:17:35 -0700 Subject: [PATCH] allow forcing using websockets --- src/croc/croc.go | 1 + src/croc/croc_test.go | 5 ++++- src/croc/sending.go | 8 ++++++-- src/recipient/recipient.go | 22 +++++++++++++++++----- src/sender/sender.go | 30 +++++++++++++++++++++--------- 5 files changed, 49 insertions(+), 17 deletions(-) diff --git a/src/croc/croc.go b/src/croc/croc.go index 464566a..d0f040e 100644 --- a/src/croc/croc.go +++ b/src/croc/croc.go @@ -36,6 +36,7 @@ type Croc struct { AllowLocalDiscovery bool NoRecipientPrompt bool Stdout bool + ForceSend int // 0: ignore, 1: websockets, 2: TCP // Parameters for file transfer Filename string diff --git a/src/croc/croc_test.go b/src/croc/croc_test.go index d4dd686..9185bde 100644 --- a/src/croc/croc_test.go +++ b/src/croc/croc_test.go @@ -14,6 +14,7 @@ import ( ) func TestSendReceive(t *testing.T) { + forceSend := 0 var startTime time.Time var durationPerMegabyte float64 generateRandomFile(100) @@ -22,14 +23,16 @@ func TestSendReceive(t *testing.T) { go func() { defer wg.Done() c := Init(true) + c.ForceSend = forceSend assert.Nil(t, c.Send("100mb.file", "test")) }() go func() { defer wg.Done() - time.Sleep(2 * time.Second) + time.Sleep(3 * time.Second) os.MkdirAll("test", 0755) os.Chdir("test") c := Init(true) + c.ForceSend = forceSend startTime = time.Now() assert.Nil(t, c.Receive("test")) durationPerMegabyte = 100.0 / time.Since(startTime).Seconds() diff --git a/src/croc/sending.go b/src/croc/sending.go index 2bcb956..739247b 100644 --- a/src/croc/sending.go +++ b/src/croc/sending.go @@ -15,6 +15,7 @@ import ( "github.com/schollz/croc/src/relay" "github.com/schollz/croc/src/sender" "github.com/schollz/peerdiscovery" + "github.com/schollz/utils" ) // Send the file @@ -87,6 +88,9 @@ func (c *Croc) Receive(codephrase string) (err error) { log.Debug(errDiscover) } if len(discovered) > 0 { + if discovered[0].Address == utils.GetLocalIP() { + discovered[0].Address = "localhost" + } log.Debugf("discovered %s:%s", discovered[0].Address, discovered[0].Payload) // see if we can actually connect to it timeout := time.Duration(200 * time.Millisecond) @@ -152,9 +156,9 @@ func (c *Croc) sendReceive(address, websocketPort, tcpPort, fname, codephrase st } if isSender { - go sender.Send(address, tcpPort, isLocal, done, sock, fname, codephrase, c.UseCompression, c.UseEncryption) + go sender.Send(c.ForceSend, address, tcpPort, isLocal, done, sock, fname, codephrase, c.UseCompression, c.UseEncryption) } else { - go recipient.Receive(address, tcpPort, isLocal, done, sock, codephrase, c.NoRecipientPrompt, c.Stdout) + go recipient.Receive(c.ForceSend, address, tcpPort, isLocal, done, sock, codephrase, c.NoRecipientPrompt, c.Stdout) } for { diff --git a/src/recipient/recipient.go b/src/recipient/recipient.go index 927a8ce..aacfcdb 100644 --- a/src/recipient/recipient.go +++ b/src/recipient/recipient.go @@ -31,9 +31,9 @@ import ( var DebugLevel string // Receive is the async operation to receive a file -func Receive(serverAddress, serverTCP string, isLocal bool, done chan struct{}, c *websocket.Conn, codephrase string, noPrompt bool, useStdout bool) { +func Receive(forceSend int, serverAddress, serverTCP string, isLocal bool, done chan struct{}, c *websocket.Conn, codephrase string, noPrompt bool, useStdout bool) { logger.SetLogLevel(DebugLevel) - err := receive(serverAddress, serverTCP, isLocal, c, codephrase, noPrompt, useStdout) + err := receive(forceSend, serverAddress, serverTCP, isLocal, c, codephrase, noPrompt, useStdout) if err != nil { if !strings.HasPrefix(err.Error(), "websocket: close 100") { fmt.Fprintf(os.Stderr, "\n"+err.Error()) @@ -42,7 +42,7 @@ func Receive(serverAddress, serverTCP string, isLocal bool, done chan struct{}, done <- struct{}{} } -func receive(serverAddress, serverTCP string, isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool, useStdout bool) (err error) { +func receive(forceSend int, serverAddress, serverTCP string, isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool, useStdout bool) (err error) { var fstats models.FileStats var sessionKey []byte var transferTime time.Duration @@ -50,6 +50,18 @@ func receive(serverAddress, serverTCP string, isLocal bool, c *websocket.Conn, c var otherIP string var tcpConnection comm.Comm + useWebsockets := true + switch forceSend { + case 0: + if !isLocal { + useWebsockets = false + } + case 1: + useWebsockets = true + case 2: + useWebsockets = false + } + // start a spinner spin := spinner.New(spinner.CharSets[9], 100*time.Millisecond) spin.Writer = os.Stderr @@ -160,7 +172,7 @@ func receive(serverAddress, serverTCP string, isLocal bool, c *websocket.Conn, c } // connect to TCP to receive file - if !isLocal && serverTCP != "" { + if !useWebsockets { log.Debugf("connecting to server") tcpConnection, err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%x", sessionKey)), serverAddress+":"+serverTCP) if err != nil { @@ -188,7 +200,7 @@ func receive(serverAddress, serverTCP string, isLocal bool, c *websocket.Conn, c var numBytes int var bs []byte for { - if isLocal || serverTCP == "" { + if useWebsockets { var messageType int // read from websockets messageType, message, err = c.ReadMessage() diff --git a/src/sender/sender.go b/src/sender/sender.go index ef80228..a7a01b1 100644 --- a/src/sender/sender.go +++ b/src/sender/sender.go @@ -30,10 +30,10 @@ import ( var DebugLevel string // Send is the async call to send data -func Send(serverAddress, serverTCP string, isLocal bool, done chan struct{}, c *websocket.Conn, fname string, codephrase string, useCompression bool, useEncryption bool) { +func Send(forceSend int, serverAddress, serverTCP string, isLocal bool, done chan struct{}, c *websocket.Conn, fname string, codephrase string, useCompression bool, useEncryption bool) { logger.SetLogLevel(DebugLevel) log.Debugf("sending %s", fname) - err := send(serverAddress, serverTCP, isLocal, c, fname, codephrase, useCompression, useEncryption) + err := send(forceSend, serverAddress, serverTCP, isLocal, c, fname, codephrase, useCompression, useEncryption) if err != nil { if !strings.HasPrefix(err.Error(), "websocket: close 100") { fmt.Fprintf(os.Stderr, "\n"+err.Error()) @@ -43,7 +43,7 @@ func Send(serverAddress, serverTCP string, isLocal bool, done chan struct{}, c * done <- struct{}{} } -func send(serverAddress, serverTCP string, isLocal bool, c *websocket.Conn, fname string, codephrase string, useCompression bool, useEncryption bool) (err error) { +func send(forceSend int, serverAddress, serverTCP string, isLocal bool, c *websocket.Conn, fname string, codephrase string, useCompression bool, useEncryption bool) (err error) { var f *os.File defer f.Close() // ignore the error if it wasn't opened :( var fstats models.FileStats @@ -52,6 +52,18 @@ func send(serverAddress, serverTCP string, isLocal bool, c *websocket.Conn, fnam var startTransfer time.Time var tcpConnection comm.Comm + useWebsockets := true + switch forceSend { + case 0: + if !isLocal { + useWebsockets = false + } + case 1: + useWebsockets = true + case 2: + useWebsockets = false + } + fileReady := make(chan error) // normalize the file name @@ -195,7 +207,7 @@ func send(serverAddress, serverTCP string, isLocal bool, c *websocket.Conn, fnam return errors.New("recipient refused file") } - if !isLocal && serverTCP != "" { + if !useWebsockets { // connection to TCP tcpConnection, err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%x", sessionKey)), serverAddress+":"+serverTCP) if err != nil { @@ -208,7 +220,7 @@ func send(serverAddress, serverTCP string, isLocal bool, c *websocket.Conn, fnam // send file, compure hash simultaneously startTransfer = time.Now() buffer := make([]byte, models.WEBSOCKET_BUFFER_SIZE/8) - if !isLocal && serverTCP != "" { + if !useWebsockets { buffer = make([]byte, models.TCP_BUFFER_SIZE/2) } bar := progressbar.NewOptions( @@ -236,12 +248,12 @@ func send(serverAddress, serverTCP string, isLocal bool, c *websocket.Conn, fnam return err } - if isLocal || serverTCP == "" { - // write data to websockets - err = c.WriteMessage(websocket.BinaryMessage, encBytes) - } else { + if !useWebsockets { // write data to tcp connection _, err = tcpConnection.Write(encBytes) + } else { + // write data to websockets + err = c.WriteMessage(websocket.BinaryMessage, encBytes) } if err != nil { err = errors.Wrap(err, "problem writing message")