修改代码结构

This commit is contained in:
TOP糯米 2023-03-28 13:58:18 +08:00
parent 34d4612b51
commit 8dffe0d5b4
8 changed files with 121 additions and 118 deletions

View File

@ -1,7 +1,7 @@
package core package core
import ( import (
"chatgpt/util" "chatgpt/utils"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect" "reflect"
@ -107,7 +107,7 @@ func BuildResponse(privateKey string, code int, msg string, data ResponseData) s
Msg: msg, Msg: msg,
} }
dataJson, _ := json.Marshal(data) dataJson, _ := json.Marshal(data)
resp.Data = util.Base64{ resp.Data = utils.Base64{
Content: []byte(dataJson), Content: []byte(dataJson),
}.Encode() }.Encode()
if privateKey != "" { if privateKey != "" {
@ -147,5 +147,5 @@ func MakeSign(obj interface{}, privateKey string) string {
str = fmt.Sprintf("%s%s=%v&", str, currentKey, currentValue) str = fmt.Sprintf("%s%s=%v&", str, currentKey, currentValue)
} }
return util.Md5(str + "key=" + privateKey) return utils.Md5(str + "key=" + privateKey)
} }

View File

@ -27,13 +27,8 @@ func (g *GPT) HttpClientWithProxy(proxy string) *http.Client {
transport.Proxy = http.ProxyURL(proxyAddr) transport.Proxy = http.ProxyURL(proxyAddr)
} }
timeout := 10
if g.Timeout > 0 {
timeout = g.Timeout
}
return &http.Client{ return &http.Client{
Timeout: time.Duration(timeout) * time.Second, Timeout: time.Duration(g.Timeout) * time.Second,
Transport: transport, Transport: transport,
} }
} }

View File

@ -1,7 +1,7 @@
package core package core
import ( import (
"chatgpt/util" "chatgpt/utils"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -15,19 +15,19 @@ var wg sync.WaitGroup
// proxyAddr // proxyAddr
// 代理配置 // 代理配置
var proxyAddr string = util.GetProxyServer() var proxyAddr string = utils.GetProxyServer()
// apiKey // apiKey
// GPT Api Key // GPT Api Key
var apiKey string = util.Conf.Get("gpt", "api_key") var apiKey string = utils.GetConfig("gpt", "api_key", "")
// privateKey // privateKey
// 签名密钥 // 签名密钥
var privateKey string = util.Conf.Get("init", "private_key") var privateKey string = utils.GetConfig("init", "private_key", "")
// TimeoutValue // TimeoutValue
// 超时配置 // 超时配置
var TimeoutValue string = util.Conf.Get("init", "timeout") var TimeoutValue string = utils.GetConfig("init", "timeout", "10")
// Action // Action
// 处理请求 // 处理请求
@ -45,11 +45,11 @@ func Handler(w http.ResponseWriter, r *http.Request) {
json.Unmarshal(raw, &request) json.Unmarshal(raw, &request)
if privateKey != "" && request.Sign != MakeSign(request, privateKey) { if privateKey != "" && request.Sign != MakeSign(request, privateKey) {
fmt.Fprintln(w, BuildResponse(privateKey, 0, "sign error", ResponseData{})) fmt.Fprintln(w, BuildResponse(privateKey, 0, "签名错误", ResponseData{}))
return return
} }
question := util.Base64{ question := utils.Base64{
Content: []byte(request.Words), Content: []byte(request.Words),
}.Decode() }.Decode()
if len(question) > 0 { if len(question) > 0 {
@ -58,18 +58,21 @@ func Handler(w http.ResponseWriter, r *http.Request) {
Proxy: proxyAddr, Proxy: proxyAddr,
Timeout: Timeout, Timeout: Timeout,
} }
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
if result, err := gpt.GetAnswer(question); err == nil { if result, err := gpt.GetAnswer(question); err == nil {
fmt.Fprintln(w, BuildResponse(privateKey, 1, "success", ResponseData{ fmt.Fprintln(w, BuildResponse(privateKey, 1, "请求成功", ResponseData{
"answer": result.Choices, "answer": result.Choices,
})) }))
} else {
fmt.Fprintln(w, BuildResponse(privateKey, 0, "网络错误", ResponseData{}))
} }
}() }()
wg.Wait() wg.Wait()
return return
} }
} }
fmt.Fprintln(w, BuildResponse(privateKey, 0, "error", ResponseData{})) fmt.Fprintln(w, BuildResponse(privateKey, 0, "未知错误", ResponseData{}))
} }

View File

@ -2,18 +2,17 @@ package main
import ( import (
"chatgpt/core" "chatgpt/core"
"chatgpt/util" "chatgpt/utils"
"net/http" "net/http"
) )
// main // main
func main() { func main() {
util.Conf = &util.Config{}
http.HandleFunc("/", core.Handler) http.HandleFunc("/", core.Handler)
addressArr := util.GetListenAddress() addressArr := utils.GetListenAddress()
if addressArr[0] == "https" { if addressArr[0] == "https" {
certFile := util.Conf.Get("init", "cert_file") certFile := utils.GetConfig("init", "cert_file", "")
keyFile := util.Conf.Get("init", "key_file") keyFile := utils.GetConfig("init", "key_file", "")
http.ListenAndServeTLS(addressArr[1], certFile, keyFile, nil) http.ListenAndServeTLS(addressArr[1], certFile, keyFile, nil)
} else { } else {
http.ListenAndServe(addressArr[1], nil) http.ListenAndServe(addressArr[1], nil)

View File

@ -1,40 +0,0 @@
package util
import (
"github.com/go-ini/ini"
)
// Conf
// 全局对象
var Conf *Config
// instance
var instance *ini.File
// Config
type Config struct{}
// init
// 初始化
//
// @receiver c
// @return *ini.File
func (c *Config) init() *ini.File {
if instance == nil {
f, _ := ini.Load("config.ini")
instance = f
}
return instance
}
// Get
// 获取配置
//
// @receiver c
// @param section
// @param key
// @return string
func (c *Config) Get(section string, key string) string {
return c.init().Section(section).Key(key).String()
}

View File

@ -1,54 +0,0 @@
package util
import (
"crypto/md5"
"fmt"
"strings"
)
// GetAddress
// 获取监听地址
//
// @return []string
func GetListenAddress() []string {
address := Conf.Get("init", "addr")
addressArr := strings.Split(address, "://")
if addressArr[0] == "https" {
return []string{addressArr[0], addressArr[1]}
}
var addr string
if len(addressArr) > 1 {
addr = addressArr[1]
} else {
addr = addressArr[0]
}
return []string{"http", addr}
}
// GetProxy
// 获取代理服务器
//
// @return string
func GetProxyServer() string {
proxy := Conf.Get("proxy", "addr")
if proxy == "" {
return ""
}
proxyArr := strings.Split(proxy, "://")
if proxyArr[0] != "http" && proxyArr[0] != "https" {
return "http://" + proxyArr[0]
}
return proxy
}
// Md5
//
// @param text
// @return string
func Md5(text string) string {
hash := md5.New()
hash.Write([]byte(text))
byteData := hash.Sum(nil)
return fmt.Sprintf("%x", byteData)
}

View File

@ -1,4 +1,4 @@
package util package utils
import "encoding/base64" import "encoding/base64"

100
utils/helper.go Normal file
View File

@ -0,0 +1,100 @@
package utils
import (
"crypto/md5"
"fmt"
"net"
"net/http"
"strings"
"github.com/go-ini/ini"
)
// instance
var instance *ini.File
// GetConfig
// 获取配置
//
// @param section
// @param key
// @return string
func GetConfig(section string, key string, defaultValue string) string {
if instance == nil {
f, _ := ini.Load("config.ini")
instance = f
}
v := instance.Section(section).Key(key).String()
if v == "" && defaultValue != "" {
return defaultValue
}
return v
}
// GetAddress
// 获取监听地址
//
// @return []string
func GetListenAddress() []string {
address := GetConfig("init", "addr", "")
addressArr := strings.Split(address, "://")
if addressArr[0] == "https" {
return []string{addressArr[0], addressArr[1]}
}
var addr string
if len(addressArr) > 1 {
addr = addressArr[1]
} else {
addr = addressArr[0]
}
return []string{"http", addr}
}
// GetProxy
// 获取代理服务器
//
// @return string
func GetProxyServer() string {
proxy := GetConfig("proxy", "addr", "")
if proxy == "" {
return ""
}
proxyArr := strings.Split(proxy, "://")
if proxyArr[0] != "http" && proxyArr[0] != "https" {
return "http://" + proxyArr[0]
}
return proxy
}
// Md5
//
// @param text
// @return string
func Md5(text string) string {
hash := md5.New()
hash.Write([]byte(text))
byteData := hash.Sum(nil)
return fmt.Sprintf("%x", byteData)
}
// GetClientIP
// 获取客户端IP
//
// @param r
// @return string
func GetClientIP(r *http.Request) string {
xForwardedFor := r.Header.Get("X-Forwarded-For")
ip := strings.TrimSpace(strings.Split(xForwardedFor, ",")[0])
if ip != "" {
return ip
}
ip = strings.TrimSpace(r.Header.Get("X-Real-Ip"))
if ip != "" {
return ip
}
if ip, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)); err == nil {
return ip
}
return ""
}