diff --git a/src/api.go b/src/api.go index f0bb473..35893ae 100644 --- a/src/api.go +++ b/src/api.go @@ -10,6 +10,10 @@ type Croc struct { UseCompression bool CurveType string AllowLocalDiscovery bool + + // private variables + // rs relay state is only for the relay + rs relayState } // Init will initialize the croc relay @@ -27,11 +31,15 @@ func Init() (c *Croc) { // Relay initiates a relay func (c *Croc) Relay() error { + c.rs.Lock() + c.rs.channel = make(map[string]*channelData) + c.rs.Unlock() + // start relay - go startRelay(c.TcpPorts) + go c.startRelay(c.TcpPorts) // start server - return startServer(c.TcpPorts, c.ServerPort) + return c.startServer(c.TcpPorts, c.ServerPort) } // Send will take an existing file or folder and send it through the croc relay diff --git a/src/models.go b/src/models.go index 0a14c22..0609a25 100644 --- a/src/models.go +++ b/src/models.go @@ -3,6 +3,7 @@ package croc import ( "crypto/elliptic" "net" + "sync" "time" ) @@ -16,6 +17,11 @@ var ( availableStates = []string{"curve", "h_k", "hh_k", "x", "y"} ) +type relayState struct { + channel map[string]*channelData + sync.RWMutex +} + type channelData struct { // Public // Name is the name of the channel diff --git a/src/relay.go b/src/relay.go index 7d4fe6e..526eec1 100644 --- a/src/relay.go +++ b/src/relay.go @@ -10,14 +10,14 @@ import ( "github.com/pkg/errors" ) -func startRelay(ports []string) { +func (c *Croc) startRelay(ports []string) { var wg sync.WaitGroup wg.Add(len(ports)) for _, port := range ports { go func(port string, wg *sync.WaitGroup) { defer wg.Done() log.Debugf("listening on port %s", port) - if err := listener(port); err != nil { + if err := c.listener(port); err != nil { log.Error(err) return } @@ -26,7 +26,7 @@ func startRelay(ports []string) { wg.Wait() } -func listener(port string) (err error) { +func (c *Croc) listener(port string) (err error) { server, err := net.Listen("tcp", "0.0.0.0:"+port) if err != nil { return errors.Wrap(err, "Error listening on :"+port) @@ -40,7 +40,7 @@ func listener(port string) (err error) { } log.Debugf("client %s connected", connection.RemoteAddr().String()) go func(port string, connection net.Conn) { - errCommunication := clientCommuncation(port, connection) + errCommunication := c.clientCommuncation(port, connection) if errCommunication != nil { log.Warnf("relay-%s: %s", connection.RemoteAddr().String(), errCommunication.Error()) } @@ -48,7 +48,7 @@ func listener(port string) (err error) { } } -func clientCommuncation(port string, connection net.Conn) (err error) { +func (c *Croc) clientCommuncation(port string, connection net.Conn) (err error) { var con1, con2 net.Conn // get the channel and UUID from the client @@ -67,27 +67,27 @@ func clientCommuncation(port string, connection net.Conn) (err error) { log.Debugf("%s connected with channel %s and uuid %s", connection.RemoteAddr().String(), channel, uuid) // validate channel and UUID - rs.Lock() - if _, ok := rs.channel[channel]; !ok { - rs.Unlock() + c.rs.Lock() + if _, ok := c.rs.channel[channel]; !ok { + c.rs.Unlock() err = errors.Errorf("channel %s does not exist", channel) return } - if uuid != rs.channel[channel].uuids[0] && - uuid != rs.channel[channel].uuids[1] { - rs.Unlock() + if uuid != c.rs.channel[channel].uuids[0] && + uuid != c.rs.channel[channel].uuids[1] { + c.rs.Unlock() err = errors.Errorf("uuid '%s' is invalid", uuid) return } role := 0 - if uuid == rs.channel[channel].uuids[1] { + if uuid == c.rs.channel[channel].uuids[1] { role = 1 } - rs.channel[channel].connection[role] = connection + c.rs.channel[channel].connection[role] = connection - con1 = rs.channel[channel].connection[0] - con2 = rs.channel[channel].connection[1] - rs.Unlock() + con1 = c.rs.channel[channel].connection[0] + con2 = c.rs.channel[channel].connection[1] + c.rs.Unlock() if con1 != nil && con2 != nil { var wg sync.WaitGroup @@ -100,9 +100,9 @@ func clientCommuncation(port string, connection net.Conn) (err error) { // then set transfer ready go func(channel string, wg *sync.WaitGroup) { // set the channels to ready - rs.Lock() - rs.channel[channel].TransferReady = true - rs.Unlock() + c.rs.Lock() + c.rs.channel[channel].TransferReady = true + c.rs.Unlock() wg.Done() }(channel, &wg) wg.Wait() diff --git a/src/server.go b/src/server.go index 362f9eb..6efee9c 100644 --- a/src/server.go +++ b/src/server.go @@ -4,7 +4,6 @@ import ( "crypto/elliptic" "encoding/json" "fmt" - "sync" "time" log "github.com/cihub/seelog" @@ -13,34 +12,21 @@ import ( "github.com/pkg/errors" ) -type relayState struct { - channel map[string]*channelData - sync.RWMutex -} - -var rs relayState - -func init() { - rs.Lock() - rs.channel = make(map[string]*channelData) - rs.Unlock() -} - -func startServer(tcpPorts []string, port string) (err error) { +func (c *Croc) startServer(tcpPorts []string, port string) (err error) { // start cleanup on dangling channels - go channelCleanup() + go c.channelCleanup() // start server gin.SetMode(gin.ReleaseMode) r := gin.New() r.Use(middleWareHandler(), gin.Recovery()) - r.POST("/channel", func(c *gin.Context) { - r, err := func(c *gin.Context) (r response, err error) { - rs.Lock() - defer rs.Unlock() + r.POST("/channel", func(cg *gin.Context) { + r, err := func(cg *gin.Context) (r response, err error) { + c.rs.Lock() + defer c.rs.Unlock() r.Success = true var p payloadChannel - err = c.ShouldBindJSON(&p) + err = cg.ShouldBindJSON(&p) if err != nil { log.Errorf("failed on payload %+v", p) err = errors.Wrap(err, "problem parsing /channel") @@ -48,21 +34,21 @@ func startServer(tcpPorts []string, port string) (err error) { } // determine if channel is invalid - if _, ok := rs.channel[p.Channel]; !ok { + if _, ok := c.rs.channel[p.Channel]; !ok { err = errors.Errorf("channel '%s' does not exist", p.Channel) return } // determine if UUID is invalid for channel - if p.UUID != rs.channel[p.Channel].uuids[0] && - p.UUID != rs.channel[p.Channel].uuids[1] { + if p.UUID != c.rs.channel[p.Channel].uuids[0] && + p.UUID != c.rs.channel[p.Channel].uuids[1] { err = errors.Errorf("uuid '%s' is invalid", p.UUID) return } // check if the action is to close the channel if p.Close { - delete(rs.channel, p.Channel) + delete(c.rs.channel, p.Channel) r.Message = "deleted " + p.Channel return } @@ -74,34 +60,34 @@ func startServer(tcpPorts []string, port string) (err error) { // add a check that the value of key is not enormous // add only if it is a valid key - if _, ok := rs.channel[p.Channel].State[key]; ok { + if _, ok := c.rs.channel[p.Channel].State[key]; ok { assignedKeys = append(assignedKeys, key) - rs.channel[p.Channel].State[key] = p.State[key] + c.rs.channel[p.Channel].State[key] = p.State[key] } } // return the current state - r.Data = rs.channel[p.Channel] + r.Data = c.rs.channel[p.Channel] r.Message = fmt.Sprintf("assigned %d keys: %v", len(assignedKeys), assignedKeys) return - }(c) + }(cg) if err != nil { log.Debugf("bad /channel: %s", err.Error()) r.Message = err.Error() r.Success = false } bR, _ := json.Marshal(r) - c.Data(200, "application/json", bR) + cg.Data(200, "application/json", bR) }) - r.POST("/join", func(c *gin.Context) { - r, err := func(c *gin.Context) (r response, err error) { - rs.Lock() - defer rs.Unlock() + r.POST("/join", func(cg *gin.Context) { + r, err := func(cg *gin.Context) (r response, err error) { + c.rs.Lock() + defer c.rs.Unlock() r.Success = true var p payloadOpen - err = c.ShouldBindJSON(&p) + err = cg.ShouldBindJSON(&p) if err != nil { log.Errorf("failed on payload %+v", p) err = errors.Wrap(err, "problem parsing") @@ -120,57 +106,57 @@ func startServer(tcpPorts []string, port string) (err error) { // find an empty channel p.Channel = "chou" } - if _, ok := rs.channel[p.Channel]; ok { + if _, ok := c.rs.channel[p.Channel]; ok { // channel is not empty - if rs.channel[p.Channel].uuids[p.Role] != "" { + if c.rs.channel[p.Channel].uuids[p.Role] != "" { err = errors.Errorf("channel '%s' already occupied by role %d", p.Channel, p.Role) return } } r.Channel = p.Channel - if _, ok := rs.channel[r.Channel]; !ok { - rs.channel[r.Channel] = newChannelData(r.Channel) + if _, ok := c.rs.channel[r.Channel]; !ok { + c.rs.channel[r.Channel] = newChannelData(r.Channel) } // assign UUID for the role in the channel - rs.channel[r.Channel].uuids[p.Role] = uuid4.New().String() - r.UUID = rs.channel[r.Channel].uuids[p.Role] + c.rs.channel[r.Channel].uuids[p.Role] = uuid4.New().String() + r.UUID = c.rs.channel[r.Channel].uuids[p.Role] log.Debugf("(%s) %s has joined as role %d", r.Channel, r.UUID, p.Role) // if channel is not open, set initial parameters - if !rs.channel[r.Channel].isopen { - rs.channel[r.Channel].isopen = true - rs.channel[r.Channel].Ports = tcpPorts - rs.channel[r.Channel].startTime = time.Now() + if !c.rs.channel[r.Channel].isopen { + c.rs.channel[r.Channel].isopen = true + c.rs.channel[r.Channel].Ports = tcpPorts + c.rs.channel[r.Channel].startTime = time.Now() switch curve := p.Curve; curve { case "p224": - rs.channel[r.Channel].curve = elliptic.P224() + c.rs.channel[r.Channel].curve = elliptic.P224() case "p256": - rs.channel[r.Channel].curve = elliptic.P256() + c.rs.channel[r.Channel].curve = elliptic.P256() case "p384": - rs.channel[r.Channel].curve = elliptic.P384() + c.rs.channel[r.Channel].curve = elliptic.P384() case "p521": - rs.channel[r.Channel].curve = elliptic.P521() + c.rs.channel[r.Channel].curve = elliptic.P521() default: // TODO: // add SIEC p.Curve = "p256" - rs.channel[r.Channel].curve = elliptic.P256() + c.rs.channel[r.Channel].curve = elliptic.P256() } log.Debugf("(%s) using curve '%s'", r.Channel, p.Curve) - rs.channel[r.Channel].State["curve"] = []byte(p.Curve) + c.rs.channel[r.Channel].State["curve"] = []byte(p.Curve) } r.Message = fmt.Sprintf("assigned role %d in channel '%s'", p.Role, r.Channel) return - }(c) + }(cg) if err != nil { log.Debugf("bad /join: %s", err.Error()) r.Message = err.Error() r.Success = false } bR, _ := json.Marshal(r) - c.Data(200, "application/json", bR) + cg.Data(200, "application/json", bR) }) log.Infof("Running at http://0.0.0.0:" + port) err = r.Run(":" + port) @@ -178,32 +164,32 @@ func startServer(tcpPorts []string, port string) (err error) { } func middleWareHandler() gin.HandlerFunc { - return func(c *gin.Context) { + return func(cg *gin.Context) { t := time.Now() // Run next function - c.Next() + cg.Next() // Log request - log.Infof("%v %v %v %s", c.Request.RemoteAddr, c.Request.Method, c.Request.URL, time.Since(t)) + log.Infof("%v %v %v %s", cg.Request.RemoteAddr, cg.Request.Method, cg.Request.URL, time.Since(t)) } } -func channelCleanup() { +func (c *Croc) channelCleanup() { maximumWait := 10 * time.Minute for { - rs.Lock() - keys := make([]string, len(rs.channel)) + c.rs.Lock() + keys := make([]string, len(c.rs.channel)) i := 0 - for key := range rs.channel { + for key := range c.rs.channel { keys[i] = key i++ } for _, key := range keys { - if time.Since(rs.channel[key].startTime) > maximumWait { + if time.Since(c.rs.channel[key].startTime) > maximumWait { log.Debugf("channel %s has exceeded time, deleting", key) - delete(rs.channel, key) + delete(c.rs.channel, key) } } - rs.Unlock() + c.rs.Unlock() time.Sleep(1 * time.Minute) } }