package memcache import ( "context" "fmt" "net" "strconv" "time" pkgerr "github.com/pkg/errors" ) const ( // 1024*1024 - 1, set error??? _largeValue = 1000 * 1000 // 1MB ) // low level connection that implement memcache protocol provide basic operation. type protocolConn interface { Populate(ctx context.Context, cmd string, key string, flags uint32, expiration int32, cas uint64, data []byte) error Get(ctx context.Context, key string) (*Item, error) GetMulti(ctx context.Context, keys ...string) (map[string]*Item, error) Touch(ctx context.Context, key string, expire int32) error IncrDecr(ctx context.Context, cmd, key string, delta uint64) (uint64, error) Delete(ctx context.Context, key string) error Close() error Err() error } // DialOption specifies an option for dialing a Memcache server. type DialOption struct { f func(*dialOptions) } type dialOptions struct { readTimeout time.Duration writeTimeout time.Duration protocol string dial func(network, addr string) (net.Conn, error) } // DialReadTimeout specifies the timeout for reading a single command reply. func DialReadTimeout(d time.Duration) DialOption { return DialOption{func(do *dialOptions) { do.readTimeout = d }} } // DialWriteTimeout specifies the timeout for writing a single command. func DialWriteTimeout(d time.Duration) DialOption { return DialOption{func(do *dialOptions) { do.writeTimeout = d }} } // DialConnectTimeout specifies the timeout for connecting to the Memcache server. func DialConnectTimeout(d time.Duration) DialOption { return DialOption{func(do *dialOptions) { dialer := net.Dialer{Timeout: d} do.dial = dialer.Dial }} } // DialNetDial specifies a custom dial function for creating TCP // connections. If this option is left out, then net.Dial is // used. DialNetDial overrides DialConnectTimeout. func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption { return DialOption{func(do *dialOptions) { do.dial = dial }} } // Dial connects to the Memcache server at the given network and // address using the specified options. func Dial(network, address string, options ...DialOption) (Conn, error) { do := dialOptions{ dial: net.Dial, } for _, option := range options { option.f(&do) } netConn, err := do.dial(network, address) if err != nil { return nil, pkgerr.WithStack(err) } pconn, err := newASCIIConn(netConn, do.readTimeout, do.writeTimeout) return &conn{pconn: pconn, ed: newEncodeDecoder()}, nil } type conn struct { // low level connection. pconn protocolConn ed *encodeDecode } func (c *conn) Close() error { return c.pconn.Close() } func (c *conn) Err() error { return c.pconn.Err() } func (c *conn) AddContext(ctx context.Context, item *Item) error { return c.populate(ctx, "add", item) } func (c *conn) SetContext(ctx context.Context, item *Item) error { return c.populate(ctx, "set", item) } func (c *conn) ReplaceContext(ctx context.Context, item *Item) error { return c.populate(ctx, "replace", item) } func (c *conn) CompareAndSwapContext(ctx context.Context, item *Item) error { return c.populate(ctx, "cas", item) } func (c *conn) populate(ctx context.Context, cmd string, item *Item) error { if !legalKey(item.Key) { return ErrMalformedKey } data, err := c.ed.encode(item) if err != nil { return err } length := len(data) if length < _largeValue { return c.pconn.Populate(ctx, cmd, item.Key, item.Flags, item.Expiration, item.cas, data) } count := length/_largeValue + 1 if err = c.pconn.Populate(ctx, cmd, item.Key, item.Flags|flagLargeValue, item.Expiration, item.cas, []byte(strconv.Itoa(length))); err != nil { return err } var chunk []byte for i := 1; i <= count; i++ { if i == count { chunk = data[_largeValue*(count-1):] } else { chunk = data[_largeValue*(i-1) : _largeValue*i] } key := fmt.Sprintf("%s%d", item.Key, i) if err = c.pconn.Populate(ctx, cmd, key, item.Flags, item.Expiration, item.cas, chunk); err != nil { return err } } return nil } func (c *conn) GetContext(ctx context.Context, key string) (*Item, error) { if !legalKey(key) { return nil, ErrMalformedKey } result, err := c.pconn.Get(ctx, key) if err != nil { return nil, err } if result.Flags&flagLargeValue != flagLargeValue { return result, err } return c.getLargeItem(ctx, result) } func (c *conn) getLargeItem(ctx context.Context, result *Item) (*Item, error) { length, err := strconv.Atoi(string(result.Value)) if err != nil { return nil, err } count := length/_largeValue + 1 keys := make([]string, 0, count) for i := 1; i <= count; i++ { keys = append(keys, fmt.Sprintf("%s%d", result.Key, i)) } var results map[string]*Item if results, err = c.pconn.GetMulti(ctx, keys...); err != nil { return nil, err } if len(results) < count { return nil, ErrNotFound } result.Value = make([]byte, 0, length) for _, k := range keys { ti := results[k] if ti == nil || ti.Value == nil { return nil, ErrNotFound } result.Value = append(result.Value, ti.Value...) } result.Flags = result.Flags ^ flagLargeValue return result, nil } func (c *conn) GetMultiContext(ctx context.Context, keys []string) (map[string]*Item, error) { // TODO: move to protocolConn? for _, key := range keys { if !legalKey(key) { return nil, ErrMalformedKey } } results, err := c.pconn.GetMulti(ctx, keys...) if err != nil { return results, err } for k, v := range results { if v.Flags&flagLargeValue != flagLargeValue { continue } if v, err = c.getLargeItem(ctx, v); err != nil { return results, err } results[k] = v } return results, nil } func (c *conn) DeleteContext(ctx context.Context, key string) error { if !legalKey(key) { return ErrMalformedKey } return c.pconn.Delete(ctx, key) } func (c *conn) IncrementContext(ctx context.Context, key string, delta uint64) (uint64, error) { if !legalKey(key) { return 0, ErrMalformedKey } return c.pconn.IncrDecr(ctx, "incr", key, delta) } func (c *conn) DecrementContext(ctx context.Context, key string, delta uint64) (uint64, error) { if !legalKey(key) { return 0, ErrMalformedKey } return c.pconn.IncrDecr(ctx, "decr", key, delta) } func (c *conn) TouchContext(ctx context.Context, key string, seconds int32) error { if !legalKey(key) { return ErrMalformedKey } return c.pconn.Touch(ctx, key, seconds) } func (c *conn) Add(item *Item) error { return c.AddContext(context.TODO(), item) } func (c *conn) Set(item *Item) error { return c.SetContext(context.TODO(), item) } func (c *conn) Replace(item *Item) error { return c.ReplaceContext(context.TODO(), item) } func (c *conn) Get(key string) (*Item, error) { return c.GetContext(context.TODO(), key) } func (c *conn) GetMulti(keys []string) (map[string]*Item, error) { return c.GetMultiContext(context.TODO(), keys) } func (c *conn) Delete(key string) error { return c.DeleteContext(context.TODO(), key) } func (c *conn) Increment(key string, delta uint64) (newValue uint64, err error) { return c.IncrementContext(context.TODO(), key, delta) } func (c *conn) Decrement(key string, delta uint64) (newValue uint64, err error) { return c.DecrementContext(context.TODO(), key, delta) } func (c *conn) CompareAndSwap(item *Item) error { return c.CompareAndSwapContext(context.TODO(), item) } func (c *conn) Touch(key string, seconds int32) (err error) { return c.TouchContext(context.TODO(), key, seconds) } func (c *conn) Scan(item *Item, v interface{}) (err error) { return pkgerr.WithStack(c.ed.decode(item, v)) }