From 138728ad341f53a3afbe470b95f07758805b5284 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?TOP=E7=B3=AF=E7=B1=B3?= <1130395124@qq.com> Date: Mon, 27 Mar 2023 20:36:07 +0800 Subject: [PATCH] =?UTF-8?q?=E9=A6=96=E6=AC=A1=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.ini | 16 ++++++ core/gpt.go | 131 ++++++++++++++++++++++++++++++++++++++++++++++++ core/request.go | 120 ++++++++++++++++++++++++++++++++++++++++++++ go.mod | 7 +++ go.sum | 18 +++++++ main.go | 21 ++++++++ utils/base64.go | 25 +++++++++ utils/config.go | 40 +++++++++++++++ utils/helper.go | 92 ++++++++++++++++++++++++++++++++++ 9 files changed, 470 insertions(+) create mode 100644 config.ini create mode 100644 core/gpt.go create mode 100644 core/request.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 utils/base64.go create mode 100644 utils/config.go create mode 100644 utils/helper.go diff --git a/config.ini b/config.ini new file mode 100644 index 0000000..76470f5 --- /dev/null +++ b/config.ini @@ -0,0 +1,16 @@ +[init] +; addr = http://0.0.0.0:8888 +; 密钥 +; private_key = +; SSL配置 +; cert_file = +; key_file = +; 超时,默认10 +; timeout = 10 + +[proxy] +; 代理服务器 +; addr = http://127.0.0.1:8080 + +[gpt] +; api_key = diff --git a/core/gpt.go b/core/gpt.go new file mode 100644 index 0000000..83f31ba --- /dev/null +++ b/core/gpt.go @@ -0,0 +1,131 @@ +package core + +import ( + "crypto/tls" + "encoding/json" + "io/ioutil" + "net/http" + "net/url" + "strings" + "time" +) + +// messages +var messages []Message + +// Message +// 单个消息 +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// Payload +// 请求数据 +type Payload struct { + Model string `json:"model"` + Messages []Message `json:"messages"` +} + +// AnswerItem +// 回答数据 +type AnswerItem struct { + FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message Message `json:"message"` +} + +// Answer +// 回答 +type Answer struct { + Id string `json:"id"` + Object int `json:"object"` + Created interface{} `json:"created"` + Model string `json:"model"` + Usage interface{} `json:"usage"` + Choices []AnswerItem `json:"choices"` +} + +// GPT +type GPT struct { + ApiKey string + Proxy string + Timeout int +} + +// BuildPayload +// 构建请求数据 +// +// @receiver g +// @param question +// @return string +func (g *GPT) BuildPayload(question string) string { + messages = append(messages, Message{ + Role: "user", + Content: question, + }) + + payload := Payload{ + Model: "gpt-3.5-turbo", + Messages: messages, + } + + str, _ := json.Marshal(&payload) + return string(str) +} + +// HttpClientWithProxy +// 返回Http客户端对象 +// +// @receiver g +// @param proxy +// @return *http.Client +func (g *GPT) HttpClientWithProxy(proxy string) *http.Client { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + if proxy != "" { + proxyAddr, _ := url.Parse(proxy) + transport.Proxy = http.ProxyURL(proxyAddr) + } + + timeout := 10 + if g.Timeout > 0 { + timeout = g.Timeout + } + + return &http.Client{ + Timeout: time.Duration(timeout) * time.Second, + Transport: transport, + } +} + +// GetAnswer +// 发起请求 +// +// @receiver g +// @param question +// @return Answer +func (g *GPT) GetAnswer(question string) (Answer, error) { + payload := g.BuildPayload(question) + + api := "https://api.openai.com/v1/chat/completions" + req, _ := http.NewRequest("POST", api, strings.NewReader(payload)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+g.ApiKey) + + asr := Answer{} + + client := g.HttpClientWithProxy(g.Proxy) + resp, err := client.Do(req) + if err != nil { + return asr, err + } + body, _ := ioutil.ReadAll(resp.Body) + + json.Unmarshal(body, &asr) + + return asr, nil +} diff --git a/core/request.go b/core/request.go new file mode 100644 index 0000000..fbf0e18 --- /dev/null +++ b/core/request.go @@ -0,0 +1,120 @@ +package core + +import ( + "chatgpt/utils" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "strconv" + "sync" +) + +// RequestData +type RequestData struct { + Words string `json:"words"` + Time int `json:"time"` + Sign string `json:"sign"` +} + +// Response +// 响应体 +type Response struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data string `json:"data"` + Sign string `json:"sign"` +} + +// ResponseData +// 响应数据 +type ResponseData map[string]interface{} + +// BuildResponse +// 构建响应体 +// +// @param privateKey +// @param code +// @param msg +// @param data +// @return string +func BuildResponse(privateKey string, code int, msg string, data ResponseData) string { + resp := Response{ + Code: code, + Msg: msg, + } + dataJson, _ := json.Marshal(data) + resp.Data = utils.Base64{ + Content: []byte(dataJson), + }.Encode() + if privateKey != "" { + resp.Sign = utils.MakeSign(resp, privateKey) + } + + respJson, _ := json.Marshal(resp) + + return string(respJson) +} + +var wg sync.WaitGroup + +// proxyAddr +// 代理配置 +var proxyAddr string = utils.GetProxyServer() + +// apiKey +// GPT Api Key +var apiKey string = utils.Conf.Get("gpt", "api_key") + +// privateKey +// 签名密钥 +var privateKey string = utils.Conf.Get("init", "private_key") + +// TimeoutValue +// 超时配置 +var TimeoutValue string = utils.Conf.Get("init", "timeout") + +// Action +// 处理请求 +// +// @param w +// @param r +func Handler(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Access-Control-Allow-Origin", "*") + w.Header().Add("Access-Control-Allow-Methods", "POST, GET, PUT, DELETE, OPTIONS") + w.Header().Add("Content-Type", "application/json") + Timeout, _ := strconv.Atoi(TimeoutValue) + if r.Method == "POST" { + var requestData RequestData + raw, _ := ioutil.ReadAll(r.Body) + json.Unmarshal(raw, &requestData) + + if privateKey != "" && requestData.Sign != utils.MakeSign(requestData, privateKey) { + fmt.Fprintln(w, BuildResponse(privateKey, 0, "sign error", ResponseData{})) + return + } + + question := utils.Base64{ + Content: []byte(requestData.Words), + }.Decode() + if len(question) > 0 { + gpt := &GPT{ + ApiKey: apiKey, + Proxy: proxyAddr, + Timeout: Timeout, + } + wg.Add(1) + go func() { + defer wg.Done() + if result, err := gpt.GetAnswer(question); err == nil { + fmt.Fprintln(w, BuildResponse(privateKey, 1, "success", ResponseData{ + "answer": result.Choices, + })) + } + }() + wg.Wait() + return + } + } + fmt.Fprintln(w, BuildResponse(privateKey, 0, "error", ResponseData{})) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..e25ecee --- /dev/null +++ b/go.mod @@ -0,0 +1,7 @@ +module chatgpt + +go 1.19 + +require github.com/go-ini/ini v1.67.0 + +require github.com/stretchr/testify v1.8.1 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..9fc91af --- /dev/null +++ b/go.sum @@ -0,0 +1,18 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= +github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go new file mode 100644 index 0000000..1b553eb --- /dev/null +++ b/main.go @@ -0,0 +1,21 @@ +package main + +import ( + "chatgpt/core" + "chatgpt/utils" + "net/http" +) + +// main +func main() { + utils.Conf = &utils.Config{} + http.HandleFunc("/", core.Handler) + addressArr := utils.GetListenAddress() + if addressArr[0] == "https" { + certFile := utils.Conf.Get("init", "cert_file") + keyFile := utils.Conf.Get("init", "key_file") + http.ListenAndServeTLS(addressArr[1], certFile, keyFile, nil) + } else { + http.ListenAndServe(addressArr[1], nil) + } +} diff --git a/utils/base64.go b/utils/base64.go new file mode 100644 index 0000000..e5e34b4 --- /dev/null +++ b/utils/base64.go @@ -0,0 +1,25 @@ +package utils + +import "encoding/base64" + +// Base64 +type Base64 struct { + Content []byte +} + +// Encode +// 编码 +// @receiver q +// @return string +func (q Base64) Encode() string { + return base64.StdEncoding.EncodeToString(q.Content) +} + +// Decode +// 解码 +// @receiver q +// @return string +func (q Base64) Decode() string { + res, _ := base64.StdEncoding.DecodeString(string(q.Content)) + return string(res) +} diff --git a/utils/config.go b/utils/config.go new file mode 100644 index 0000000..a7a1404 --- /dev/null +++ b/utils/config.go @@ -0,0 +1,40 @@ +package utils + +import ( + "github.com/go-ini/ini" +) + +// Conf +// 全局对象 +var Conf *Config + +// instance +var instance *ini.File + +// Config +type Config struct{} + +// init +// 初始化 +// +// @receiver c +// @return *ini.File +func (c *Config) init() *ini.File { + if instance == nil { + f, _ := ini.Load("config.ini") + instance = f + } + + return instance +} + +// Get +// 获取配置 +// +// @receiver c +// @param section +// @param key +// @return string +func (c *Config) Get(section string, key string) string { + return c.init().Section(section).Key(key).String() +} diff --git a/utils/helper.go b/utils/helper.go new file mode 100644 index 0000000..6a164ea --- /dev/null +++ b/utils/helper.go @@ -0,0 +1,92 @@ +package utils + +import ( + "crypto/md5" + "fmt" + "reflect" + "sort" + "strings" +) + +// GetAddress +// 获取监听地址 +// +// @return []string +func GetListenAddress() []string { + address := Conf.Get("init", "addr") + addressArr := strings.Split(address, "://") + if addressArr[0] == "https" { + return []string{addressArr[0], addressArr[1]} + } + var addr string + if len(addressArr) > 1 { + addr = addressArr[1] + } else { + addr = addressArr[0] + } + return []string{"http", addr} +} + +// GetProxy +// 获取代理服务器 +// +// @return string +func GetProxyServer() string { + proxy := Conf.Get("proxy", "addr") + if proxy == "" { + return "" + } + proxyArr := strings.Split(proxy, "://") + if proxyArr[0] != "http" && proxyArr[0] != "https" { + return "http://" + proxyArr[0] + } + + return proxy +} + +// Md5 +// +// @param text +// @return string +func Md5(text string) string { + hash := md5.New() + hash.Write([]byte(text)) + byteData := hash.Sum(nil) + return fmt.Sprintf("%x", byteData) +} + +// MakeSign +// 生成签名 +// +// @param obj +// @param privateKey +// @return string +func MakeSign(obj interface{}, privateKey string) string { + signmap := map[string]interface{}{} + valueOfObj := reflect.ValueOf(obj) + for i := 0; i < valueOfObj.NumField(); i++ { + signmap[valueOfObj.Type().Field(i).Name] = valueOfObj.Field(i) + } + + //进行键排序 + keys := make([]string, len(signmap)) + j := 0 + for k := range signmap { + keys[j] = k + j++ + } + + sort.Strings(keys) + //获取 value 值,拼接成一行 + str := "" + for _, k := range keys { + if strings.ToLower(k) == "sign" { + continue + } + if v, ok := signmap[k]; ok { + str = fmt.Sprintf("%s%s=%v&", str, strings.ToLower(k), v) + } + } + + return Md5(str + "key=" + privateKey) +}