kratos/pkg/cache/memcache/conn.go

287 lines
7.4 KiB

package memcache
import (
"context"
"fmt"
"net"
"strconv"
"time"
pkgerr "github.com/pkg/errors"
)
const (
// 1024*1024 - 1, set error???
_largeValue = 1000 * 1000 // 1MB
)
// low level connection that implement memcache protocol provide basic operation.
type protocolConn interface {
Populate(ctx context.Context, cmd string, key string, flags uint32, expiration int32, cas uint64, data []byte) error
Get(ctx context.Context, key string) (*Item, error)
GetMulti(ctx context.Context, keys ...string) (map[string]*Item, error)
Touch(ctx context.Context, key string, expire int32) error
IncrDecr(ctx context.Context, cmd, key string, delta uint64) (uint64, error)
Delete(ctx context.Context, key string) error
Close() error
Err() error
}
// DialOption specifies an option for dialing a Memcache server.
type DialOption struct {
f func(*dialOptions)
}
type dialOptions struct {
readTimeout time.Duration
writeTimeout time.Duration
protocol string
dial func(network, addr string) (net.Conn, error)
}
// DialReadTimeout specifies the timeout for reading a single command reply.
func DialReadTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.readTimeout = d
}}
}
// DialWriteTimeout specifies the timeout for writing a single command.
func DialWriteTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.writeTimeout = d
}}
}
// DialConnectTimeout specifies the timeout for connecting to the Memcache server.
func DialConnectTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
dialer := net.Dialer{Timeout: d}
do.dial = dialer.Dial
}}
}
// DialNetDial specifies a custom dial function for creating TCP
// connections. If this option is left out, then net.Dial is
// used. DialNetDial overrides DialConnectTimeout.
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) {
do.dial = dial
}}
}
// Dial connects to the Memcache server at the given network and
// address using the specified options.
func Dial(network, address string, options ...DialOption) (Conn, error) {
do := dialOptions{
dial: net.Dial,
}
for _, option := range options {
option.f(&do)
}
netConn, err := do.dial(network, address)
if err != nil {
return nil, pkgerr.WithStack(err)
}
pconn, err := newASCIIConn(netConn, do.readTimeout, do.writeTimeout)
return &conn{pconn: pconn, ed: newEncodeDecoder()}, nil
}
type conn struct {
// low level connection.
pconn protocolConn
ed *encodeDecode
}
func (c *conn) Close() error {
return c.pconn.Close()
}
func (c *conn) Err() error {
return c.pconn.Err()
}
func (c *conn) AddContext(ctx context.Context, item *Item) error {
return c.populate(ctx, "add", item)
}
func (c *conn) SetContext(ctx context.Context, item *Item) error {
return c.populate(ctx, "set", item)
}
func (c *conn) ReplaceContext(ctx context.Context, item *Item) error {
return c.populate(ctx, "replace", item)
}
func (c *conn) CompareAndSwapContext(ctx context.Context, item *Item) error {
return c.populate(ctx, "cas", item)
}
func (c *conn) populate(ctx context.Context, cmd string, item *Item) error {
if !legalKey(item.Key) {
return ErrMalformedKey
}
data, err := c.ed.encode(item)
if err != nil {
return err
}
length := len(data)
if length < _largeValue {
return c.pconn.Populate(ctx, cmd, item.Key, item.Flags, item.Expiration, item.cas, data)
}
count := length/_largeValue + 1
if err = c.pconn.Populate(ctx, cmd, item.Key, item.Flags|flagLargeValue, item.Expiration, item.cas, []byte(strconv.Itoa(length))); err != nil {
return err
}
var chunk []byte
for i := 1; i <= count; i++ {
if i == count {
chunk = data[_largeValue*(count-1):]
} else {
chunk = data[_largeValue*(i-1) : _largeValue*i]
}
key := fmt.Sprintf("%s%d", item.Key, i)
if err = c.pconn.Populate(ctx, cmd, key, item.Flags, item.Expiration, item.cas, chunk); err != nil {
return err
}
}
return nil
}
func (c *conn) GetContext(ctx context.Context, key string) (*Item, error) {
if !legalKey(key) {
return nil, ErrMalformedKey
}
result, err := c.pconn.Get(ctx, key)
if err != nil {
return nil, err
}
if result.Flags&flagLargeValue != flagLargeValue {
return result, err
}
return c.getLargeItem(ctx, result)
}
func (c *conn) getLargeItem(ctx context.Context, result *Item) (*Item, error) {
length, err := strconv.Atoi(string(result.Value))
if err != nil {
return nil, err
}
count := length/_largeValue + 1
keys := make([]string, 0, count)
for i := 1; i <= count; i++ {
keys = append(keys, fmt.Sprintf("%s%d", result.Key, i))
}
var results map[string]*Item
if results, err = c.pconn.GetMulti(ctx, keys...); err != nil {
return nil, err
}
if len(results) < count {
return nil, ErrNotFound
}
result.Value = make([]byte, 0, length)
for _, k := range keys {
ti := results[k]
if ti == nil || ti.Value == nil {
return nil, ErrNotFound
}
result.Value = append(result.Value, ti.Value...)
}
result.Flags = result.Flags ^ flagLargeValue
return result, nil
}
func (c *conn) GetMultiContext(ctx context.Context, keys []string) (map[string]*Item, error) {
// TODO: move to protocolConn?
for _, key := range keys {
if !legalKey(key) {
return nil, ErrMalformedKey
}
}
results, err := c.pconn.GetMulti(ctx, keys...)
if err != nil {
return results, err
}
for k, v := range results {
if v.Flags&flagLargeValue != flagLargeValue {
continue
}
if v, err = c.getLargeItem(ctx, v); err != nil {
return results, err
}
results[k] = v
}
return results, nil
}
func (c *conn) DeleteContext(ctx context.Context, key string) error {
if !legalKey(key) {
return ErrMalformedKey
}
return c.pconn.Delete(ctx, key)
}
func (c *conn) IncrementContext(ctx context.Context, key string, delta uint64) (uint64, error) {
if !legalKey(key) {
return 0, ErrMalformedKey
}
return c.pconn.IncrDecr(ctx, "incr", key, delta)
}
func (c *conn) DecrementContext(ctx context.Context, key string, delta uint64) (uint64, error) {
if !legalKey(key) {
return 0, ErrMalformedKey
}
return c.pconn.IncrDecr(ctx, "decr", key, delta)
}
func (c *conn) TouchContext(ctx context.Context, key string, seconds int32) error {
if !legalKey(key) {
return ErrMalformedKey
}
return c.pconn.Touch(ctx, key, seconds)
}
func (c *conn) Add(item *Item) error {
return c.AddContext(context.TODO(), item)
}
func (c *conn) Set(item *Item) error {
return c.SetContext(context.TODO(), item)
}
func (c *conn) Replace(item *Item) error {
return c.ReplaceContext(context.TODO(), item)
}
func (c *conn) Get(key string) (*Item, error) {
return c.GetContext(context.TODO(), key)
}
func (c *conn) GetMulti(keys []string) (map[string]*Item, error) {
return c.GetMultiContext(context.TODO(), keys)
}
func (c *conn) Delete(key string) error {
return c.DeleteContext(context.TODO(), key)
}
func (c *conn) Increment(key string, delta uint64) (newValue uint64, err error) {
return c.IncrementContext(context.TODO(), key, delta)
}
func (c *conn) Decrement(key string, delta uint64) (newValue uint64, err error) {
return c.DecrementContext(context.TODO(), key, delta)
}
func (c *conn) CompareAndSwap(item *Item) error {
return c.CompareAndSwapContext(context.TODO(), item)
}
func (c *conn) Touch(key string, seconds int32) (err error) {
return c.TouchContext(context.TODO(), key, seconds)
}
func (c *conn) Scan(item *Item, v interface{}) (err error) {
return pkgerr.WithStack(c.ed.decode(item, v))
}