From 6da74e0adcbc9169b9f03b10d00e90a9de1433c4 Mon Sep 17 00:00:00 2001 From: Lilleman Date: Tue, 5 Jan 2021 16:23:18 +0100 Subject: [PATCH] Added endpoint to auth with username and password --- src/db/accounts.go | 7 +++++-- src/handlers/get.go | 2 +- src/handlers/post.go | 30 +++++++++++++++++++++++++++++- src/main.go | 1 + 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/src/db/accounts.go b/src/db/accounts.go index ef643fe..9160f74 100644 --- a/src/db/accounts.go +++ b/src/db/accounts.go @@ -58,7 +58,7 @@ func (d Db) AccountCreate(input AccountCreateInput) (CreatedAccount, error) { } // AccountGet fetches an account from the database -func (d Db) AccountGet(accountID string, APIKey string) (Account, error) { +func (d Db) AccountGet(accountID string, APIKey string, Name string) (Account, error) { logContext := log.WithFields(log.Fields{ "accountID": accountID, "APIKey": len(APIKey), @@ -75,6 +75,9 @@ func (d Db) AccountGet(accountID string, APIKey string) (Account, error) { } else if APIKey != "" { accountSQL = accountSQL + "\"apiKey\" = $1" searchParam = APIKey + } else if Name != "" { + accountSQL = accountSQL + "name = $1" + searchParam = Name } accountErr := d.DbPool.QueryRow(context.Background(), accountSQL, searchParam).Scan(&account.ID, &account.Created, &account.Name, &account.Password) @@ -114,7 +117,7 @@ func (d Db) AccountGet(accountID string, APIKey string) (Account, error) { func (d Db) RenewalTokenGet(accountID string) (string, error) { logContext := log.WithFields(log.Fields{"accountID": accountID}) - logContext.Debug("Createing new renewal token") + logContext.Debug("Creating new renewal token") newToken := utils.RandString(60) diff --git a/src/handlers/get.go b/src/handlers/get.go index 88174ca..44fc0bb 100644 --- a/src/handlers/get.go +++ b/src/handlers/get.go @@ -19,7 +19,7 @@ func (h Handlers) AccountGet(c *fiber.Ctx) error { return c.Status(403).JSON([]ResJSONError{{Error: authErr.Error()}}) } - account, accountErr := h.Db.AccountGet(accountID, "") + account, accountErr := h.Db.AccountGet(accountID, "", "") if accountErr != nil { return c.Status(500).JSON([]ResJSONError{{Error: accountErr.Error()}}) } diff --git a/src/handlers/post.go b/src/handlers/post.go index 0768cbd..884cff6 100644 --- a/src/handlers/post.go +++ b/src/handlers/post.go @@ -75,7 +75,7 @@ func (h Handlers) AccountAuthAPIKey(c *fiber.Ctx) error { inputAPIKey := string(c.Request().Body()) inputAPIKey = inputAPIKey[1 : len(inputAPIKey)-1] - resolvedAccount, accountErr := h.Db.AccountGet("", inputAPIKey) + resolvedAccount, accountErr := h.Db.AccountGet("", inputAPIKey, "") if accountErr != nil { if accountErr.Error() == "no rows in result set" { return c.Status(403).JSON([]ResJSONError{{Error: "Invalid credentials"}}) @@ -86,3 +86,31 @@ func (h Handlers) AccountAuthAPIKey(c *fiber.Ctx) error { return h.returnTokens(resolvedAccount, c) } + +// AccountAuthPassword auths a name/password pair +func (h Handlers) AccountAuthPassword(c *fiber.Ctx) error { + type AuthInput struct { + Name string `json:"name"` + Password string `json:"password"` + } + + authInput := new(AuthInput) + if err := c.BodyParser(authInput); err != nil { + return c.Status(400).JSON([]ResJSONError{{Error: err.Error()}}) + } + + resolvedAccount, err := h.Db.AccountGet("", "", authInput.Name) + if err != nil { + if err.Error() == "No account found" { + return c.Status(403).JSON([]ResJSONError{{Error: "Invalid name or password"}}) + } + + return c.Status(500).JSON([]ResJSONError{{Error: err.Error()}}) + } + + if utils.CheckPasswordHash(authInput.Password, resolvedAccount.Password) == false { + return c.Status(403).JSON([]ResJSONError{{Error: "Invalid name or password"}}) + } + + return h.returnTokens(resolvedAccount, c) +} diff --git a/src/main.go b/src/main.go index efe6cf1..e30f3d0 100644 --- a/src/main.go +++ b/src/main.go @@ -76,6 +76,7 @@ func main() { app.Get("/account/:accountID", handlers.AccountGet) app.Post("/account", handlers.AccountCreate) app.Post("/auth/api-key", handlers.AccountAuthAPIKey) + app.Post("/auth/password", handlers.AccountAuthPassword) log.WithFields(log.Fields{"WEB_BIND_HOST": os.Getenv("WEB_BIND_HOST")}).Info("Trying to start web server")