348 lines
8.7 KiB
Go
348 lines
8.7 KiB
Go
package api
|
||
|
||
import (
|
||
"fmt"
|
||
"github.com/gin-gonic/gin"
|
||
_ "gorm.io/driver/sqlite"
|
||
db "main.go/database"
|
||
"main.go/tools"
|
||
. "main.go/type"
|
||
"regexp"
|
||
"unicode/utf8"
|
||
)
|
||
|
||
//连接数据库
|
||
func Link() {
|
||
err := db.Inimysql()
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
}
|
||
|
||
func Login(c *gin.Context) {
|
||
var request LoginRequest
|
||
var user User
|
||
|
||
Link()
|
||
DB := db.DBsnctf
|
||
DB.AutoMigrate(&User{})
|
||
//用ShouldBindJSON解析绑定传入的Json数据。
|
||
if err := c.ShouldBindJSON(&request); err != nil {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Bind json error!"})
|
||
return
|
||
}
|
||
//code 0用户名不存在
|
||
//code 1用户名或密码错误
|
||
|
||
err := DB.Take(&user, "username = ?", request.Username).Error
|
||
|
||
if err != nil {
|
||
c.JSON(200, gin.H{
|
||
"Error": true,
|
||
"Msg": "登陆失败",
|
||
"Code": 0,
|
||
})
|
||
return
|
||
}
|
||
//判断md5值与数据库内容是否相同
|
||
if user.Password != tools.MD5(request.Password) {
|
||
c.JSON(200, gin.H{
|
||
"Error": true,
|
||
"Msg": "用户名或密码错误",
|
||
"Code": 1,
|
||
})
|
||
return
|
||
}
|
||
//身份认证结束
|
||
//session
|
||
// 设置session
|
||
session, _ := Store.Get(c.Request, "SNCTFSESSID")
|
||
user.Password = ""
|
||
// 根据remember值设置session有效期
|
||
if request.Remember {
|
||
session.Options.MaxAge = 7 * 24 * 60 * 60 // 7 days
|
||
} else {
|
||
session.Options.MaxAge = 24 * 60 * 60 // 1 day
|
||
}
|
||
//保存session
|
||
session.Values["user"] = user
|
||
|
||
err = session.Save(c.Request, c.Writer)
|
||
fmt.Println(err)
|
||
if err != nil {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Save SNCTFSESSID error"})
|
||
return
|
||
}
|
||
//登录成功
|
||
c.JSON(200, gin.H{"code": 200, "username": user.Username, "role": user.Role, "msg": "Login success!"})
|
||
|
||
}
|
||
|
||
func Register(c *gin.Context) {
|
||
var request RegisterRequest
|
||
var user User
|
||
var score ScoreResponse
|
||
|
||
Link()
|
||
DB := db.DBsnctf
|
||
err := DB.AutoMigrate(&User{})
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
//用ShouldBindJSON解析绑定传入的Json数据。
|
||
if err := c.ShouldBindJSON(&request); err != nil {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Bind json error!"})
|
||
return
|
||
}
|
||
//限制传入用户名为中文、数字、大小写字母下划线和横杠,1到10位
|
||
if !checkUsername(request.Username) {
|
||
c.JSON(200, gin.H{
|
||
"Error": true,
|
||
"Msg": "用户名不符合规范",
|
||
"Code": 2,
|
||
})
|
||
return
|
||
}
|
||
//限制密码长度6到20位
|
||
if !checkPassword(request.Password) {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Password format error!"})
|
||
return
|
||
}
|
||
//限制传入邮箱符合格式
|
||
if !checkEmail(request.Email) {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Email format error!"})
|
||
return
|
||
}
|
||
//判断用户名是否已被使用
|
||
if isNameExisted(user, request.Username) {
|
||
c.JSON(200, gin.H{"code": 1000, "msg": "Username has already been used!"})
|
||
return
|
||
}
|
||
//判断邮箱是否已被使用
|
||
if isEmailExisted(user, request.Email) {
|
||
c.JSON(200, gin.H{"code": 1001, "msg": "Email has already been used!"})
|
||
return
|
||
}
|
||
user.Token = tools.Token()
|
||
user.Username = request.Username
|
||
user.Password = tools.MD5(request.Password)
|
||
user.Email = request.Email
|
||
user.Created = tools.Timestamp()
|
||
score.Score = 0
|
||
score.Username = request.Username
|
||
//创建数据
|
||
err1 := DB.Table("user").Create(&user).Error
|
||
err2 := DB.Table("score").Create(&score).Error
|
||
if err1 != nil {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Register error!"})
|
||
return
|
||
}
|
||
if err2 != nil {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Register error!"})
|
||
return
|
||
}
|
||
|
||
c.JSON(200, gin.H{
|
||
"code": 200,
|
||
"success": "注册成功",
|
||
})
|
||
|
||
}
|
||
|
||
// Logout 实现用户注销登陆
|
||
func Logout(c *gin.Context) {
|
||
var user User
|
||
|
||
session, err := Store.Get(c.Request, "SNCTFSESSID")
|
||
if err != nil {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Get CTFGOSESSID error"})
|
||
return
|
||
}
|
||
user, ok := session.Values["user"].(User)
|
||
if !ok {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "No session"})
|
||
return
|
||
}
|
||
|
||
session.Options.MaxAge = -1
|
||
err = session.Save(c.Request, c.Writer)
|
||
if err != nil {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Save CTFGOSESSID error"})
|
||
return
|
||
}
|
||
|
||
fmt.Sprintf("[%s] logout success!", user.Username)
|
||
c.JSON(200, gin.H{"code": 200, "msg": "Logout success!"})
|
||
}
|
||
|
||
//GetInfoByUserId 获取用户信息
|
||
func GetInfoByUserId(c *gin.Context) {
|
||
var info PublicInfoResponse
|
||
Link()
|
||
DB := db.DBsnctf
|
||
//获取用户id
|
||
id := c.Params.ByName("id")
|
||
if id == "" {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Need id!"})
|
||
return
|
||
}
|
||
//检查id是否合法
|
||
|
||
if !tools.CheckID(id) {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "ID format error!"})
|
||
return
|
||
}
|
||
|
||
err := DB.Debug().Raw("SELECT username,affiliation,country,team_id,website FROM user WHERE id = ? LIMIT 1", id).Scan(&info).Error
|
||
//err := DB.Where("id = ?", id).First(user).Error
|
||
if err != nil {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Get info error!"})
|
||
return
|
||
}
|
||
c.JSON(200, gin.H{"code": 200, "data": info})
|
||
|
||
}
|
||
|
||
// UpdateUserInfo 更新用户信息
|
||
func UpdateUserInfo(c *gin.Context) {
|
||
var user User
|
||
var request UpdateUserInfoRequest
|
||
if err := c.ShouldBindJSON(&request); err != nil {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Request format wrong!"})
|
||
return
|
||
}
|
||
Link()
|
||
DB := db.DBsnctf
|
||
session, err := Store.Get(c.Request, "SNCTFSESSID")
|
||
if err != nil {
|
||
c.JSON(200, gin.H{"code": 400, "msg": "Get SNCTFSESSID error"})
|
||
}
|
||
user, ok := session.Values["user"].(User)
|
||
if !ok {
|
||
c.JSON(200, gin.H{"code": 400, "msg": "No session"})
|
||
}
|
||
//获取用户id
|
||
userid := user.ID
|
||
//获取传入数据
|
||
username := request.Name
|
||
affiliation := request.Affiliation
|
||
country := request.Country
|
||
website := request.Website
|
||
email := request.Email
|
||
//数据库更新数据
|
||
err = DB.Model(&user).Where("id = ?", userid).Update("username", username).Update("affiliation", affiliation).Update("country", country).Update("website", website).Update("email", email).Error
|
||
if err != nil {
|
||
c.JSON(200, gin.H{"code": 400, "msg": "Update info error!"})
|
||
return
|
||
}
|
||
//更新session
|
||
session.Values["user"] = user
|
||
c.JSON(200, gin.H{"code": 200, "msg": "Update info success!"})
|
||
|
||
}
|
||
|
||
// GetAllUserInfo 获取所有用户信息
|
||
func GetAllUserInfo(c *gin.Context) {
|
||
var info PublicAllInfoResponse
|
||
var alla []PublicAllInfoResponse
|
||
Link()
|
||
DB := db.DBsnctf
|
||
rows, err := DB.Debug().Select([]string{"id", "username", "affiliation", "country", "website"}).Table("user").Where("hidden = ?", 0).Rows()
|
||
|
||
if err != nil {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Get info error!"})
|
||
return
|
||
}
|
||
for rows.Next() {
|
||
rows.Scan(&info.Id, &info.Username, &info.Affiliation, &info.Country, &info.Website)
|
||
alla = append(alla, info)
|
||
}
|
||
c.JSON(200, gin.H{"code": 200, "data": alla})
|
||
}
|
||
|
||
// checkUsername 验证用户名是否符合中文数字字母下划线横杠,长度1到10位,返回true或false
|
||
func checkUsername(username string) bool {
|
||
if !(utf8.RuneCountInString(username) > 0) || !(utf8.RuneCountInString(username) < 11) {
|
||
return false
|
||
}
|
||
pattern := `^[-\w\p{Han}]+$`
|
||
reg := regexp.MustCompile(pattern)
|
||
return reg.MatchString(username)
|
||
}
|
||
|
||
// checkEmail 验证是否符合邮箱格式,返回true或false
|
||
func checkEmail(email string) bool {
|
||
pattern := `^\w+([-+.]\w+)*@\w+([-.]\w+)*\.\w+([-.]\w+)*$`
|
||
reg := regexp.MustCompile(pattern)
|
||
return reg.MatchString(email)
|
||
}
|
||
|
||
// checkPassword 验证密码是否符合长度6到20位,返回true或false
|
||
func checkPassword(password string) bool {
|
||
if !(utf8.RuneCountInString(password) > 5) || !(utf8.RuneCountInString(password) < 21) {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
// isNameExisted 判断用户名是否已经被占用,被占用返回true,未被占用则返回false
|
||
func isNameExisted(user User, username string) bool {
|
||
Link()
|
||
DB := db.DBsnctf
|
||
err := DB.First(&user, "Username = ?", username).Error
|
||
if err != nil {
|
||
return false
|
||
}
|
||
return true
|
||
|
||
}
|
||
|
||
// isNameExisted 判断邮箱是否已经被占用,被占用返回true,未被占用则返回false
|
||
func isEmailExisted(user User, email string) bool {
|
||
Link()
|
||
DB := db.DBsnctf
|
||
|
||
err := DB.First(&user, "Email = ?", email).Error
|
||
if err != nil {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
// Session 获取当前用户session信息
|
||
func Session(c *gin.Context) {
|
||
var user User
|
||
|
||
// 默认在此之前已经通过了中间件的session权限验证
|
||
session, _ := Store.Get(c.Request, "SNCTFSESSID")
|
||
user = session.Values["user"].(User)
|
||
|
||
c.JSON(200, gin.H{"code": 200, "data": user})
|
||
}
|
||
|
||
//下面是身份认证用 AUTH部分
|
||
// AuthRequired 用于普通用户权限控制的中间件
|
||
func AuthRequired() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
session, err := Store.Get(c.Request, "SNCTFSESSID")
|
||
if err != nil {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Get SNCTFSESSID error"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
user, ok := session.Values["user"].(User)
|
||
if !ok {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "No session"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
if user.Role != 0 && user.Role != 1 {
|
||
c.JSON(400, gin.H{"code": 400, "msg": "Permission denied"})
|
||
c.Abort()
|
||
return
|
||
}
|
||
c.Next()
|
||
}
|
||
}
|