Files
SNCTF/api/user.go
jiayuqi7813 0cd63a7111 修复bug
2022-07-16 02:16:31 +08:00

348 lines
8.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package api
import (
"fmt"
"github.com/gin-gonic/gin"
_ "gorm.io/driver/sqlite"
db "main.go/database"
"main.go/tools"
. "main.go/type"
"regexp"
"unicode/utf8"
)
//连接数据库
func Link() {
err := db.Inimysql()
if err != nil {
panic(err)
}
}
func Login(c *gin.Context) {
var request LoginRequest
var user User
Link()
DB := db.DBsnctf
DB.AutoMigrate(&User{})
//用ShouldBindJSON解析绑定传入的Json数据。
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(400, gin.H{"code": 400, "msg": "Bind json error!"})
return
}
//code 0用户名不存在
//code 1用户名或密码错误
err := DB.Take(&user, "username = ?", request.Username).Error
if err != nil {
c.JSON(200, gin.H{
"Error": true,
"Msg": "登陆失败",
"Code": 0,
})
return
}
//判断md5值与数据库内容是否相同
if user.Password != tools.MD5(request.Password) {
c.JSON(200, gin.H{
"Error": true,
"Msg": "用户名或密码错误",
"Code": 1,
})
return
}
//身份认证结束
//session
// 设置session
session, _ := Store.Get(c.Request, "SNCTFSESSID")
user.Password = ""
// 根据remember值设置session有效期
if request.Remember {
session.Options.MaxAge = 7 * 24 * 60 * 60 // 7 days
} else {
session.Options.MaxAge = 24 * 60 * 60 // 1 day
}
//保存session
session.Values["user"] = user
err = session.Save(c.Request, c.Writer)
fmt.Println(err)
if err != nil {
c.JSON(400, gin.H{"code": 400, "msg": "Save SNCTFSESSID error"})
return
}
//登录成功
c.JSON(200, gin.H{"code": 200, "username": user.Username, "role": user.Role, "msg": "Login success!"})
}
func Register(c *gin.Context) {
var request RegisterRequest
var user User
var score ScoreResponse
Link()
DB := db.DBsnctf
err := DB.AutoMigrate(&User{})
if err != nil {
return
}
//用ShouldBindJSON解析绑定传入的Json数据。
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(400, gin.H{"code": 400, "msg": "Bind json error!"})
return
}
//限制传入用户名为中文、数字、大小写字母下划线和横杠1到10位
if !checkUsername(request.Username) {
c.JSON(200, gin.H{
"Error": true,
"Msg": "用户名不符合规范",
"Code": 2,
})
return
}
//限制密码长度6到20位
if !checkPassword(request.Password) {
c.JSON(400, gin.H{"code": 400, "msg": "Password format error!"})
return
}
//限制传入邮箱符合格式
if !checkEmail(request.Email) {
c.JSON(400, gin.H{"code": 400, "msg": "Email format error!"})
return
}
//判断用户名是否已被使用
if isNameExisted(user, request.Username) {
c.JSON(200, gin.H{"code": 1000, "msg": "Username has already been used!"})
return
}
//判断邮箱是否已被使用
if isEmailExisted(user, request.Email) {
c.JSON(200, gin.H{"code": 1001, "msg": "Email has already been used!"})
return
}
user.Token = tools.Token()
user.Username = request.Username
user.Password = tools.MD5(request.Password)
user.Email = request.Email
user.Created = tools.Timestamp()
score.Score = 0
score.Username = request.Username
//创建数据
err1 := DB.Table("user").Create(&user).Error
err2 := DB.Table("score").Create(&score).Error
if err1 != nil {
c.JSON(400, gin.H{"code": 400, "msg": "Register error!"})
return
}
if err2 != nil {
c.JSON(400, gin.H{"code": 400, "msg": "Register error!"})
return
}
c.JSON(200, gin.H{
"code": 200,
"success": "注册成功",
})
}
// Logout 实现用户注销登陆
func Logout(c *gin.Context) {
var user User
session, err := Store.Get(c.Request, "SNCTFSESSID")
if err != nil {
c.JSON(400, gin.H{"code": 400, "msg": "Get CTFGOSESSID error"})
return
}
user, ok := session.Values["user"].(User)
if !ok {
c.JSON(400, gin.H{"code": 400, "msg": "No session"})
return
}
session.Options.MaxAge = -1
err = session.Save(c.Request, c.Writer)
if err != nil {
c.JSON(400, gin.H{"code": 400, "msg": "Save CTFGOSESSID error"})
return
}
fmt.Sprintf("[%s] logout success!", user.Username)
c.JSON(200, gin.H{"code": 200, "msg": "Logout success!"})
}
//GetInfoByUserId 获取用户信息
func GetInfoByUserId(c *gin.Context) {
var info PublicInfoResponse
Link()
DB := db.DBsnctf
//获取用户id
id := c.Params.ByName("id")
if id == "" {
c.JSON(400, gin.H{"code": 400, "msg": "Need id!"})
return
}
//检查id是否合法
if !tools.CheckID(id) {
c.JSON(400, gin.H{"code": 400, "msg": "ID format error!"})
return
}
err := DB.Debug().Raw("SELECT username,affiliation,country,team_id,website FROM user WHERE id = ? LIMIT 1", id).Scan(&info).Error
//err := DB.Where("id = ?", id).First(user).Error
if err != nil {
c.JSON(400, gin.H{"code": 400, "msg": "Get info error!"})
return
}
c.JSON(200, gin.H{"code": 200, "data": info})
}
// UpdateUserInfo 更新用户信息
func UpdateUserInfo(c *gin.Context) {
var user User
var request UpdateUserInfoRequest
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(400, gin.H{"code": 400, "msg": "Request format wrong!"})
return
}
Link()
DB := db.DBsnctf
session, err := Store.Get(c.Request, "SNCTFSESSID")
if err != nil {
c.JSON(200, gin.H{"code": 400, "msg": "Get SNCTFSESSID error"})
}
user, ok := session.Values["user"].(User)
if !ok {
c.JSON(200, gin.H{"code": 400, "msg": "No session"})
}
//获取用户id
userid := user.ID
//获取传入数据
username := request.Name
affiliation := request.Affiliation
country := request.Country
website := request.Website
email := request.Email
//数据库更新数据
err = DB.Model(&user).Where("id = ?", userid).Update("username", username).Update("affiliation", affiliation).Update("country", country).Update("website", website).Update("email", email).Error
if err != nil {
c.JSON(200, gin.H{"code": 400, "msg": "Update info error!"})
return
}
//更新session
session.Values["user"] = user
c.JSON(200, gin.H{"code": 200, "msg": "Update info success!"})
}
// GetAllUserInfo 获取所有用户信息
func GetAllUserInfo(c *gin.Context) {
var info PublicAllInfoResponse
var alla []PublicAllInfoResponse
Link()
DB := db.DBsnctf
rows, err := DB.Debug().Select([]string{"id", "username", "affiliation", "country", "website"}).Table("user").Where("hidden = ?", 0).Rows()
if err != nil {
c.JSON(400, gin.H{"code": 400, "msg": "Get info error!"})
return
}
for rows.Next() {
rows.Scan(&info.Id, &info.Username, &info.Affiliation, &info.Country, &info.Website)
alla = append(alla, info)
}
c.JSON(200, gin.H{"code": 200, "data": alla})
}
// checkUsername 验证用户名是否符合中文数字字母下划线横杠长度1到10位返回true或false
func checkUsername(username string) bool {
if !(utf8.RuneCountInString(username) > 0) || !(utf8.RuneCountInString(username) < 11) {
return false
}
pattern := `^[-\w\p{Han}]+$`
reg := regexp.MustCompile(pattern)
return reg.MatchString(username)
}
// checkEmail 验证是否符合邮箱格式返回true或false
func checkEmail(email string) bool {
pattern := `^\w+([-+.]\w+)*@\w+([-.]\w+)*\.\w+([-.]\w+)*$`
reg := regexp.MustCompile(pattern)
return reg.MatchString(email)
}
// checkPassword 验证密码是否符合长度6到20位返回true或false
func checkPassword(password string) bool {
if !(utf8.RuneCountInString(password) > 5) || !(utf8.RuneCountInString(password) < 21) {
return false
}
return true
}
// isNameExisted 判断用户名是否已经被占用被占用返回true未被占用则返回false
func isNameExisted(user User, username string) bool {
Link()
DB := db.DBsnctf
err := DB.First(&user, "Username = ?", username).Error
if err != nil {
return false
}
return true
}
// isNameExisted 判断邮箱是否已经被占用被占用返回true未被占用则返回false
func isEmailExisted(user User, email string) bool {
Link()
DB := db.DBsnctf
err := DB.First(&user, "Email = ?", email).Error
if err != nil {
return false
}
return true
}
// Session 获取当前用户session信息
func Session(c *gin.Context) {
var user User
// 默认在此之前已经通过了中间件的session权限验证
session, _ := Store.Get(c.Request, "SNCTFSESSID")
user = session.Values["user"].(User)
c.JSON(200, gin.H{"code": 200, "data": user})
}
//下面是身份认证用 AUTH部分
// AuthRequired 用于普通用户权限控制的中间件
func AuthRequired() gin.HandlerFunc {
return func(c *gin.Context) {
session, err := Store.Get(c.Request, "SNCTFSESSID")
if err != nil {
c.JSON(400, gin.H{"code": 400, "msg": "Get SNCTFSESSID error"})
c.Abort()
return
}
user, ok := session.Values["user"].(User)
if !ok {
c.JSON(400, gin.H{"code": 400, "msg": "No session"})
c.Abort()
return
}
if user.Role != 0 && user.Role != 1 {
c.JSON(400, gin.H{"code": 400, "msg": "Permission denied"})
c.Abort()
return
}
c.Next()
}
}