diff --git a/unmarshal.go b/unmarshal.go index 10536f3..0301e99 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -243,19 +243,44 @@ func decRegister(e interface{}) (err error) { } decoderCache[t.Name()] = f case reflect.Map: - return errors.New("ssob: Unsupported type") - //f := func(e interface{}, in []byte) (n int, err error) { - //t := reflect.TypeOf(e) - //v := reflect.ValueOf(e) - //pos := 0 + f := func(e interface{}, in []byte) (n int, err error) { + t := reflect.TypeOf(e) + v := reflect.ValueOf(e) + pos := 0 - //if t.Kind() != reflect.Ptr || v.IsNil() { - //return 0, errors.New("ssob: Cannot unmarshal to nil pointer") - //} + if t.Kind() != reflect.Ptr || v.IsNil() { + return 0, errors.New("ssob: Cannot unmarshal to nil pointer") + } - //vi := reflect.Indirect(v) + l, r, err := UnmarshalInt32(in) + pos += r + if err != nil { + return pos, err + } - //} + vi := reflect.Indirect(v) + vt := reflect.TypeOf(vi.Interface()) + tk := vt.Key() + tv := vt.Elem() + for i := int32(0); i < l; i++ { + ek := reflect.New(tk) + ev := reflect.New(tv) + r, err := unmarshal(ek.Interface(), in[pos:]) + pos += r + if err != nil { + return pos, err + } + r, err = unmarshal(ev.Interface(), in[pos:]) + pos += r + if err != nil { + return pos, err + } + vi.SetMapIndex(reflect.Indirect(ek), reflect.Indirect(ev)) + } + + return pos, nil + } + decoderCache[string(t.Kind())] = f default: return errors.New("ssob: Unknown type " + string(v.Kind())) } diff --git a/unsafe_decoder.go b/unsafe_decoder.go new file mode 100644 index 0000000..f12c2ba --- /dev/null +++ b/unsafe_decoder.go @@ -0,0 +1,56 @@ +package ssob + +import ( + "encoding/binary" + "errors" + "io" + "reflect" + "sync" +) + +type UnsafeDecoder struct { + mutex sync.Mutex + r io.Reader +} + +func NewUnsafeDecoder(r io.Reader) *UnsafeDecoder { + dec := new(UnsafeDecoder) + dec.r = r + return dec +} + +func (dec *UnsafeDecoder) Decode(e interface{}) (err error) { + return dec.DecodeValue(reflect.ValueOf(e)) +} + +func (dec *UnsafeDecoder) DecodeValue(value reflect.Value) (err error) { + if value.Kind() == reflect.Invalid { + return errors.New("ssob: Cannot decode nil value") + } + + if value.Kind() == reflect.Ptr && value.IsNil() { + return errors.New("ssob: Cannot decode nil of type " + value.Type().String()) + } + + dec.mutex.Lock() + defer dec.mutex.Unlock() + + lb := make([]byte, 4) + err = binary.Read(dec.r, binary.BigEndian, lb) + if err != nil { + return err + } + l, n, err := UnsafeUnmarshalInt32(lb) + if err != nil || n != 4 { + return err + } + + bb := make([]byte, l) + err = binary.Read(dec.r, binary.BigEndian, bb) + if err != nil { + return err + } + + _, err = unsafeUnmarshal(value.Interface(), bb) + return err +} diff --git a/unsafe_unmarshal.go b/unsafe_unmarshal.go new file mode 100644 index 0000000..cf04438 --- /dev/null +++ b/unsafe_unmarshal.go @@ -0,0 +1,483 @@ +package ssob + +import ( + "encoding/binary" + "errors" + "reflect" + "unsafe" +) + +func UnsafeUnmarshalString(in []byte) (ret string, n int, err error) { + if len(in) < 4 { + return "", 0, errors.New("ssob: Invalid input to decode string") + } + l := int32(binary.BigEndian.Uint32(in)) + + if len(in[4:]) < int(l) { + return "", 0, errors.New("ssob: Invalid length of string") + } + return string(in[4 : l+4]), int(l) + 4, nil +} + +func UnsafeUnmarshalInt8(in []byte) (ret int8, n int, err error) { + if len(in) < 1 { + return 0, 0, errors.New("ssob: Invalid input to decode int8") + } + + return int8(in[0]), 1, nil +} + +func UnsafeUnmarshalUint8(in []byte) (ret uint8, n int, err error) { + if len(in) < 1 { + return 0, 0, errors.New("ssob: Invalid input to decode uint8") + } + + return uint8(in[0]), 1, nil +} + +func UnsafeUnmarshalInt16(in []byte) (ret int16, n int, err error) { + if len(in) < 2 { + return 0, 0, errors.New("ssob: Invalid input to decode int16") + } + + var out int16 + start := unsafe.Pointer(&out) + if littleEndian { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[0] + } else { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[0] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[1] + } + + return out, 2, nil +} + +func UnsafeUnmarshalUint16(in []byte) (ret uint16, n int, err error) { + if len(in) < 2 { + return 0, 0, errors.New("ssob: Invalid input to decode uint16") + } + + var out uint16 + start := unsafe.Pointer(&out) + if littleEndian { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[0] + } else { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[0] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[1] + } + + return out, 2, nil +} + +func UnsafeUnmarshalInt32(in []byte) (ret int32, n int, err error) { + if len(in) < 4 { + return 0, 0, errors.New("ssob: Invalid input to decode int32") + } + + var out int32 + start := unsafe.Pointer(&out) + if littleEndian { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[3] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[2] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(2))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(3))) = in[0] + } else { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[0] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(2))) = in[2] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(3))) = in[3] + } + + return out, 4, nil +} + +func UnsafeUnmarshalUint32(in []byte) (ret uint32, n int, err error) { + if len(in) < 4 { + return 0, 0, errors.New("ssob: Invalid input to decode uint32") + } + + var out uint32 + start := unsafe.Pointer(&out) + if littleEndian { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[3] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[2] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(2))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(3))) = in[0] + } else { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[0] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(2))) = in[2] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(3))) = in[3] + } + + return out, 4, nil +} + +func UnsafeUnmarshalFloat32(in []byte) (ret float32, n int, err error) { + if len(in) < 4 { + return 0, 0, errors.New("ssob: Invalid input to decode float32") + } + + var out float32 + start := unsafe.Pointer(&out) + if littleEndian { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[3] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[2] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(2))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(3))) = in[0] + } else { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[0] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(2))) = in[2] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(3))) = in[3] + } + + return out, 4, nil +} + +func UnsafeUnmarshalInt64(in []byte) (ret int64, n int, err error) { + if len(in) < 8 { + return 0, 0, errors.New("ssob: Invalid input to decode int64") + } + + var out int64 + start := unsafe.Pointer(&out) + if littleEndian { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[7] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[6] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(2))) = in[5] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(3))) = in[4] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(4))) = in[3] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(5))) = in[2] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(6))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(7))) = in[0] + } else { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[0] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(2))) = in[2] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(3))) = in[3] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(4))) = in[4] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(5))) = in[5] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(6))) = in[6] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(7))) = in[7] + } + + return out, 8, nil +} + +func UnsafeUnmarshalUint64(in []byte) (ret uint64, n int, err error) { + if len(in) < 8 { + return 0, 0, errors.New("ssob: Invalid input to decode uint64") + } + + var out uint64 + start := unsafe.Pointer(&out) + if littleEndian { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[7] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[6] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(2))) = in[5] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(3))) = in[4] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(4))) = in[3] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(5))) = in[2] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(6))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(7))) = in[0] + } else { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[0] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(2))) = in[2] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(3))) = in[3] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(4))) = in[4] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(5))) = in[5] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(6))) = in[6] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(7))) = in[7] + } + + return out, 8, nil +} + +func UnsafeUnmarshalFloat64(in []byte) (ret float64, n int, err error) { + if len(in) < 8 { + return 0, 0, errors.New("ssob: Invalid input to decode float64") + } + + var out float64 + start := unsafe.Pointer(&out) + if littleEndian { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[7] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[6] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(2))) = in[5] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(3))) = in[4] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(4))) = in[3] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(5))) = in[2] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(6))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(7))) = in[0] + } else { + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(0))) = in[0] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(1))) = in[1] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(2))) = in[2] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(3))) = in[3] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(4))) = in[4] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(5))) = in[5] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(6))) = in[6] + *(*byte)(unsafe.Pointer(uintptr(start) + sPtr*uintptr(7))) = in[7] + } + + return out, 8, nil +} + +var unsafeDecoderCache map[string]unmarshalFunc + +func init() { + unsafeDecoderCache = make(map[string]unmarshalFunc) + RegisterUnsafeDecoder("int8", func(e interface{}, in []byte) (n int, err error) { + if i, ok := e.(*int8); ok { + *i, n, err = UnsafeUnmarshalInt8(in) + return n, nil + } + return 0, errors.New("ssob: Incompatible type - expected int8") + }) + RegisterUnsafeDecoder("uint8", func(e interface{}, in []byte) (n int, err error) { + if i, ok := e.(*uint8); ok { + *i, n, err = UnsafeUnmarshalUint8(in) + return n, nil + } + return 0, errors.New("ssob: Incompatible type - expected uint8") + }) + RegisterUnsafeDecoder("int16", func(e interface{}, in []byte) (n int, err error) { + if i, ok := e.(*int16); ok { + *i, n, err = UnsafeUnmarshalInt16(in) + return n, nil + } + return 0, errors.New("ssob: Incompatible type - expected int16") + }) + RegisterUnsafeDecoder("uint16", func(e interface{}, in []byte) (n int, err error) { + if i, ok := e.(*uint16); ok { + *i, n, err = UnsafeUnmarshalUint16(in) + return n, nil + } + return 0, errors.New("ssob: Incompatible type - expected uint16") + }) + RegisterUnsafeDecoder("int32", func(e interface{}, in []byte) (n int, err error) { + if i, ok := e.(*int32); ok { + *i, n, err = UnsafeUnmarshalInt32(in) + return n, nil + } + return 0, errors.New("ssob: Incompatible type - expected int32") + }) + RegisterUnsafeDecoder("uint32", func(e interface{}, in []byte) (n int, err error) { + if i, ok := e.(*uint32); ok { + *i, n, err = UnsafeUnmarshalUint32(in) + return n, nil + } + return 0, errors.New("ssob: Incompatible type - expected uint32") + }) + RegisterUnsafeDecoder("float32", func(e interface{}, in []byte) (n int, err error) { + if i, ok := e.(*float32); ok { + *i, n, err = UnsafeUnmarshalFloat32(in) + return n, nil + } + return 0, errors.New("ssob: Incompatible type - expected float32") + }) + RegisterUnsafeDecoder("int64", func(e interface{}, in []byte) (n int, err error) { + if i, ok := e.(*int64); ok { + *i, n, err = UnsafeUnmarshalInt64(in) + return n, nil + } + return 0, errors.New("ssob: Incompatible type - expected int64") + }) + RegisterUnsafeDecoder("uint64", func(e interface{}, in []byte) (n int, err error) { + if i, ok := e.(*uint64); ok { + *i, n, err = UnsafeUnmarshalUint64(in) + return n, nil + } + return 0, errors.New("ssob: Incompatible type - expected uint64") + }) + RegisterUnsafeDecoder("float64", func(e interface{}, in []byte) (n int, err error) { + if i, ok := e.(*float64); ok { + *i, n, err = UnsafeUnmarshalFloat64(in) + return n, nil + } + return 0, errors.New("ssob: Incompatible type - expected float64") + }) + RegisterUnsafeDecoder("string", func(e interface{}, in []byte) (n int, err error) { + if i, ok := e.(*string); ok { + *i, n, err = UnsafeUnmarshalString(in) + return n, nil + } + return 0, errors.New("ssob: Incompatible type - expected string") + }) +} + +func decRegisterUnsafe(e interface{}) (err error) { + t := reflect.TypeOf(e) + v := reflect.ValueOf(e) + switch t.Kind() { + case reflect.Invalid: + return errors.New("ssob: Invalid type") + case reflect.Slice: + f := func(e interface{}, in []byte) (n int, err error) { + t := reflect.TypeOf(e) + v := reflect.ValueOf(e) + pos := 0 + + if t.Kind() != reflect.Ptr || v.IsNil() { + return 0, errors.New("ssob: Cannot unmarshal to nil pointer") + } + + l, r, err := UnsafeUnmarshalInt32(in) + pos += r + if err != nil { + return pos, err + } + + ti := v.Elem() + ti.Set(reflect.MakeSlice(reflect.TypeOf(ti.Interface()), 0, ti.Cap())) + for i := 0; i < int(l); i++ { + e := reflect.New(reflect.TypeOf(ti.Interface()).Elem()) + r, err := unsafeUnmarshal(e.Interface(), in[pos:]) + pos += r + if err != nil { + return pos, err + } + ti.Set(reflect.Append(ti, reflect.Indirect(e))) + } + return pos, nil + } + unsafeDecoderCache[string(t.Kind())] = f + case reflect.Struct: + f := func(e interface{}, in []byte) (n int, err error) { + t := reflect.TypeOf(e) + v := reflect.ValueOf(e) + pos := 0 + + if t.Kind() != reflect.Ptr || v.IsNil() { + return 0, errors.New("ssob: Cannot unmarshal to nil pointer") + } + + vi := reflect.Indirect(v) + l := vi.NumField() + for i := 0; i < l; i++ { + ni, err := unsafeUnmarshal(vi.Field(i).Addr().Interface(), in[pos:]) + pos += ni + if err != nil { + return pos, err + } + } + + return pos, nil + } + unsafeDecoderCache[t.Name()] = f + case reflect.Map: + f := func(e interface{}, in []byte) (n int, err error) { + t := reflect.TypeOf(e) + v := reflect.ValueOf(e) + pos := 0 + + if t.Kind() != reflect.Ptr || v.IsNil() { + return 0, errors.New("ssob: Cannot unmarshal to nil pointer") + } + + l, r, err := UnsafeUnmarshalInt32(in) + pos += r + if err != nil { + return pos, err + } + + vi := reflect.Indirect(v) + vt := reflect.TypeOf(vi.Interface()) + tk := vt.Key() + tv := vt.Elem() + for i := int32(0); i < l; i++ { + ek := reflect.New(tk) + ev := reflect.New(tv) + r, err := unsafeUnmarshal(ek.Interface(), in[pos:]) + pos += r + if err != nil { + return pos, err + } + r, err = unsafeUnmarshal(ev.Interface(), in[pos:]) + pos += r + if err != nil { + return pos, err + } + vi.SetMapIndex(reflect.Indirect(ek), reflect.Indirect(ev)) + } + + return pos, nil + } + unsafeDecoderCache[string(t.Kind())] = f + default: + return errors.New("ssob: Unknown type " + string(v.Kind())) + } + + return nil +} + +func unsafeUnmarshalBaseType(e interface{}, in []byte) (n int, err error) { + switch t := e.(type) { + case *int8: + return unsafeDecoderCache["int8"](t, in) + case *uint8: + return unsafeDecoderCache["uint8"](t, in) + case *int16: + return unsafeDecoderCache["int16"](t, in) + case *uint16: + return unsafeDecoderCache["uint16"](t, in) + case *int32: + return unsafeDecoderCache["int32"](t, in) + case *uint32: + return unsafeDecoderCache["uint32"](t, in) + case *int64: + return unsafeDecoderCache["int64"](t, in) + case *uint64: + return unsafeDecoderCache["uint64"](t, in) + case *float32: + return unsafeDecoderCache["float32"](t, in) + case *float64: + return unsafeDecoderCache["float64"](t, in) + case *string: + return unsafeDecoderCache["string"](t, in) + default: + return 0, errors.New("ssob: No base type") + } +} + +func unsafeUnmarshal(e interface{}, in []byte) (n int, err error) { + n, err = unsafeUnmarshalBaseType(e, in) + if err == nil { + return n, err + } + + var key string + t := reflect.TypeOf(e) + v := reflect.ValueOf(e) + if t.Kind() != reflect.Ptr || v.IsNil() { + return 0, errors.New("ssob: Need a pointer that is fully allocated for unmarshalling") + } + + p := reflect.Indirect(v) + if p.Kind() != reflect.Ptr { + if p.Kind() == reflect.Struct { + key = reflect.TypeOf(p.Interface()).Name() + } else { + key = string(p.Kind()) + } + + if f, ok := unsafeDecoderCache[key]; ok { + return f(v.Interface(), in) + } else { + err = decRegisterUnsafe(p.Interface()) + if err != nil { + return 0, err + } + } + } + + return unsafeUnmarshal(e, in) +} + +func RegisterUnsafeDecoder(name string, f func(e interface{}, in []byte) (n int, err error)) { + unsafeDecoderCache[name] = f +}