package memcache import ( "bufio" "bytes" "context" "fmt" "io" "net" "strconv" "strings" "time" pkgerr "github.com/pkg/errors" ) var ( crlf = []byte("\r\n") space = []byte(" ") replyOK = []byte("OK\r\n") replyStored = []byte("STORED\r\n") replyNotStored = []byte("NOT_STORED\r\n") replyExists = []byte("EXISTS\r\n") replyNotFound = []byte("NOT_FOUND\r\n") replyDeleted = []byte("DELETED\r\n") replyEnd = []byte("END\r\n") replyTouched = []byte("TOUCHED\r\n") replyClientErrorPrefix = []byte("CLIENT_ERROR ") replyServerErrorPrefix = []byte("SERVER_ERROR ") ) var _ protocolConn = &asiiConn{} // asiiConn is the low-level implementation of Conn type asiiConn struct { err error conn net.Conn // Read & Write readTimeout time.Duration writeTimeout time.Duration rw *bufio.ReadWriter } func replyToError(line []byte) error { switch { case bytes.Equal(line, replyStored): return nil case bytes.Equal(line, replyOK): return nil case bytes.Equal(line, replyDeleted): return nil case bytes.Equal(line, replyTouched): return nil case bytes.Equal(line, replyNotStored): return ErrNotStored case bytes.Equal(line, replyExists): return ErrCASConflict case bytes.Equal(line, replyNotFound): return ErrNotFound case bytes.Equal(line, replyNotStored): return ErrNotStored case bytes.Equal(line, replyExists): return ErrCASConflict } return pkgerr.WithStack(protocolError(string(line))) } func (c *asiiConn) Populate(ctx context.Context, cmd string, key string, flags uint32, expiration int32, cas uint64, data []byte) error { var err error c.conn.SetWriteDeadline(shrinkDeadline(ctx, c.writeTimeout)) // [noreply]\r\n if cmd == "cas" { _, err = fmt.Fprintf(c.rw, "%s %s %d %d %d %d\r\n", cmd, key, flags, expiration, len(data), cas) } else { _, err = fmt.Fprintf(c.rw, "%s %s %d %d %d\r\n", cmd, key, flags, expiration, len(data)) } if err != nil { return c.fatal(err) } c.rw.Write(data) c.rw.Write(crlf) if err = c.rw.Flush(); err != nil { return c.fatal(err) } c.conn.SetReadDeadline(shrinkDeadline(ctx, c.readTimeout)) line, err := c.rw.ReadSlice('\n') if err != nil { return c.fatal(err) } return replyToError(line) } // newConn returns a new memcache connection for the given net connection. func newASCIIConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) (protocolConn, error) { if writeTimeout <= 0 || readTimeout <= 0 { return nil, pkgerr.Errorf("readTimeout writeTimeout can't be zero") } c := &asiiConn{ conn: netConn, rw: bufio.NewReadWriter(bufio.NewReader(netConn), bufio.NewWriter(netConn)), readTimeout: readTimeout, writeTimeout: writeTimeout, } return c, nil } func (c *asiiConn) Close() error { if c.err == nil { c.err = pkgerr.New("memcache: closed") } return c.conn.Close() } func (c *asiiConn) fatal(err error) error { if c.err == nil { c.err = pkgerr.WithStack(err) // Close connection to force errors on subsequent calls and to unblock // other reader or writer. c.conn.Close() } return c.err } func (c *asiiConn) Err() error { return c.err } func (c *asiiConn) Get(ctx context.Context, key string) (result *Item, err error) { c.conn.SetWriteDeadline(shrinkDeadline(ctx, c.writeTimeout)) if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", key); err != nil { return nil, c.fatal(err) } if err = c.rw.Flush(); err != nil { return nil, c.fatal(err) } if err = c.parseGetReply(ctx, func(it *Item) { result = it }); err != nil { return } if result == nil { return nil, ErrNotFound } return } func (c *asiiConn) GetMulti(ctx context.Context, keys ...string) (map[string]*Item, error) { var err error c.conn.SetWriteDeadline(shrinkDeadline(ctx, c.writeTimeout)) if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil { return nil, c.fatal(err) } if err = c.rw.Flush(); err != nil { return nil, c.fatal(err) } results := make(map[string]*Item, len(keys)) if err = c.parseGetReply(ctx, func(it *Item) { results[it.Key] = it }); err != nil { return nil, err } return results, nil } func (c *asiiConn) parseGetReply(ctx context.Context, f func(*Item)) error { c.conn.SetReadDeadline(shrinkDeadline(ctx, c.readTimeout)) for { line, err := c.rw.ReadSlice('\n') if err != nil { return c.fatal(err) } if bytes.Equal(line, replyEnd) { return nil } if bytes.HasPrefix(line, replyServerErrorPrefix) { errMsg := line[len(replyServerErrorPrefix):] return c.fatal(protocolError(errMsg)) } it := new(Item) size, err := scanGetReply(line, it) if err != nil { return c.fatal(err) } it.Value = make([]byte, size+2) if _, err = io.ReadFull(c.rw, it.Value); err != nil { return c.fatal(err) } if !bytes.HasSuffix(it.Value, crlf) { return c.fatal(protocolError("corrupt get reply, no except CRLF")) } it.Value = it.Value[:size] f(it) } } func scanGetReply(line []byte, item *Item) (size int, err error) { pattern := "VALUE %s %d %d %d\r\n" dest := []interface{}{&item.Key, &item.Flags, &size, &item.cas} if bytes.Count(line, space) == 3 { pattern = "VALUE %s %d %d\r\n" dest = dest[:3] } n, err := fmt.Sscanf(string(line), pattern, dest...) if err != nil || n != len(dest) { return -1, fmt.Errorf("memcache: unexpected line in get response: %q", line) } return size, nil } func (c *asiiConn) Touch(ctx context.Context, key string, expire int32) error { line, err := c.writeReadLine(ctx, "touch %s %d\r\n", key, expire) if err != nil { return err } return replyToError(line) } func (c *asiiConn) IncrDecr(ctx context.Context, cmd, key string, delta uint64) (uint64, error) { line, err := c.writeReadLine(ctx, "%s %s %d\r\n", cmd, key, delta) if err != nil { return 0, err } switch { case bytes.Equal(line, replyNotFound): return 0, ErrNotFound case bytes.HasPrefix(line, replyClientErrorPrefix): errMsg := line[len(replyClientErrorPrefix):] return 0, pkgerr.WithStack(protocolError(errMsg)) } val, err := strconv.ParseUint(string(line[:len(line)-2]), 10, 64) if err != nil { return 0, err } return val, nil } func (c *asiiConn) Delete(ctx context.Context, key string) error { line, err := c.writeReadLine(ctx, "delete %s\r\n", key) if err != nil { return err } return replyToError(line) } func (c *asiiConn) writeReadLine(ctx context.Context, format string, args ...interface{}) ([]byte, error) { var err error c.conn.SetWriteDeadline(shrinkDeadline(ctx, c.writeTimeout)) _, err = fmt.Fprintf(c.rw, format, args...) if err != nil { return nil, c.fatal(pkgerr.WithStack(err)) } if err = c.rw.Flush(); err != nil { return nil, c.fatal(pkgerr.WithStack(err)) } c.conn.SetReadDeadline(shrinkDeadline(ctx, c.readTimeout)) line, err := c.rw.ReadSlice('\n') if err != nil { return line, c.fatal(pkgerr.WithStack(err)) } return line, nil }