default
This commit is contained in:
38
api/admin/auth.go
Normal file
38
api/admin/auth.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
."main.go/type"
|
||||
"main.go/api"
|
||||
|
||||
)
|
||||
|
||||
|
||||
|
||||
// AuthRequired 用于管理员权限控制的中间件
|
||||
func AuthRequired()gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
session, err := api.Store.Get(c.Request, "SNCTFSESSID")
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"code": 400, "msg": "Get SNCTFSESSID error"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
user, ok := session.Values["user"].(User)
|
||||
if !ok {
|
||||
c.JSON(200, gin.H{"code": 400, "msg": "No session"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if user.Role != 1 {
|
||||
c.JSON(200, gin.H{"code": 400, "msg": "Unauthorized access!"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
|
||||
}
|
79
api/admin/challenge.go
Normal file
79
api/admin/challenge.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"main.go/api"
|
||||
db "main.go/database"
|
||||
. "main.go/type"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CrDb 专门用作给数据替换。
|
||||
type CrDb struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Score int `json:"score" binding:"required"`
|
||||
Flag string `json:"flag"` // 暂时一个题只能一个flag
|
||||
Description string `json:"description"`
|
||||
Attachment string `json:"attachment"`
|
||||
Category string `json:"category" binding:"required"`
|
||||
Tags string `json:"tags"`
|
||||
Hints string `json:"hints"`
|
||||
Visible int `json:"visible"`
|
||||
}
|
||||
|
||||
|
||||
//NewChallenge 新建一个题目
|
||||
func NewChallenge(c *gin.Context){
|
||||
var request ChallengeRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
fmt.Println(err)
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Request format wrong!"})
|
||||
return
|
||||
}
|
||||
challenge := &Challenge{
|
||||
Name: request.Name,
|
||||
Score: request.Score,
|
||||
Flag: request.Flag,
|
||||
Description: request.Description,
|
||||
Attachment: request.Attachment,
|
||||
Category: request.Category,
|
||||
Tags: request.Tags,
|
||||
Hints: request.Hints,
|
||||
Visible: request.Visible,
|
||||
}
|
||||
if err := addChallenge(challenge); err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Add challenge failure!"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"code": 200, "msg": "Add challenge success!"})
|
||||
}
|
||||
|
||||
//addChallenge 添加题目内容
|
||||
func addChallenge(c *Challenge) error {
|
||||
api.Link()
|
||||
DB := db.DBsnctf
|
||||
// 使用逗号分隔字符串
|
||||
attachmentString := strings.Join(c.Attachment, ",")
|
||||
hintString := strings.Join(c.Hints, ",")
|
||||
crdb := &CrDb{
|
||||
Name: c.Name,
|
||||
Score: c.Score,
|
||||
Flag: c.Flag,
|
||||
Description: c.Description,
|
||||
Attachment: attachmentString,
|
||||
Category: c.Category,
|
||||
Tags: c.Tags,
|
||||
Hints: hintString,
|
||||
Visible: c.Visible,
|
||||
}
|
||||
//插入数据
|
||||
err := DB.Table("challenge").Create(crdb).Error
|
||||
//command := "INSERT INTO challenge (name,score,flag,description,attachment,category,tags,hints,visible) VALUES (?,?,?,?,?,?,?,?,?);"
|
||||
//err := DB.Debug().Raw(command,c.Name, c.Score, c.Flag, c.Description, attachmentString, c.Category, c.Tags, hintString, c.Visible).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
58
api/category.go
Normal file
58
api/category.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
db "main.go/database"
|
||||
|
||||
)
|
||||
|
||||
|
||||
// GetCategories 获取所有题目分类。
|
||||
func GetCategories(c *gin.Context) {
|
||||
var categories []string
|
||||
|
||||
if err := getAllCategories(&categories); err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get categories failure!"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"code": 200, "data": categories})
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
//getAllCategories 操作数据库所有题目分类
|
||||
func getAllCategories(categories *[]string) error {
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
rows,err:= DB.Raw("select category from category").Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var category string
|
||||
if err := rows.Scan(&category); err != nil {
|
||||
return err
|
||||
}
|
||||
*categories = append(*categories, category)
|
||||
}
|
||||
return rows.Err()
|
||||
|
||||
}
|
||||
|
||||
// CheckCategory 检查类别是否正确
|
||||
func CheckCategory(c string) bool {
|
||||
var categories []string
|
||||
if err := getAllCategories(&categories); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, category := range categories {
|
||||
if category == c {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
84
api/challenge.go
Normal file
84
api/challenge.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
db "main.go/database"
|
||||
. "main.go/type"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//GetAllChallenges 获取全部题目
|
||||
|
||||
func GetAllChallenges(c * gin.Context){
|
||||
var challenges []ChallengeResponse
|
||||
|
||||
if err := getAllChallenges(c, &challenges); err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get all challenges failure!"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"code": 200, "data": challenges})
|
||||
|
||||
}
|
||||
|
||||
//GetChallengesByCategory 获取某个分类下的题目
|
||||
|
||||
func GetChallengesByCategory(c *gin.Context){
|
||||
category := c.Param("category")
|
||||
if matched := CheckCategory(category); !matched {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Wrong category!"})
|
||||
return
|
||||
}
|
||||
var challenges []ChallengeResponse
|
||||
if err := getAllChallenges(c, &challenges); err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get all challenges failure!"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"code": 200, "data": challenges})
|
||||
|
||||
}
|
||||
|
||||
|
||||
// getAllChallenges 操作数据库获取所有题目。
|
||||
func getAllChallenges(c *gin.Context, challenges *[]ChallengeResponse) error {
|
||||
var attachmentString, hints string
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
rows,err := DB.Raw("SELECT id, name, score, description, attachment, category, tags, hints FROM challenge WHERE visible=1;").Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next(){
|
||||
var challenge ChallengeResponse
|
||||
err = rows.Scan(&challenge.ID, &challenge.Name, &challenge.Score, &challenge.Description, &attachmentString, &challenge.Category, &challenge.Tags, &hints)
|
||||
fmt.Println(err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 解析为切片
|
||||
challenge.Attachment = strings.Split(attachmentString, ",")
|
||||
challenge.Hints = strings.Split(hints, ",")
|
||||
|
||||
solverCount, err := getSolverCount(challenge.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
challenge.SolverCount = solverCount
|
||||
session, err := Store.Get(c.Request, "SNCTFSESSID")
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"code": 400, "msg": "Get SNCTFSESSID error"})
|
||||
return err
|
||||
}
|
||||
user, ok := session.Values["user"].(User)
|
||||
if !ok {
|
||||
c.JSON(200, gin.H{"code": 400, "msg": "No session"})
|
||||
return errors.New("no session")
|
||||
}
|
||||
challenge.IsSolved = hasAlreadySolved(user.ID, challenge.ID)
|
||||
*challenges = append(*challenges, challenge)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
39
api/notice.go
Normal file
39
api/notice.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
. "main.go/type"
|
||||
db "main.go/database"
|
||||
)
|
||||
|
||||
// GetAllNotices 获取所有的公告
|
||||
func GetAllNotices(c *gin.Context) {
|
||||
var notices []Notice
|
||||
if err := getAllNotices(¬ices); err != nil {
|
||||
c.JSON(200, gin.H{"code": 400, "msg": "Get all notices failure!"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"code": 200, "data": notices})
|
||||
}
|
||||
|
||||
|
||||
func getAllNotices(notices *[]Notice) error {
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
command := "SELECT id, title, content, created_at FROM notice ORDER BY created_at ASC;"
|
||||
rows, err := DB.Raw(command).Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var b Notice
|
||||
err = rows.Scan(&b.ID, &b.Title, &b.Content, &b.CreatedAt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*notices = append(*notices, b)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
78
api/score.go
Normal file
78
api/score.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
db "main.go/database"
|
||||
"main.go/tools"
|
||||
. "main.go/type"
|
||||
)
|
||||
|
||||
//GetAllScore 获取所有的积分,按照积分降序排列
|
||||
func GetAllScore(c *gin.Context){
|
||||
var s ScoreResponse
|
||||
var scores []ScoreResponse
|
||||
Link()
|
||||
DB :=db.DBsnctf
|
||||
//rows,err := DB.Raw("SELECT s.id, s.username, s.score FROM score AS s, user AS u WHERE u.hidden = 0 AND s.username = u.username ORDER BY s.score DESC;").Scan(&s).Rows()
|
||||
rows,err := DB.Debug().Raw("SELECT s.id, s.username, s.score FROM score AS s, user AS u WHERE u.hidden = 0 AND s.username = u.username ORDER BY s.score DESC;").Rows()
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get all score error!"})
|
||||
return
|
||||
}
|
||||
// 循环读取数据
|
||||
for rows.Next() {
|
||||
rows.Scan(&s.ID, &s.Username, &s.Score)
|
||||
scores = append(scores, s)
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"code": 200, "data": scores})
|
||||
}
|
||||
|
||||
//GetScoreByUserId 获取用户分数
|
||||
func GetScoreByUserId(c *gin.Context) {
|
||||
var score int
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
//获取用户id
|
||||
id := c.Params.ByName("id")
|
||||
if id == "" {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Need id!"})
|
||||
return
|
||||
}
|
||||
//检查id是否合法
|
||||
if !tools.CheckID(id) {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "ID format error!"})
|
||||
return
|
||||
}
|
||||
//查询用户信息
|
||||
err := DB.Raw("SELECT s.score FROM score AS s, user AS u WHERE u.id = ? AND u.hidden = 0 AND u.username = s.username LIMIT 1;", id).Scan(&score).Error
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get info error!"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"code": 200, "data": score})
|
||||
}
|
||||
|
||||
|
||||
//GetSelfScoreAndRank 获取当前登录用户的分数和排名
|
||||
func GetSelfScoreAndRank(c *gin.Context){
|
||||
var scoreAndRank ScoreRankResponse
|
||||
DB := db.DBsnctf
|
||||
session, err := Store.Get(c.Request, "SNCTFSESSID")
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"code": 400, "msg": "Get SNCTFSESSID error"})
|
||||
return
|
||||
}
|
||||
user, ok := session.Values["user"].(User)
|
||||
if !ok {
|
||||
c.JSON(200, gin.H{"code": 400, "msg": "No session"})
|
||||
return
|
||||
}
|
||||
err = DB.Raw("SELECT score, (SELECT count(DISTINCT score) FROM score WHERE score>=s.score) AS rank FROM score AS s,user AS u WHERE u.id = ? AND u.username = s.username ORDER BY score DESC LIMIT 1;",user.ID).Scan(&scoreAndRank).Error
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"code": 400, "msg": "Get info error!"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"code": 200, "data": scoreAndRank})
|
||||
|
||||
}
|
27
api/session.go
Normal file
27
api/session.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/gorilla/sessions"
|
||||
. "main.go/type"
|
||||
)
|
||||
|
||||
|
||||
// sessions 存储于文件系统
|
||||
var Store *sessions.FilesystemStore
|
||||
|
||||
func init() {
|
||||
Store = sessions.NewFilesystemStore("./sessions", securecookie.GenerateRandomKey(32))
|
||||
|
||||
Store.Options = &sessions.Options{
|
||||
Domain: "",
|
||||
Path: "/",
|
||||
MaxAge: 24 * 60 * 60, // 1 day
|
||||
// SameSite: http.SameSiteNoneMode,
|
||||
Secure: false,
|
||||
HttpOnly: false,
|
||||
}
|
||||
|
||||
gob.Register(User{})
|
||||
}
|
153
api/solve.go
Normal file
153
api/solve.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
db "main.go/database"
|
||||
."main.go/type"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
//GetAllSolves 获取所有解题记录
|
||||
func GetAllSolves(c *gin.Context){
|
||||
var solves []SolveResponse
|
||||
|
||||
if err := getAllSolves(&solves); err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get all solves failure!"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"code": 200, "data": solves})
|
||||
}
|
||||
//GetSolvesByCid 获取某个用户的所有解题记录
|
||||
func GetSolvesByCid(c *gin.Context){
|
||||
var solves []SolveResponse
|
||||
|
||||
cid, err := strconv.ParseInt(c.Param("cid"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get solves failure!"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := getSolvesByCid(&solves, int(cid)); err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get solves failure!"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"code": 200, "data": solves})
|
||||
}
|
||||
|
||||
// GetSolvesByUid 根据用户id获取正确的flag提交记录。
|
||||
func GetSolvesByUid(c *gin.Context) {
|
||||
uid, err := strconv.ParseInt(c.Param("uid"), 10, 64)
|
||||
if err != nil {
|
||||
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Wrong uid!"})
|
||||
return
|
||||
}
|
||||
|
||||
if uid == 1 {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Not allowed!"})
|
||||
return
|
||||
}
|
||||
|
||||
var solves []SolveResponse
|
||||
if err := getSolvesByUid(&solves, int(uid)); err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get specified solves failure!"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"code": 200, "data": solves})
|
||||
}
|
||||
|
||||
// GetSelfSolves 获取当前用户的所有解题记录
|
||||
func GetSelfSolves(c *gin.Context){
|
||||
var solves []SolveResponse
|
||||
|
||||
session, err := Store.Get(c.Request, "SNCTFSESSID")
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"code": 400, "msg": "Get SNCTFSESSID error"})
|
||||
return
|
||||
}
|
||||
user, ok := session.Values["user"].(User)
|
||||
if !ok {
|
||||
c.JSON(200, gin.H{"code": 400, "msg": "No session"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := getSolvesByUid(&solves, user.ID); err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get self solves failure!"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"code": 200, "data": solves})
|
||||
}
|
||||
|
||||
// getSolverCount 操作数据库获取指定id题目的解出人数。
|
||||
func getSolverCount(id int) (count int, err error) {
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
//SELECT COUNT(*) FROM solve WHERE cid = ?;
|
||||
err = DB.Table("solve").Select("COUNT(*)").Where("cid = ?", id).Scan(&count).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
// getAllSolves 操作数据库获取所有正确的提交记录,按提交时间从早到晚排序。
|
||||
func getAllSolves(solves *[]SolveResponse) error {
|
||||
DB :=db.DBsnctf
|
||||
rows,err := DB.Raw("SELECT s.id, s.uid, s.cid, u.username, c.name, s.submitted_at, c.score FROM solve AS s, user AS u, challenge AS c WHERE u.hidden=0 AND s.uid=u.id AND s.cid=c.id ORDER BY s.submitted_at ASC;").Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var s SolveResponse
|
||||
err = rows.Scan(&s.ID, &s.Uid, &s.Cid, &s.Username, &s.ChallengeName, &s.SubmittedAt, &s.Score)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*solves = append(*solves, s)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// getSolvesByUid 操作数据库根据用户id获取正确的flag提交记录,按提交时间从早到晚排序。
|
||||
func getSolvesByUid(solves *[]SolveResponse, uid int) error {
|
||||
DB :=db.DBsnctf
|
||||
rows,err := DB.Debug().Raw("SELECT s.id, s.uid, s.cid, u.username, c.name, s.submitted_at, c.score FROM solve AS s, user AS u, challenge AS c WHERE u.hidden=0 AND s.uid=? AND u.id=s.uid AND c.id=s.cid ORDER BY s.submitted_at ASC;",uid).Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var s SolveResponse
|
||||
err = rows.Scan(&s.ID, &s.Uid, &s.Cid, &s.Username, &s.ChallengeName, &s.SubmittedAt, &s.Score)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*solves = append(*solves, s)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
// getSolvesByCid 操作数据库根据题目id获取正确的提交记录,按提交时间从早到晚排序。
|
||||
func getSolvesByCid(solves *[]SolveResponse, cid int) error {
|
||||
DB :=db.DBsnctf
|
||||
rows,err := DB.Raw("SELECT s.id, s.uid, s.cid, u.username, c.name, s.submitted_at, c.score FROM solve AS s, user AS u, challenge AS c WHERE u.hidden=0 AND s.cid=? AND u.id=s.uid AND c.id=s.cid ORDER BY s.submitted_at ASC;",cid).Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var s SolveResponse
|
||||
err = rows.Scan(&s.ID, &s.Uid, &s.Cid, &s.Username, &s.ChallengeName, &s.SubmittedAt, &s.Score)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*solves = append(*solves, s)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
211
api/submissions.go
Normal file
211
api/submissions.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
db "main.go/database"
|
||||
"main.go/tools"
|
||||
. "main.go/type"
|
||||
"math"
|
||||
)
|
||||
|
||||
// hasAlreadySolved 检查某道题是否已经被某用户解出。
|
||||
func hasAlreadySolved(uid int, cid int) (exists bool) {
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
err := DB.Raw("SELECT EXISTS(SELECT 1 FROM solve WHERE uid=? AND cid=?)",uid,cid).Scan(&exists).Error
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return exists
|
||||
}
|
||||
// isChallengeExisted 检查数据库中是否存在某个题目。
|
||||
func isChallengeExisted(id int) (exists bool) {
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
command := "SELECT EXISTS(SELECT 1 FROM challenge WHERE id = ?);"
|
||||
err := DB.Raw(command,id).Scan(&exists).Error
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return exists
|
||||
}
|
||||
|
||||
//addSubmission 添加一个提交记录。
|
||||
func addSubmission(s *Submission) error {
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
err := DB.Table("submission").Create(&s).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// getFlag 根据题目id获取该题的flag
|
||||
func getFlag(id int) (flag string, err error) {
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
err = DB.Table("challenge").Select("flag").Where("id = ?",id).Find(&flag).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return flag, nil
|
||||
}
|
||||
|
||||
// addSolve 操作数据库加入一条正确的flag提交记录。
|
||||
func addSolve(s *Solve) error {
|
||||
DB := db.DBsnctf
|
||||
err := DB.Table("solve").Create(&s).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addUserScore 操作数据库为指定用户增加某题的分数。
|
||||
func addUserScore(username string, cid int) error {
|
||||
var newScore int
|
||||
DB := db.DBsnctf
|
||||
err := DB.Table("challenge").Select("score").Where("id = ?", cid).Find(&newScore).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
command := "UPDATE score SET score=score+? WHERE username=?"
|
||||
err = DB.Exec(command, newScore, username).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// updateUserScores 操作数据库更新解出用户的分数。
|
||||
func updateUserScores(reducedScore, cid int) error {
|
||||
|
||||
DB:=db.DBsnctf
|
||||
command := "UPDATE score SET score=score-? WHERE EXISTS(SELECT 1 FROM user,solve WHERE user.id=solve.uid AND score.username=user.username AND solve.cid=?);"
|
||||
err := DB.Exec(command, reducedScore, cid).Error
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// editChallengeScore 操作数据库修改指定题目增的动态分数。
|
||||
func editChallengeScore(cid int) (reducedScore int, err error) {
|
||||
DB := db.DBsnctf
|
||||
var currentScore int
|
||||
err = DB.Table("challenge").Select("score").Where("id = ?", cid).Find(¤tScore).Error
|
||||
command := "SELECT score FROM challenge WHERE id=? LIMIT 1;"
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
solverCount, err := getSolverCount(cid)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// According to https://github.com/o-o-overflow/scoring-playground
|
||||
newScore := int(100 + (1000-100)/(1.0+float64(solverCount)*0.04*math.Log(float64(solverCount))))
|
||||
reducedScore = currentScore - newScore
|
||||
|
||||
command = "UPDATE challenge SET score=? WHERE id=?;"
|
||||
err = DB.Exec(command, newScore, cid).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return reducedScore, nil
|
||||
}
|
||||
|
||||
// SubmitFlag 提交flag
|
||||
func SubmitFlag(c *gin.Context) {
|
||||
var request SubmissionRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Request format wrong!"})
|
||||
return
|
||||
}
|
||||
|
||||
session,err := Store.Get(c.Request, "SNCTFSESSID")
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "GET SNCTFSESSID error!"})
|
||||
return
|
||||
}
|
||||
user,ok := session.Values["user"].(User)
|
||||
if !ok {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "GET user error!"})
|
||||
return
|
||||
}
|
||||
// 检查题目是否存在
|
||||
if !isChallengeExisted(request.Cid) {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Challenge not existed!"})
|
||||
return
|
||||
}
|
||||
// Submission记录
|
||||
solvedTime := tools.Timestamp()
|
||||
submission := &Submission{
|
||||
UserID: user.ID,
|
||||
ChallengeID: request.Cid,
|
||||
Flag: request.Flag,
|
||||
IP: c.ClientIP(),
|
||||
Time: solvedTime,
|
||||
}
|
||||
err = addSubmission(submission)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Record submission failure!"})
|
||||
return
|
||||
}
|
||||
// 是否已经解出该题
|
||||
if hasAlreadySolved(user.ID, request.Cid) {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Already solved!"})
|
||||
return
|
||||
}
|
||||
// 获取flag
|
||||
flag, err := getFlag(request.Cid)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get flag failure!"})
|
||||
return
|
||||
}
|
||||
// 检查flag是否正确
|
||||
if flag != request.Flag {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Wrong flag!"})
|
||||
return
|
||||
}else {
|
||||
// Solve记录
|
||||
solve := &Solve{
|
||||
UserID: user.ID,
|
||||
ChallengeID: request.Cid,
|
||||
Time: solvedTime,
|
||||
}
|
||||
err = addSolve(solve)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Record solve failure!"})
|
||||
return
|
||||
}
|
||||
//加分
|
||||
err = addUserScore(user.Username, request.Cid)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Add score failure!"})
|
||||
return
|
||||
}
|
||||
// 题目动态分数
|
||||
|
||||
reducedScore, err := editChallengeScore(request.Cid)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Edit challenge score failure!"})
|
||||
return
|
||||
}
|
||||
//更新所有用户分数
|
||||
err = updateUserScores(reducedScore, request.Cid)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Update user scores failure!"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"code": 200, "msg": "Correct flag!"})
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
}
|
296
api/user.go
Normal file
296
api/user.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
_ "gorm.io/driver/sqlite"
|
||||
db "main.go/database"
|
||||
"main.go/tools"
|
||||
. "main.go/type"
|
||||
"regexp"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
//连接数据库
|
||||
func Link() {
|
||||
err := db.Inimysql()
|
||||
if err != nil{
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
func Login(c * gin.Context) {
|
||||
var request LoginRequest
|
||||
var user User
|
||||
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
DB.AutoMigrate(&User{})
|
||||
//用ShouldBindJSON解析绑定传入的Json数据。
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Bind json error!"})
|
||||
return
|
||||
}
|
||||
//code 0用户名不存在
|
||||
//code 1用户名或密码错误
|
||||
|
||||
err := DB.Take(&user,"username = ?",request.Username).Error
|
||||
if err != nil {
|
||||
c.JSON(200,gin.H{
|
||||
"Error":true,
|
||||
"Msg":"登陆失败",
|
||||
"Code": 0,
|
||||
})
|
||||
return
|
||||
}
|
||||
//判断md5值与数据库内容是否相同
|
||||
if user.Password != tools.MD5(request.Password){
|
||||
c.JSON(200,gin.H{
|
||||
"Error":true,
|
||||
"Msg":"用户名或密码错误",
|
||||
"Code": 1,
|
||||
})
|
||||
return
|
||||
}
|
||||
//身份认证结束
|
||||
//session
|
||||
// 设置session
|
||||
session, _ := Store.Get(c.Request, "SNCTFSESSID")
|
||||
user.Password = ""
|
||||
// 根据remember值设置session有效期
|
||||
if request.Remember {
|
||||
session.Options.MaxAge = 7 * 24 * 60 * 60 // 7 days
|
||||
} else {
|
||||
session.Options.MaxAge = 24 * 60 * 60 // 1 day
|
||||
}
|
||||
//保存session
|
||||
session.Values["user"] = user
|
||||
|
||||
err = session.Save(c.Request, c.Writer)
|
||||
fmt.Println(err)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Save SNCTFSESSID error"})
|
||||
return
|
||||
}
|
||||
//登录成功
|
||||
c.JSON(200, gin.H{"code": 200, "username": user.Username, "role": user.Role, "msg": "Login success!"})
|
||||
|
||||
}
|
||||
|
||||
|
||||
func Register(c *gin.Context) {
|
||||
var request RegisterRequest
|
||||
var user User
|
||||
var score ScoreResponse
|
||||
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
err := DB.AutoMigrate(&User{})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
//用ShouldBindJSON解析绑定传入的Json数据。
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Bind json error!"})
|
||||
return
|
||||
}
|
||||
//限制传入用户名为中文、数字、大小写字母下划线和横杠,1到10位
|
||||
if !checkUsername(request.Username) {
|
||||
c.JSON(200, gin.H{
|
||||
"Error": true,
|
||||
"Msg": "用户名不符合规范",
|
||||
"Code": 2,
|
||||
})
|
||||
return
|
||||
}
|
||||
//限制密码长度6到20位
|
||||
if !checkPassword(request.Password) {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Password format error!"})
|
||||
return
|
||||
}
|
||||
//限制传入邮箱符合格式
|
||||
if !checkEmail(request.Email) {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Email format error!"})
|
||||
return
|
||||
}
|
||||
//判断用户名是否已被使用
|
||||
if isNameExisted(user, request.Username) {
|
||||
c.JSON(200, gin.H{"code": 1000, "msg": "Username has already been used!"})
|
||||
return
|
||||
}
|
||||
//判断邮箱是否已被使用
|
||||
if isEmailExisted(user, request.Email) {
|
||||
c.JSON(200, gin.H{"code": 1001, "msg": "Email has already been used!"})
|
||||
return
|
||||
}
|
||||
user.Token = tools.Token()
|
||||
user.Username = request.Username
|
||||
user.Password = tools.MD5(request.Password)
|
||||
user.Email = request.Email
|
||||
user.Created = tools.Timestamp()
|
||||
score.Score = 0
|
||||
score.Username =request.Username
|
||||
//创建数据
|
||||
err1 := DB.Table("user").Create(&user).Error
|
||||
err2 := DB.Table("score").Create(&score).Error
|
||||
if err1 != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Register error!"})
|
||||
return
|
||||
}
|
||||
if err2 != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Register error!"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": "注册成功",
|
||||
})
|
||||
|
||||
|
||||
}
|
||||
|
||||
// Logout 实现用户注销登陆
|
||||
func Logout(c *gin.Context) {
|
||||
var user User
|
||||
|
||||
session, err := Store.Get(c.Request, "SNCTFSESSID")
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get CTFGOSESSID error"})
|
||||
return
|
||||
}
|
||||
user, ok := session.Values["user"].(User)
|
||||
if !ok {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "No session"})
|
||||
return
|
||||
}
|
||||
|
||||
session.Options.MaxAge = -1
|
||||
err = session.Save(c.Request, c.Writer)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Save CTFGOSESSID error"})
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Sprintf("[%s] logout success!", user.Username)
|
||||
c.JSON(200, gin.H{"code": 200, "msg": "Logout success!"})
|
||||
}
|
||||
//GetInfoByUserId 获取用户信息
|
||||
func GetInfoByUserId(c *gin.Context) {
|
||||
var info PublicInfoResponse
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
//获取用户id
|
||||
id := c.Params.ByName("id")
|
||||
if id == "" {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Need id!"})
|
||||
return
|
||||
}
|
||||
//检查id是否合法
|
||||
|
||||
|
||||
if !tools.CheckID(id) {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "ID format error!"})
|
||||
return
|
||||
}
|
||||
|
||||
err := DB.Debug().Raw("SELECT username,affiliation,country,team_id FROM user WHERE id = ? LIMIT 1", id).Scan(&info).Error
|
||||
//err := DB.Where("id = ?", id).First(user).Error
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get info error!"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"code": 200, "data": info})
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
// checkUsername 验证用户名是否符合中文数字字母下划线横杠,长度1到10位,返回true或false
|
||||
func checkUsername(username string) bool {
|
||||
if !(utf8.RuneCountInString(username) > 0) || !(utf8.RuneCountInString(username) < 11) {
|
||||
return false
|
||||
}
|
||||
pattern := `^[-\w\p{Han}]+$`
|
||||
reg := regexp.MustCompile(pattern)
|
||||
return reg.MatchString(username)
|
||||
}
|
||||
|
||||
// checkEmail 验证是否符合邮箱格式,返回true或false
|
||||
func checkEmail(email string) bool {
|
||||
pattern := `^\w+([-+.]\w+)*@\w+([-.]\w+)*\.\w+([-.]\w+)*$`
|
||||
reg := regexp.MustCompile(pattern)
|
||||
return reg.MatchString(email)
|
||||
}
|
||||
|
||||
|
||||
// checkPassword 验证密码是否符合长度6到20位,返回true或false
|
||||
func checkPassword(password string) bool {
|
||||
if !(utf8.RuneCountInString(password) > 5) || !(utf8.RuneCountInString(password) < 21) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
// isNameExisted 判断用户名是否已经被占用,被占用返回true,未被占用则返回false
|
||||
func isNameExisted(user User, username string) bool {
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
err := DB.First(&user,"Username = ?",username).Error
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
}
|
||||
// isNameExisted 判断邮箱是否已经被占用,被占用返回true,未被占用则返回false
|
||||
func isEmailExisted(user User, email string) bool {
|
||||
Link()
|
||||
DB := db.DBsnctf
|
||||
|
||||
err := DB.First(&user,"Email = ?",email).Error
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Session 获取当前用户session信息
|
||||
func Session(c *gin.Context) {
|
||||
var user User
|
||||
|
||||
// 默认在此之前已经通过了中间件的session权限验证
|
||||
session, _ := Store.Get(c.Request, "SNCTFSESSID")
|
||||
user = session.Values["user"].(User)
|
||||
|
||||
c.JSON(200, gin.H{"code": 200, "data": user})
|
||||
}
|
||||
|
||||
|
||||
|
||||
//下面是身份认证用 AUTH部分
|
||||
// AuthRequired 用于普通用户权限控制的中间件
|
||||
func AuthRequired()gin.HandlerFunc{
|
||||
return func(c *gin.Context) {
|
||||
session, err := Store.Get(c.Request, "SNCTFSESSID")
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Get SNCTFSESSID error"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
user, ok := session.Values["user"].(User)
|
||||
if !ok {
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "No session"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if user.Role != 0&&user.Role!=1{
|
||||
c.JSON(400, gin.H{"code": 400, "msg": "Permission denied"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user