WIP: Update struct decode and encode #2

Draft
Sirherobrine23 wants to merge 13 commits from struct-serealize into main
18 changed files with 487 additions and 553 deletions
Showing only changes of commit 94b98563ee - Show all commits

15
.github/workflows/test.yaml vendored Normal file
View File

@ -0,0 +1,15 @@
name: Test
on:
pull_request:
jobs:
Test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
check-latest: true
- name: Go test
run: go test ./...

View File

@ -84,12 +84,11 @@ type toWr struct {
func (t toWr) Write(w []byte) (int, error) { func (t toWr) Write(w []byte) (int, error) {
err := structcode.NewEncode(t.tun.Conn, proto.Request{ err := structcode.NewEncode(t.tun.Conn, proto.Request{
DataTX: &proto.ClientData{ DataTX: &proto.ClientData{
Data: w,
Client: proto.Client{ Client: proto.Client{
Client: t.To, Client: t.To,
Proto: t.Proto, Proto: t.Proto,
}, },
Size: uint64(len(w)),
Data: w[:],
}, },
}) })
if err == nil { if err == nil {

View File

@ -5,19 +5,55 @@ import (
"encoding/binary" "encoding/binary"
"io" "io"
"reflect" "reflect"
"time"
) )
func readBuff(r io.Reader) ([]byte, error) {
size := uint32(0)
if err := binary.Read(r, binary.BigEndian, &size); err != nil {
return nil, err
}
buff := make([]byte, size)
_, err := r.Read(buff)
return buff, err
}
func decodeTypeof(r io.Reader, reflectValue reflect.Value) (bool, error) {
var data any
npoint := reflect.New(reflectValue.Type())
typeof := npoint.Type()
switch {
default:
return false, nil
case typeof.Implements(typeofBinUnmarshal), typeof.ConvertibleTo(typeofBinUnmarshal):
buff, err := readBuff(r)
if err != nil {
return true, err
}
t := npoint.Interface()
if err := t.(encoding.BinaryUnmarshaler).UnmarshalBinary(buff); err != nil {
return true, err
}
data = t
case typeof.Implements(typeofTextUnmarshal), typeof.ConvertibleTo(typeofTextUnmarshal):
buff, err := readBuff(r)
if err != nil {
return true, err
}
t := npoint.Interface()
if err := t.(encoding.TextUnmarshaler).UnmarshalText(buff); err != nil {
return true, err
}
data = t
}
reflectValue.Set(reflect.ValueOf(data).Elem())
return true, nil
}
func decodeRecursive(r io.Reader, reflectValue reflect.Value) error { func decodeRecursive(r io.Reader, reflectValue reflect.Value) error {
switch reflectValue.Type().Kind() { switch reflectValue.Type().Kind() {
case reflect.Interface:
case reflect.String: case reflect.String:
size := int64(0) buff, err := readBuff(r)
if err := binary.Read(r, binary.BigEndian, &size); err != nil { if err != nil {
return err
}
buff := make([]byte, size)
if _, err := r.Read(buff); err != nil {
return err return err
} }
reflectValue.SetString(string(buff)) reflectValue.SetString(string(buff))
@ -27,55 +63,28 @@ func decodeRecursive(r io.Reader, reflectValue reflect.Value) error {
return err return err
} }
reflectValue.Set(reflect.ValueOf(data).Elem()) reflectValue.Set(reflect.ValueOf(data).Elem())
case reflect.Interface:
case reflect.Struct: case reflect.Struct:
if reflectValue.Type().ConvertibleTo(typeofTimer) || reflectValue.Type().Implements(typeofBinUnmarshal) || reflectValue.Type().Implements(typeofTextUnmarshal) { if ok, err := decodeTypeof(r, reflectValue); ok {
size := int64(0) return err
if err := binary.Read(r, binary.BigEndian, &size); err != nil {
return err
}
buff := make([]byte, size)
if _, err := r.Read(buff); err != nil {
return err
} else if reflectValue.Type().ConvertibleTo(typeofTimer) {
ttime := reflectValue.Interface().(time.Time)
if err := ttime.UnmarshalBinary(buff); err != nil {
return err
}
reflectValue.Set(reflect.ValueOf(ttime))
return nil
} else if reflectValue.Type().Implements(typeofBinUnmarshal) {
data := reflectValue.Interface().(encoding.BinaryUnmarshaler)
if err := data.UnmarshalBinary(buff); err != nil {
return err
}
reflectValue.Set(reflect.ValueOf(data))
return nil
}
data := reflectValue.Interface().(encoding.TextUnmarshaler)
if err := data.UnmarshalText(buff); err != nil {
return err
}
reflectValue.Set(reflect.ValueOf(data))
return nil
} }
typeof := reflectValue.Type() typeof := reflectValue.Type()
for fieldIndex := range typeof.NumField() { for fieldIndex := range typeof.NumField() {
if typeof.Field(fieldIndex).Tag.Get(selectorTagName) == "-" || !typeof.Field(fieldIndex).IsExported() { fieldType := typeof.Field(fieldIndex)
if fieldType.Tag.Get(selectorTagName) == "-" || !fieldType.IsExported() {
continue continue
} else if err := decodeRecursive(r, reflectValue.Field(fieldIndex)); err != nil { } else if err := decodeRecursive(r, reflectValue.Field(fieldIndex)); err != nil {
return err return err
} }
} }
case reflect.Pointer: case reflect.Pointer:
read := int8(0) var read bool
if err := binary.Read(r, binary.BigEndian, &read); err != nil { if err := binary.Read(r, binary.BigEndian, &read); err != nil {
return err return err
} else if read == 0 { } else if read {
return nil reflectValue.Set(reflect.New(reflectValue.Type().Elem()))
return decodeRecursive(r, reflectValue.Elem())
} }
reflectValue.Set(reflect.New(reflectValue.Type().Elem()))
return decodeRecursive(r, reflectValue.Elem())
case reflect.Array: case reflect.Array:
for arrIndex := range reflectValue.Len() { for arrIndex := range reflectValue.Len() {
if err := decodeRecursive(r, reflectValue.Index(arrIndex)); err != nil { if err := decodeRecursive(r, reflectValue.Index(arrIndex)); err != nil {
@ -83,24 +92,25 @@ func decodeRecursive(r io.Reader, reflectValue reflect.Value) error {
} }
} }
case reflect.Slice: case reflect.Slice:
size := int64(0) if reflectValue.Type().ConvertibleTo(typeofBytes) {
if err := binary.Read(r, binary.BigEndian, &size); err != nil { buff, err := readBuff(r)
return err if err != nil {
} else if reflectValue.Type().Elem().Kind() == typeofByte.Kind() {
buff := make([]byte, size)
if _, err = r.Read(buff); err != nil {
return err return err
} }
reflectValue.SetBytes(buff) reflectValue.SetBytes(buff)
} else { return nil
typeof := reflectValue.Type().Elem() }
for range size { size := int64(0)
newData := reflect.New(typeof) if err := binary.Read(r, binary.BigEndian, &size); err != nil {
if err := decodeRecursive(r, newData); err != nil { return err
return err }
} typeof := reflectValue.Type().Elem()
reflectValue.Set(reflect.AppendSlice(reflectValue, newData.Elem())) for range size {
newData := reflect.New(typeof)
if err := decodeRecursive(r, newData); err != nil {
return err
} }
reflectValue.Set(reflect.AppendSlice(reflectValue, newData.Elem()))
} }
} }
return nil return nil

View File

@ -7,37 +7,46 @@ import (
"reflect" "reflect"
) )
func writeBuff(w io.Writer, buff []byte) error {
if err := binary.Write(w, binary.BigEndian, uint32(len(buff))); err != nil {
return err
}
_, err := w.Write(buff)
return err
}
func encodeTypeof(w io.Writer, reflectValue reflect.Value) (bool, error) {
var err error = nil
var data []byte
switch {
default:
return false, nil
case reflectValue.Type().Implements(typeofBinMarshal), reflectValue.Type().ConvertibleTo(typeofBinMarshal):
data, err = reflectValue.Interface().(encoding.BinaryMarshaler).MarshalBinary()
case reflectValue.Type().Implements(typeofTextMarshal), reflectValue.Type().ConvertibleTo(typeofTextMarshal):
data, err = reflectValue.Interface().(encoding.TextMarshaler).MarshalText()
}
if err == nil {
err = writeBuff(w, data)
}
return true, err
}
func encodeRecursive(w io.Writer, reflectValue reflect.Value) error { func encodeRecursive(w io.Writer, reflectValue reflect.Value) error {
switch reflectValue.Type().Kind() { switch reflectValue.Type().Kind() {
case reflect.Interface:
case reflect.String: case reflect.String:
str := reflectValue.String() return writeBuff(w, []byte(reflectValue.String()))
if err := binary.Write(w, binary.BigEndian, int64(len(str))); err != nil {
return err
} else if _, err := w.Write([]byte(str)); err != nil {
return err
}
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return binary.Write(w, binary.BigEndian, reflectValue.Interface()) return binary.Write(w, binary.BigEndian, reflectValue.Interface())
case reflect.Interface:
case reflect.Struct: case reflect.Struct:
if reflectValue.Type().Implements(typeofBinMarshal) || reflectValue.Type().Implements(typeofTextMarshal) { if ok, err := encodeTypeof(w, reflectValue); ok {
var err error
var data []byte
if reflectValue.Type().Implements(typeofBinMarshal) {
data, err = reflectValue.Interface().(encoding.BinaryMarshaler).MarshalBinary()
} else {
data, err = reflectValue.Interface().(encoding.TextMarshaler).MarshalText()
}
if err == nil {
if err = binary.Write(w, binary.BigEndian, int64(len(data))); err == nil {
_, err = w.Write(data)
}
}
return err return err
} }
typeof := reflectValue.Type() typeof := reflectValue.Type()
for fieldIndex := range typeof.NumField() { for fieldIndex := range typeof.NumField() {
if typeof.Field(fieldIndex).Tag.Get(selectorTagName) == "-" || !typeof.Field(fieldIndex).IsExported() { fieldType := typeof.Field(fieldIndex)
if fieldType.Tag.Get(selectorTagName) == "-" || !fieldType.IsExported() {
continue continue
} else if err := encodeRecursive(w, reflectValue.Field(fieldIndex)); err != nil { } else if err := encodeRecursive(w, reflectValue.Field(fieldIndex)); err != nil {
return err return err
@ -45,8 +54,8 @@ func encodeRecursive(w io.Writer, reflectValue reflect.Value) error {
} }
case reflect.Pointer: case reflect.Pointer:
if reflectValue.IsNil() || reflectValue.IsZero() { if reflectValue.IsNil() || reflectValue.IsZero() {
return binary.Write(w, binary.BigEndian, int8(0)) return binary.Write(w, binary.BigEndian, false)
} else if err := binary.Write(w, binary.BigEndian, int8(1)); err != nil { } else if err := binary.Write(w, binary.BigEndian, true); err != nil {
return err return err
} }
return encodeRecursive(w, reflectValue.Elem()) return encodeRecursive(w, reflectValue.Elem())
@ -57,9 +66,11 @@ func encodeRecursive(w io.Writer, reflectValue reflect.Value) error {
} }
} }
case reflect.Slice: case reflect.Slice:
if err := binary.Write(w, binary.BigEndian, int64(reflectValue.Len())); err != nil { if reflectValue.Type().ConvertibleTo(typeofBytes) {
return writeBuff(w, reflectValue.Bytes())
} else if err := binary.Write(w, binary.BigEndian, int64(reflectValue.Len())); err != nil {
return err return err
} else if reflectValue.Type().Elem().Kind() == typeofByte.Kind() { } else if reflectValue.Type().Elem().Kind() == typeofBytes.Elem().Kind() {
_, err = w.Write(reflectValue.Bytes()) _, err = w.Write(reflectValue.Bytes())
return err return err
} }

View File

@ -21,20 +21,17 @@ import (
"fmt" "fmt"
"io" "io"
"reflect" "reflect"
"time"
) )
const selectorTagName = "ser" const selectorTagName = "ser"
var ( var (
typeofTimer = reflect.TypeFor[time.Time]() typeofBytes = reflect.TypeFor[[]byte]()
typeofBytes = reflect.TypeOf([]byte{})
typeofByte = typeofBytes.Elem()
typeofBinMarshal = reflect.TypeFor[encoding.BinaryMarshaler]()
typeofTextMarshal = reflect.TypeFor[encoding.TextMarshaler]()
typeofBinUnmarshal = reflect.TypeFor[encoding.BinaryUnmarshaler]()
typeofTextUnmarshal = reflect.TypeFor[encoding.TextUnmarshaler]() typeofTextUnmarshal = reflect.TypeFor[encoding.TextUnmarshaler]()
typeofBinUnmarshal = reflect.TypeFor[encoding.BinaryUnmarshaler]()
typeofTextMarshal = reflect.TypeFor[encoding.TextMarshaler]()
typeofBinMarshal = reflect.TypeFor[encoding.BinaryMarshaler]()
) )
func NewEncode(w io.Writer, target any) error { func NewEncode(w io.Writer, target any) error {

View File

@ -2,51 +2,109 @@ package structcode
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/hex"
"io"
"net/netip"
"sync"
"testing" "testing"
"time" "time"
"sirherobrine23.org/Minecraft-Server/go-pproxit/proto"
) )
type structTest struct { func TestSerelelize(t *testing.T) {
Text string t.Run("Response", func(t *testing.T) {
Int8 int8 var err error
Int16 int16 var encodeRes, decodeRes proto.Response
Int32 int32 encodeRes.BadRequest = true
Int64 int64 encodeRes.NotListened = true
Date time.Time encodeRes.SendAuth = true
Bytes []byte encodeRes.AgentInfo = &proto.AgentInfo{
ArrayBytes [4]byte Protocol: 1,
UDPPort: 2555,
TCPPort: 3000,
AddrPort: netip.MustParseAddrPort("[::]:10000"),
}
Pointer *structTest var waiter sync.WaitGroup
Pointer2 any waiter.Add(2)
} r, w := io.Pipe()
go func() {
func TestDeSerelelize(t *testing.T) { defer waiter.Done()
var decodeTest structTest if err = NewDecode(r, &decodeRes); err != nil {
var testData structTest t.Error(err)
testData.Text = "Golang is best" return
testData.Int8 = 2 }
testData.Int16 = 200 }()
testData.Int32 = 1_000_000 go func() {
testData.Int64 = 1024 * 12 * 12 defer waiter.Done()
testData.Bytes = []byte("google maintener go") if err = NewEncode(w, encodeRes); err != nil {
testData.ArrayBytes = [4]byte{0, 1, 1, 1} t.Error(err)
testData.Date = time.Now() return
testData.Pointer = &structTest{ }
Text: "Golang", }()
} waiter.Wait()
if err != nil {
buff := new(bytes.Buffer) return
if err := NewEncode(buff, testData); err != nil { } else if decodeRes.BadRequest != encodeRes.BadRequest {
t.Error(err) t.Errorf("invalid decode/encode, Current values to BadRequest, Decode %v, Encode %v", decodeRes.BadRequest, encodeRes.BadRequest)
return return
} } else if decodeRes.NotListened != encodeRes.NotListened {
t.Log(buff.Bytes()) t.Errorf("invalid decode/encode, Current values to NotListened, Decode %v, Encode %v", decodeRes.NotListened, encodeRes.NotListened)
if err := NewDecode(buff, &decodeTest); err != nil { return
t.Error(err) } else if decodeRes.SendAuth != encodeRes.SendAuth {
return t.Errorf("invalid decode/encode, Current values to SendAuth, Decode %v, Encode %v", decodeRes.SendAuth, encodeRes.SendAuth)
} return
d, _ := json.MarshalIndent(decodeTest, "", " ") } else if decodeRes.AgentInfo == nil {
t.Log(string(d)) t.Errorf("invalid decode, Current values to AgentInfo, Decode %+v, Encode %+v", decodeRes.AgentInfo, encodeRes.AgentInfo)
return
} else if decodeRes.AgentInfo.Protocol != encodeRes.AgentInfo.Protocol {
t.Errorf("invalid decode/encode, Current values to AgentInfo.Protocol, Decode %d, Encode %d", decodeRes.AgentInfo.Protocol, encodeRes.AgentInfo.Protocol)
return
} else if decodeRes.AgentInfo.TCPPort != encodeRes.AgentInfo.TCPPort {
t.Errorf("invalid decode/encode, Current values to AgentInfo.TCPPort, Decode %d, Encode %d", decodeRes.AgentInfo.TCPPort, encodeRes.AgentInfo.TCPPort)
} else if decodeRes.AgentInfo.UDPPort != encodeRes.AgentInfo.UDPPort {
t.Errorf("invalid decode/encode, Current values to AgentInfo.UDPPort, Decode %d, Encode %d", decodeRes.AgentInfo.UDPPort, encodeRes.AgentInfo.UDPPort)
return
} else if decodeRes.AgentInfo.AddrPort.Compare(encodeRes.AgentInfo.AddrPort) != 0 {
t.Errorf("invalid decode/encode, Current values to AgentInfo.AddrPort, Decode %s, Encode %s", decodeRes.AgentInfo.AddrPort, encodeRes.AgentInfo.AddrPort)
return
}
})
t.Run("Request", func(t *testing.T) {
var err error
var encodeRequest, decodeRequest proto.Request
encodeRequest.AgentAuth = &[]byte{0, 0, 1, 1, 1, 1, 1, 0, 255}
encodeRequest.Ping = new(time.Time)
*encodeRequest.Ping = time.Now()
var waiter sync.WaitGroup
waiter.Add(2)
r, w := io.Pipe()
go func() {
defer waiter.Done()
if err = NewEncode(w, encodeRequest); err != nil {
t.Error(err)
return
}
}()
go func() {
defer waiter.Done()
if err = NewDecode(r, &decodeRequest); err != nil {
t.Error(err)
return
}
}()
waiter.Wait()
if err != nil {
return
} else if decodeRequest.Ping.Unix() != encodeRequest.Ping.Unix() {
t.Errorf("cannot decode/encode Ping date, Decode %d, Encode: %d", decodeRequest.Ping.Unix(), encodeRequest.Ping.Unix())
return
} else if !bytes.Equal(*decodeRequest.AgentAuth, *encodeRequest.AgentAuth) {
t.Errorf("cannot decode/encode auth data, Decode %q, Encode: %q", hex.EncodeToString(*decodeRequest.AgentAuth), hex.EncodeToString(*encodeRequest.AgentAuth))
return
}
})
} }

View File

@ -26,6 +26,5 @@ type Client struct {
type ClientData struct { type ClientData struct {
Client Client // Client Destination Client Client // Client Destination
Size uint64 // Data size
Data []byte `json:"-"` // Bytes to send Data []byte `json:"-"` // Bytes to send
} }

View File

@ -12,9 +12,7 @@ import (
"sirherobrine23.org/Minecraft-Server/go-pproxit/proto" "sirherobrine23.org/Minecraft-Server/go-pproxit/proto"
) )
var ( var ErrAuthAgentFail error = errors.New("cannot authenticate agent") // Send unathorized client and close new accepts from current port
ErrAuthAgentFail error = errors.New("cannot authenticate agent") // Send unathorized client and close new accepts from current port
)
type ServerCall interface { type ServerCall interface {
// Authenticate agents // Authenticate agents

View File

@ -1,9 +1,9 @@
package server package server
import ( import (
"encoding/json"
"fmt" "fmt"
"io" "io"
"log"
"net" "net"
"net/netip" "net/netip"
"os" "os"
@ -74,12 +74,11 @@ func (t toWr) Write(w []byte) (int, error) {
go t.tun.TunInfo.Callbacks.RegisterRX(t.To, len(w), t.Proto) go t.tun.TunInfo.Callbacks.RegisterRX(t.To, len(w), t.Proto)
err := t.tun.send(proto.Response{ err := t.tun.send(proto.Response{
DataRX: &proto.ClientData{ DataRX: &proto.ClientData{
Data: w,
Client: proto.Client{ Client: proto.Client{
Proto: t.Proto, Proto: t.Proto,
Client: t.To, Client: t.To,
}, },
Size: uint64(len(w)),
Data: w[:],
}, },
}) })
if err == nil { if err == nil {
@ -120,12 +119,14 @@ func (tun *Tunnel) Setup() {
}) })
for { for {
log.Printf("waiting request from %s", tun.RootConn.RemoteAddr().String())
var req proto.Request var req proto.Request
if err := structcode.NewDecode(tun.RootConn, &req); err != nil { if err := structcode.NewDecode(tun.RootConn, &req); err != nil {
fmt.Fprintln(os.Stderr, err.Error()) fmt.Fprintln(os.Stderr, err.Error())
return return
} }
d, _ := json.MarshalIndent(req, "", " ")
fmt.Println(string(d))
if req.AgentAuth != nil { if req.AgentAuth != nil {
go tun.send(proto.Response{ go tun.send(proto.Response{
AgentInfo: &proto.AgentInfo{ AgentInfo: &proto.AgentInfo{
@ -151,7 +152,7 @@ func (tun *Tunnel) Setup() {
} }
} }
} else if data := req.DataTX; req.DataTX != nil { } else if data := req.DataTX; req.DataTX != nil {
go tun.TunInfo.Callbacks.RegisterTX(data.Client.Client, int(data.Size), data.Client.Proto) go tun.TunInfo.Callbacks.RegisterTX(data.Client.Client, len(data.Data), data.Client.Proto)
if data.Client.Proto == proto.ProtoTCP { if data.Client.Proto == proto.ProtoTCP {
if cl, ok := tun.TCPClients[data.Client.Client.String()]; ok { if cl, ok := tun.TCPClients[data.Client.Client.String()]; ok {
go cl.Write(data.Data) // Process in backgroud go cl.Write(data.Data) // Process in backgroud