golang gin接口签名sign

gin.Default().Group("").Use(xxx.CheckSign())
{
//注册需要签名的路由
gin.Default().Group("testrouter").POST("dotest", func(ctx *gin.Context) {
response.Result(401, gin.H{}, "hello", ctx)
})
}

sign.go 签名 验签

package xxx

import (
"crypto/sha256"
"encoding/hex"
"fmt"

"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
"github.com/gin-gonic/gin"
)

var (
SignWrong = "sign wrong"
SignNull = "sign null"
ApiKeyNull = "sign null"
UserWrong = "username wrong"
)

func CheckSign() gin.HandlerFunc {
return func(c *gin.Context) {
reqData := &RequestHttp{
Ctx: c,
Params: map[string]interface{}{},
}
apiKey := c.Request.FormValue("api_key")
if apiKey == "" {
response.Result(401, gin.H{}, ApiKeyNull, c)
c.Abort()
return
}
apiSecret := ""
signReq := c.Request.FormValue("sign")
if signReq == "" {
response.Result(401, gin.H{}, SignNull, c)
c.Abort()
return
}
signStr := reqData.RequestParams("sign") + apiSecret
fmt.Println(signStr, "signStr")
signReal := SignEncode(signStr)
if signReq != signReal {
response.Result(401, gin.H{}, SignWrong, c)
c.Abort()
return
}
c.Next()
}
}

//@function: SignEncode
//@description: 生成sign
//@param: message string
//@return: sign string
func SignEncode(message string) string {
return GetSHA256HashCode(message)
}

//@function: GetSHA256HashCode
//@description: SHA256生成哈希值
//@param: message string
//@return: hashCode string
func GetSHA256HashCode(message string) string {
messageArr := []byte(message)
//创建一个基于SHA256算法的hash.Hash接口的对象
hash := sha256.New()
//输入数据
hash.Write(messageArr)
//计算哈希值
bytes := hash.Sum(nil)
//将字符串编码为16进制格式,返回字符串
hashCode := hex.EncodeToString(bytes)
//返回哈希值
return hashCode
}

get_request.go 获取请求参数,转换请求参数为string,参数key排序

package xxx

import (
"fmt"
"sort"
"strings"
"sync"

"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
)

type RequestHttp struct {
Ctx *gin.Context
Params map[string]interface{}
lock sync.Mutex
}

//@function: JoinParamsStr
//@description: 请求参数转换字符串
//@param: message string
//@return: a=1&b=2 string
func (r *RequestHttp) JoinParamsStr() string {
//先按key 排序 升序 ASCII 升序
keys := make([]string, 0, len(r.Params))
for k := range r.Params {
keys = append(keys, k)
}
sort.Strings(keys)
var params []string
if len(r.Params) > 0 {
for _, k := range keys {
params = append(params, fmt.Sprintf("%s=%v", k, r.Params[k]))
}
}
return strings.Join(params, "&")
}

//@function: RequestParams
//@description: 获取参数集合
//@param: exclude string 排除key
//@return: hashCode string
func (r *RequestHttp) RequestParams(exclude string) string {
ctx := r.Ctx
bindParams := map[string]interface{}{}
if ctx.Request.Method == "POST" {
contextType := ctx.Request.Header.Get("Content-Type")
if contextType == "application/json" {
err := ctx.ShouldBindBodyWith(&bindParams, binding.JSON)
if err != nil { //报错
fmt.Printf("nyx_request_mid_error %v,err: %v \n", bindParams, err)
return ""
}
if len(bindParams) > 0 {
for k, v := range bindParams {
r.Add(k, v)
}
}
} else {
_ = ctx.Request.ParseMultipartForm(32 << 20)
if len(ctx.Request.PostForm) > 0 {
for k, v := range ctx.Request.PostForm {
r.Add(k, v[0])
}
}
}
} else {
var tmpParams = make(map[string]string)
err2 := ctx.ShouldBind(&tmpParams)
if err2 != nil {
fmt.Printf("nyx_request_mid_error %v,err: %v \n", bindParams, err2)
return ""
}
for k, v := range tmpParams {
r.Add(k, v)
}
}
r.Delete(exclude)
return r.JoinParamsStr()
}

//添加参数
func (r *RequestHttp) Add(key string, value interface{}) {
r.lock.Lock()
r.Params[key] = value
r.lock.Unlock()
}

//删除参数
func (r *RequestHttp) Delete(key string) {
r.lock.Lock()
if _, ok := r.Params[key]; ok {
delete(r.Params, key)
}
r.lock.Unlock()
}