首次提交

This commit is contained in:
TOP糯米 2023-03-27 20:36:07 +08:00
commit 138728ad34
9 changed files with 470 additions and 0 deletions

16
config.ini Normal file
View File

@ -0,0 +1,16 @@
[init]
; addr = http://0.0.0.0:8888
; 密钥
; private_key =
; SSL配置
; cert_file =
; key_file =
; 超时默认10
; timeout = 10
[proxy]
; 代理服务器
; addr = http://127.0.0.1:8080
[gpt]
; api_key =

131
core/gpt.go Normal file
View File

@ -0,0 +1,131 @@
package core
import (
"crypto/tls"
"encoding/json"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
)
// messages
var messages []Message
// Message
// 单个消息
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
// Payload
// 请求数据
type Payload struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
}
// AnswerItem
// 回答数据
type AnswerItem struct {
FinishReason string `json:"finish_reason"`
Index int `json:"index"`
Message Message `json:"message"`
}
// Answer
// 回答
type Answer struct {
Id string `json:"id"`
Object int `json:"object"`
Created interface{} `json:"created"`
Model string `json:"model"`
Usage interface{} `json:"usage"`
Choices []AnswerItem `json:"choices"`
}
// GPT
type GPT struct {
ApiKey string
Proxy string
Timeout int
}
// BuildPayload
// 构建请求数据
//
// @receiver g
// @param question
// @return string
func (g *GPT) BuildPayload(question string) string {
messages = append(messages, Message{
Role: "user",
Content: question,
})
payload := Payload{
Model: "gpt-3.5-turbo",
Messages: messages,
}
str, _ := json.Marshal(&payload)
return string(str)
}
// HttpClientWithProxy
// 返回Http客户端对象
//
// @receiver g
// @param proxy
// @return *http.Client
func (g *GPT) HttpClientWithProxy(proxy string) *http.Client {
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
if proxy != "" {
proxyAddr, _ := url.Parse(proxy)
transport.Proxy = http.ProxyURL(proxyAddr)
}
timeout := 10
if g.Timeout > 0 {
timeout = g.Timeout
}
return &http.Client{
Timeout: time.Duration(timeout) * time.Second,
Transport: transport,
}
}
// GetAnswer
// 发起请求
//
// @receiver g
// @param question
// @return Answer
func (g *GPT) GetAnswer(question string) (Answer, error) {
payload := g.BuildPayload(question)
api := "https://api.openai.com/v1/chat/completions"
req, _ := http.NewRequest("POST", api, strings.NewReader(payload))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+g.ApiKey)
asr := Answer{}
client := g.HttpClientWithProxy(g.Proxy)
resp, err := client.Do(req)
if err != nil {
return asr, err
}
body, _ := ioutil.ReadAll(resp.Body)
json.Unmarshal(body, &asr)
return asr, nil
}

120
core/request.go Normal file
View File

@ -0,0 +1,120 @@
package core
import (
"chatgpt/utils"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strconv"
"sync"
)
// RequestData
type RequestData struct {
Words string `json:"words"`
Time int `json:"time"`
Sign string `json:"sign"`
}
// Response
// 响应体
type Response struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data string `json:"data"`
Sign string `json:"sign"`
}
// ResponseData
// 响应数据
type ResponseData map[string]interface{}
// BuildResponse
// 构建响应体
//
// @param privateKey
// @param code
// @param msg
// @param data
// @return string
func BuildResponse(privateKey string, code int, msg string, data ResponseData) string {
resp := Response{
Code: code,
Msg: msg,
}
dataJson, _ := json.Marshal(data)
resp.Data = utils.Base64{
Content: []byte(dataJson),
}.Encode()
if privateKey != "" {
resp.Sign = utils.MakeSign(resp, privateKey)
}
respJson, _ := json.Marshal(resp)
return string(respJson)
}
var wg sync.WaitGroup
// proxyAddr
// 代理配置
var proxyAddr string = utils.GetProxyServer()
// apiKey
// GPT Api Key
var apiKey string = utils.Conf.Get("gpt", "api_key")
// privateKey
// 签名密钥
var privateKey string = utils.Conf.Get("init", "private_key")
// TimeoutValue
// 超时配置
var TimeoutValue string = utils.Conf.Get("init", "timeout")
// Action
// 处理请求
//
// @param w
// @param r
func Handler(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Methods", "POST, GET, PUT, DELETE, OPTIONS")
w.Header().Add("Content-Type", "application/json")
Timeout, _ := strconv.Atoi(TimeoutValue)
if r.Method == "POST" {
var requestData RequestData
raw, _ := ioutil.ReadAll(r.Body)
json.Unmarshal(raw, &requestData)
if privateKey != "" && requestData.Sign != utils.MakeSign(requestData, privateKey) {
fmt.Fprintln(w, BuildResponse(privateKey, 0, "sign error", ResponseData{}))
return
}
question := utils.Base64{
Content: []byte(requestData.Words),
}.Decode()
if len(question) > 0 {
gpt := &GPT{
ApiKey: apiKey,
Proxy: proxyAddr,
Timeout: Timeout,
}
wg.Add(1)
go func() {
defer wg.Done()
if result, err := gpt.GetAnswer(question); err == nil {
fmt.Fprintln(w, BuildResponse(privateKey, 1, "success", ResponseData{
"answer": result.Choices,
}))
}
}()
wg.Wait()
return
}
}
fmt.Fprintln(w, BuildResponse(privateKey, 0, "error", ResponseData{}))
}

7
go.mod Normal file
View File

@ -0,0 +1,7 @@
module chatgpt
go 1.19
require github.com/go-ini/ini v1.67.0
require github.com/stretchr/testify v1.8.1 // indirect

18
go.sum Normal file
View File

@ -0,0 +1,18 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

21
main.go Normal file
View File

@ -0,0 +1,21 @@
package main
import (
"chatgpt/core"
"chatgpt/utils"
"net/http"
)
// main
func main() {
utils.Conf = &utils.Config{}
http.HandleFunc("/", core.Handler)
addressArr := utils.GetListenAddress()
if addressArr[0] == "https" {
certFile := utils.Conf.Get("init", "cert_file")
keyFile := utils.Conf.Get("init", "key_file")
http.ListenAndServeTLS(addressArr[1], certFile, keyFile, nil)
} else {
http.ListenAndServe(addressArr[1], nil)
}
}

25
utils/base64.go Normal file
View File

@ -0,0 +1,25 @@
package utils
import "encoding/base64"
// Base64
type Base64 struct {
Content []byte
}
// Encode
// 编码
// @receiver q
// @return string
func (q Base64) Encode() string {
return base64.StdEncoding.EncodeToString(q.Content)
}
// Decode
// 解码
// @receiver q
// @return string
func (q Base64) Decode() string {
res, _ := base64.StdEncoding.DecodeString(string(q.Content))
return string(res)
}

40
utils/config.go Normal file
View File

@ -0,0 +1,40 @@
package utils
import (
"github.com/go-ini/ini"
)
// Conf
// 全局对象
var Conf *Config
// instance
var instance *ini.File
// Config
type Config struct{}
// init
// 初始化
//
// @receiver c
// @return *ini.File
func (c *Config) init() *ini.File {
if instance == nil {
f, _ := ini.Load("config.ini")
instance = f
}
return instance
}
// Get
// 获取配置
//
// @receiver c
// @param section
// @param key
// @return string
func (c *Config) Get(section string, key string) string {
return c.init().Section(section).Key(key).String()
}

92
utils/helper.go Normal file
View File

@ -0,0 +1,92 @@
package utils
import (
"crypto/md5"
"fmt"
"reflect"
"sort"
"strings"
)
// GetAddress
// 获取监听地址
//
// @return []string
func GetListenAddress() []string {
address := Conf.Get("init", "addr")
addressArr := strings.Split(address, "://")
if addressArr[0] == "https" {
return []string{addressArr[0], addressArr[1]}
}
var addr string
if len(addressArr) > 1 {
addr = addressArr[1]
} else {
addr = addressArr[0]
}
return []string{"http", addr}
}
// GetProxy
// 获取代理服务器
//
// @return string
func GetProxyServer() string {
proxy := Conf.Get("proxy", "addr")
if proxy == "" {
return ""
}
proxyArr := strings.Split(proxy, "://")
if proxyArr[0] != "http" && proxyArr[0] != "https" {
return "http://" + proxyArr[0]
}
return proxy
}
// Md5
//
// @param text
// @return string
func Md5(text string) string {
hash := md5.New()
hash.Write([]byte(text))
byteData := hash.Sum(nil)
return fmt.Sprintf("%x", byteData)
}
// MakeSign
// 生成签名
//
// @param obj
// @param privateKey
// @return string
func MakeSign(obj interface{}, privateKey string) string {
signmap := map[string]interface{}{}
valueOfObj := reflect.ValueOf(obj)
for i := 0; i < valueOfObj.NumField(); i++ {
signmap[valueOfObj.Type().Field(i).Name] = valueOfObj.Field(i)
}
//进行键排序
keys := make([]string, len(signmap))
j := 0
for k := range signmap {
keys[j] = k
j++
}
sort.Strings(keys)
//获取 value 值,拼接成一行
str := ""
for _, k := range keys {
if strings.ToLower(k) == "sign" {
continue
}
if v, ok := signmap[k]; ok {
str = fmt.Sprintf("%s%s=%v&", str, strings.ToLower(k), v)
}
}
return Md5(str + "key=" + privateKey)
}