From 706cf2f2daf6b0f9a11f12822f78d72f7d8bd19b Mon Sep 17 00:00:00 2001 From: Rahul Rudragoudar Date: Mon, 12 Apr 2021 23:14:42 +0530 Subject: [PATCH] Improve error handling/messages Signed-off-by: Rahul Rudragoudar --- src/main/scala/lc/core/captcha.scala | 97 +++++++++++---------------- src/main/scala/lc/server/Server.scala | 38 ++++++++--- 2 files changed, 69 insertions(+), 66 deletions(-) diff --git a/src/main/scala/lc/core/captcha.scala b/src/main/scala/lc/core/captcha.scala index 15d32a9..34cda2f 100644 --- a/src/main/scala/lc/core/captcha.scala +++ b/src/main/scala/lc/core/captcha.scala @@ -5,29 +5,22 @@ import java.util.UUID import java.io.ByteArrayInputStream import lc.database.Statements import lc.core.CaptchaProviders -import java.sql.Blob - object Captcha { - def getCaptcha(id: Id): Array[Byte] = { - var image: Array[Byte] = null - var blob: Blob = null - try { - val imagePstmt = Statements.tlStmts.get.imagePstmt - imagePstmt.setString(1, id.id) - val rs: ResultSet = imagePstmt.executeQuery() - if (rs.next()) { - blob = rs.getBlob("image") - if (blob != null) { - image = blob.getBytes(1, blob.length().toInt) - } + def getCaptcha(id: Id): ChallengeResult = { + val imagePstmt = Statements.tlStmts.get.imagePstmt + imagePstmt.setString(1, id.id) + val rs: ResultSet = imagePstmt.executeQuery() + if (rs.next()) { + val blob = rs.getBlob("image") + if (blob != null) { + Image(blob.getBytes(1, blob.length().toInt)) + } else { + Error(ErrorMessageEnum.IMG_MISSING.toString) } - image - } catch { - case e: Exception => - println(e) - image + } else { + Error(ErrorMessageEnum.IMG_NOT_FOUND.toString) } } @@ -57,49 +50,41 @@ object Captcha { val allowedLevels = Config.allowedLevels val allowedMedia = Config.allowedMedia - private def validateParam(param: Parameters): Boolean = { - if ( - allowedLevels.contains(param.level) && - allowedMedia.contains(param.media) && - allowedInputType.contains(param.input_type) - ) - return true - else - return false + private def validateParam(param: Parameters): Array[String] = { + var invalid_params = Array[String]() + if (!allowedLevels.contains(param.level)) invalid_params :+= "level" + if (!allowedMedia.contains(param.media)) invalid_params :+= "media" + if (!allowedInputType.contains(param.input_type)) invalid_params :+= "input_type" + + invalid_params } - def getChallenge(param: Parameters): ChallengeResult = { + def getChallenge(param: Parameters): ChallengeResult = { try { val validParam = validateParam(param) - if (validParam) { - val tokenPstmt = Statements.tlStmts.get.tokenPstmt - tokenPstmt.setString(1, param.level) - tokenPstmt.setString(2, param.media) - tokenPstmt.setString(3, param.input_type) - val rs = tokenPstmt.executeQuery() - val tokenOpt = if (rs.next()) { - Some(rs.getInt("token")) - } else { - None - } - val updateAttemptedPstmt = Statements.tlStmts.get.updateAttemptedPstmt - val token = tokenOpt.getOrElse(generateChallenge(param)) - val result = if (token != -1) { - val uuid = getUUID(token) - updateAttemptedPstmt.setString(1, uuid) - updateAttemptedPstmt.executeUpdate() - Id(uuid) - } else { - Error(ErrorMessageEnum.NO_CAPTCHA.toString) - } - result + if (!validParam.isEmpty) + return Error(ErrorMessageEnum.INVALID_PARAM.toString + " => " + validParam.mkString(", ")) + val tokenPstmt = Statements.tlStmts.get.tokenPstmt + tokenPstmt.setString(1, param.level) + tokenPstmt.setString(2, param.media) + tokenPstmt.setString(3, param.input_type) + val rs = tokenPstmt.executeQuery() + val tokenOpt = if (rs.next()) { + Some(rs.getInt("token")) } else { - Error(ErrorMessageEnum.INVALID_PARAM.toString) + None } + val updateAttemptedPstmt = Statements.tlStmts.get.updateAttemptedPstmt + val token = tokenOpt.getOrElse(generateChallenge(param)) + val uuid = getUUID(token) + updateAttemptedPstmt.setString(1, uuid) + updateAttemptedPstmt.executeUpdate() + Id(uuid) } catch { - case e: Exception => - println(e) - Error(ErrorMessageEnum.SMW.toString) + case exception: NoSuchElementException => { + println(exception.getStackTrace) + Error(ErrorMessageEnum.NO_CAPTCHA.toString) + } } } @@ -112,7 +97,7 @@ object Captcha { uuid } - def checkAnswer(answer: Answer): Result = { + def checkAnswer(answer: Answer): ChallengeResult = { val selectPstmt = Statements.tlStmts.get.selectPstmt selectPstmt.setInt(1, Config.captchaExpiryTimeLimit) selectPstmt.setString(2, answer.id) diff --git a/src/main/scala/lc/server/Server.scala b/src/main/scala/lc/server/Server.scala index e1d048a..73dff4b 100644 --- a/src/main/scala/lc/server/Server.scala +++ b/src/main/scala/lc/server/Server.scala @@ -5,7 +5,7 @@ import org.json4s.jackson.JsonMethods.parse import org.json4s.jackson.Serialization.write import lc.core.Captcha import lc.core.ErrorMessageEnum -import lc.core.{Parameters, Id, Answer, Response} +import lc.core.{Parameters, Id, Answer, Response, Error, ChallengeResult, Image} import org.json4s.JsonAST.JValue import com.sun.net.httpserver.{HttpServer, HttpExchange} import java.net.InetSocketAddress @@ -22,15 +22,20 @@ class Server(port: Int) { parse(string) } - private def getPathParameter(ex: HttpExchange): String = { + private def getPathParameter(ex: HttpExchange): Either[String, Error] = { try { val uri = ex.getRequestURI.toString - val param = uri.split("\\?")(1) - param.split("=")(1) + val pathParam = uri.split("\\?")(1) + val param = pathParam.split("=") + if (param(0) == "id") { + Left(param(1)) + } else { + Right(Error(ErrorMessageEnum.INVALID_PARAM.toString + "=> id")) + } } catch { case exception: ArrayIndexOutOfBoundsException => { println(exception.getStackTrace) - throw new Exception(ErrorMessageEnum.INVALID_PARAM.toString) + Right(Error(ErrorMessageEnum.INVALID_PARAM.toString + "=> id")) } } } @@ -54,6 +59,14 @@ class Server(port: Int) { Response(405, write(message).getBytes) } + private def getResponse(response: ChallengeResult): Response = { + response match { + case Image(image) => Response(200, image) + case Error(_) => Response(500, write(response).getBytes) + case _ => Response(200, write(response).getBytes) + } + } + private def makeApiWorker(path: String, f: (String, HttpExchange) => Response): Unit = { server.createContext( path, @@ -84,7 +97,7 @@ class Server(port: Int) { val json = getRequestJson(ex) val param = json.extract[Parameters] val id = Captcha.getChallenge(param) - Response(200, write(id).getBytes) + getResponse(id) } else { getBadRequestError() } @@ -96,9 +109,14 @@ class Server(port: Int) { (method: String, ex: HttpExchange) => { if (method == "GET") { val param = getPathParameter(ex) - val id = Id(param) - val image = Captcha.getCaptcha(id) - Response(200, image) + val result = param match { + case Left(value) => { + val id = Id(value) + Captcha.getCaptcha(id) + } + case Right(value) => value + } + getResponse(result) } else { getBadRequestError() } @@ -112,7 +130,7 @@ class Server(port: Int) { val json = getRequestJson(ex) val answer = json.extract[Answer] val result = Captcha.checkAnswer(answer) - Response(200, write(result).getBytes) + getResponse(result) } else { getBadRequestError() }