croc/src/server.go

185 lines
4.6 KiB
Go

package croc
import (
"crypto/elliptic"
"encoding/json"
"fmt"
"sync"
"time"
log "github.com/cihub/seelog"
"github.com/frankenbeanies/uuid4"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
)
type relayState struct {
channel map[string]*channelData
sync.RWMutex
}
var rs relayState
func init() {
rs.Lock()
rs.channel = make(map[string]*channelData)
rs.Unlock()
}
func startServer(tcpPorts []string, port string) (err error) {
// start server
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.Use(middleWareHandler(), gin.Recovery())
r.POST("/channel", func(c *gin.Context) {
r, err := func(c *gin.Context) (r response, err error) {
rs.Lock()
defer rs.Unlock()
r.Success = true
var p payloadChannel
err = c.ShouldBindJSON(&p)
if err != nil {
log.Errorf("failed on payload %+v", p)
err = errors.Wrap(err, "problem parsing /channel")
return
}
// determine if channel is invalid
if _, ok := rs.channel[p.Channel]; !ok {
err = errors.Errorf("channel '%s' does not exist", p.Channel)
return
}
// determine if UUID is invalid for channel
if p.UUID != rs.channel[p.Channel].uuids[0] &&
p.UUID != rs.channel[p.Channel].uuids[1] {
err = errors.Errorf("uuid '%s' is invalid", p.UUID)
return
}
// check if the action is to close the channel
if p.Close {
delete(rs.channel, p.Channel)
r.Message = "deleted " + p.Channel
return
}
// assign each key provided
assignedKeys := []string{}
for key := range p.State {
// TODO:
// add a check that the value of key is not enormous
// add only if it is a valid key
if _, ok := rs.channel[p.Channel].State[key]; ok {
assignedKeys = append(assignedKeys, key)
rs.channel[p.Channel].State[key] = p.State[key]
}
}
// return the current state
r.Data = rs.channel[p.Channel]
r.Message = fmt.Sprintf("assigned %d keys: %v", len(assignedKeys), assignedKeys)
return
}(c)
if err != nil {
log.Debugf("bad /channel: %s", err.Error())
r.Message = err.Error()
r.Success = false
}
bR, _ := json.Marshal(r)
c.Data(200, "application/json", bR)
})
r.POST("/join", func(c *gin.Context) {
r, err := func(c *gin.Context) (r response, err error) {
rs.Lock()
defer rs.Unlock()
r.Success = true
var p payloadOpen
err = c.ShouldBindJSON(&p)
if err != nil {
log.Errorf("failed on payload %+v", p)
err = errors.Wrap(err, "problem parsing")
return
}
// determine if sender or recipient
if p.Role != 0 && p.Role != 1 {
err = errors.Errorf("no such role of %d", p.Role)
return
}
// determine channel
if p.Channel == "" {
// TODO:
// find an empty channel
p.Channel = "chou"
}
if _, ok := rs.channel[p.Channel]; ok {
// channel is not empty
if rs.channel[p.Channel].uuids[p.Role] != "" {
err = errors.Errorf("channel '%s' already occupied by role %d", p.Channel, p.Role)
return
}
}
r.Channel = p.Channel
if _, ok := rs.channel[r.Channel]; !ok {
rs.channel[r.Channel] = newChannelData(r.Channel)
}
// assign UUID for the role in the channel
rs.channel[r.Channel].uuids[p.Role] = uuid4.New().String()
r.UUID = rs.channel[r.Channel].uuids[p.Role]
log.Debugf("(%s) %s has joined as role %d", r.Channel, r.UUID, p.Role)
// if channel is not open, set curve
if !rs.channel[r.Channel].isopen {
switch curve := p.Curve; curve {
case "p224":
rs.channel[r.Channel].curve = elliptic.P224()
case "p256":
rs.channel[r.Channel].curve = elliptic.P256()
case "p384":
rs.channel[r.Channel].curve = elliptic.P384()
case "p521":
rs.channel[r.Channel].curve = elliptic.P521()
default:
// TODO:
// add SIEC
p.Curve = "p256"
rs.channel[r.Channel].curve = elliptic.P256()
}
log.Debugf("(%s) using curve '%s'", r.Channel, p.Curve)
rs.channel[r.Channel].State["curve"] = []byte(p.Curve)
rs.channel[r.Channel].Ports = tcpPorts
rs.channel[r.Channel].isopen = true
}
r.Message = fmt.Sprintf("assigned role %d in channel '%s'", p.Role, r.Channel)
return
}(c)
if err != nil {
log.Debugf("bad /join: %s", err.Error())
r.Message = err.Error()
r.Success = false
}
bR, _ := json.Marshal(r)
c.Data(200, "application/json", bR)
})
log.Infof("Running at http://0.0.0.0:" + port)
err = r.Run(":" + port)
return
}
func middleWareHandler() gin.HandlerFunc {
return func(c *gin.Context) {
t := time.Now()
// Run next function
c.Next()
// Log request
log.Infof("%v %v %v %s", c.Request.RemoteAddr, c.Request.Method, c.Request.URL, time.Since(t))
}
}