You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
kratos/pkg/cache/memcache/ascii_conn.go

262 lines
6.8 KiB

package memcache
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
"strconv"
"strings"
"time"
pkgerr "github.com/pkg/errors"
)
var (
crlf = []byte("\r\n")
space = []byte(" ")
replyOK = []byte("OK\r\n")
replyStored = []byte("STORED\r\n")
replyNotStored = []byte("NOT_STORED\r\n")
replyExists = []byte("EXISTS\r\n")
replyNotFound = []byte("NOT_FOUND\r\n")
replyDeleted = []byte("DELETED\r\n")
replyEnd = []byte("END\r\n")
replyTouched = []byte("TOUCHED\r\n")
replyClientErrorPrefix = []byte("CLIENT_ERROR ")
replyServerErrorPrefix = []byte("SERVER_ERROR ")
)
var _ protocolConn = &asiiConn{}
// asiiConn is the low-level implementation of Conn
type asiiConn struct {
err error
conn net.Conn
// Read & Write
readTimeout time.Duration
writeTimeout time.Duration
rw *bufio.ReadWriter
}
func replyToError(line []byte) error {
switch {
case bytes.Equal(line, replyStored):
return nil
case bytes.Equal(line, replyOK):
return nil
case bytes.Equal(line, replyDeleted):
return nil
case bytes.Equal(line, replyTouched):
return nil
case bytes.Equal(line, replyNotStored):
return ErrNotStored
case bytes.Equal(line, replyExists):
return ErrCASConflict
case bytes.Equal(line, replyNotFound):
return ErrNotFound
case bytes.Equal(line, replyNotStored):
return ErrNotStored
case bytes.Equal(line, replyExists):
return ErrCASConflict
}
return pkgerr.WithStack(protocolError(string(line)))
}
func (c *asiiConn) Populate(ctx context.Context, cmd string, key string, flags uint32, expiration int32, cas uint64, data []byte) error {
c.conn.SetWriteDeadline(shrinkDeadline(ctx, c.writeTimeout))
// <command name> <key> <flags> <exptime> <bytes> [noreply]\r\n
var err error
if cmd == "cas" {
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d %d\r\n", cmd, key, flags, expiration, len(data), cas)
} else {
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d\r\n", cmd, key, flags, expiration, len(data))
}
if err != nil {
return c.fatal(err)
}
c.rw.Write(data)
c.rw.Write(crlf)
if err = c.rw.Flush(); err != nil {
return c.fatal(err)
}
c.conn.SetReadDeadline(shrinkDeadline(ctx, c.readTimeout))
line, err := c.rw.ReadSlice('\n')
if err != nil {
return c.fatal(err)
}
return replyToError(line)
}
// newConn returns a new memcache connection for the given net connection.
func newASCIIConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) (protocolConn, error) {
if writeTimeout <= 0 || readTimeout <= 0 {
return nil, pkgerr.Errorf("readTimeout writeTimeout can't be zero")
}
c := &asiiConn{
conn: netConn,
rw: bufio.NewReadWriter(bufio.NewReader(netConn),
bufio.NewWriter(netConn)),
readTimeout: readTimeout,
writeTimeout: writeTimeout,
}
return c, nil
}
func (c *asiiConn) Close() error {
if c.err == nil {
c.err = pkgerr.New("memcache: closed")
}
return c.conn.Close()
}
func (c *asiiConn) fatal(err error) error {
if c.err == nil {
c.err = pkgerr.WithStack(err)
// Close connection to force errors on subsequent calls and to unblock
// other reader or writer.
c.conn.Close()
}
return c.err
}
func (c *asiiConn) Err() error {
return c.err
}
func (c *asiiConn) Get(ctx context.Context, key string) (result *Item, err error) {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", key); err != nil {
return nil, c.fatal(err)
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(err)
}
if err = c.parseGetReply(func(it *Item) {
result = it
}); err != nil {
return
}
if result == nil {
return nil, ErrNotFound
}
return
}
func (c *asiiConn) GetMulti(ctx context.Context, keys ...string) (map[string]*Item, error) {
var err error
c.conn.SetWriteDeadline(shrinkDeadline(ctx, c.writeTimeout))
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil {
return nil, c.fatal(err)
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(err)
}
results := make(map[string]*Item, len(keys))
if err = c.parseGetReply(func(it *Item) {
results[it.Key] = it
}); err != nil {
return nil, err
}
return results, nil
}
func (c *asiiConn) parseGetReply(f func(*Item)) error {
c.conn.SetReadDeadline(shrinkDeadline(context.TODO(), c.readTimeout))
for {
line, err := c.rw.ReadSlice('\n')
if err != nil {
return c.fatal(err)
}
if bytes.Equal(line, replyEnd) {
return nil
}
if bytes.HasPrefix(line, replyServerErrorPrefix) {
errMsg := line[len(replyServerErrorPrefix):]
return c.fatal(protocolError(errMsg))
}
it := new(Item)
size, err := scanGetReply(line, it)
if err != nil {
return c.fatal(err)
}
it.Value = make([]byte, size+2)
if _, err = io.ReadFull(c.rw, it.Value); err != nil {
return c.fatal(err)
}
if !bytes.HasSuffix(it.Value, crlf) {
return c.fatal(protocolError("corrupt get reply, no except CRLF"))
}
it.Value = it.Value[:size]
f(it)
}
}
func scanGetReply(line []byte, item *Item) (size int, err error) {
pattern := "VALUE %s %d %d %d\r\n"
dest := []interface{}{&item.Key, &item.Flags, &size, &item.cas}
if bytes.Count(line, space) == 3 {
pattern = "VALUE %s %d %d\r\n"
dest = dest[:3]
}
n, err := fmt.Sscanf(string(line), pattern, dest...)
if err != nil || n != len(dest) {
return -1, fmt.Errorf("memcache: unexpected line in get response: %q", line)
}
return size, nil
}
func (c *asiiConn) Touch(ctx context.Context, key string, expire int32) error {
line, err := c.writeReadLine("touch %s %d\r\n", key, expire)
if err != nil {
return err
}
return replyToError(line)
}
func (c *asiiConn) IncrDecr(ctx context.Context, cmd, key string, delta uint64) (uint64, error) {
line, err := c.writeReadLine("%s %s %d\r\n", cmd, key, delta)
if err != nil {
return 0, err
}
switch {
case bytes.Equal(line, replyNotFound):
return 0, ErrNotFound
case bytes.HasPrefix(line, replyClientErrorPrefix):
errMsg := line[len(replyClientErrorPrefix):]
return 0, pkgerr.WithStack(protocolError(errMsg))
}
val, err := strconv.ParseUint(string(line[:len(line)-2]), 10, 64)
if err != nil {
return 0, err
}
return val, nil
}
func (c *asiiConn) Delete(ctx context.Context, key string) error {
line, err := c.writeReadLine("delete %s\r\n", key)
if err != nil {
return err
}
return replyToError(line)
}
func (c *asiiConn) writeReadLine(format string, args ...interface{}) ([]byte, error) {
c.conn.SetWriteDeadline(shrinkDeadline(context.TODO(), c.writeTimeout))
_, err := fmt.Fprintf(c.rw, format, args...)
if err != nil {
return nil, c.fatal(pkgerr.WithStack(err))
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(pkgerr.WithStack(err))
}
c.conn.SetReadDeadline(shrinkDeadline(context.TODO(), c.readTimeout))
line, err := c.rw.ReadSlice('\n')
if err != nil {
return line, c.fatal(pkgerr.WithStack(err))
}
return line, nil
}