diff --git a/pkg/cache/memcache/ascii_conn.go b/pkg/cache/memcache/ascii_conn.go index 327629e80..a09b7ed3a 100644 --- a/pkg/cache/memcache/ascii_conn.go +++ b/pkg/cache/memcache/ascii_conn.go @@ -66,9 +66,9 @@ func replyToError(line []byte) error { } func (c *asiiConn) Populate(ctx context.Context, cmd string, key string, flags uint32, expiration int32, cas uint64, data []byte) error { + var err error c.conn.SetWriteDeadline(shrinkDeadline(ctx, c.writeTimeout)) // [noreply]\r\n - var err error if cmd == "cas" { _, err = fmt.Fprintf(c.rw, "%s %s %d %d %d %d\r\n", cmd, key, flags, expiration, len(data), cas) } else { @@ -127,14 +127,14 @@ func (c *asiiConn) Err() error { } func (c *asiiConn) Get(ctx context.Context, key string) (result *Item, err error) { - c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + c.conn.SetWriteDeadline(shrinkDeadline(ctx, c.writeTimeout)) if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", key); err != nil { return nil, c.fatal(err) } if err = c.rw.Flush(); err != nil { return nil, c.fatal(err) } - if err = c.parseGetReply(func(it *Item) { + if err = c.parseGetReply(ctx, func(it *Item) { result = it }); err != nil { return @@ -155,7 +155,7 @@ func (c *asiiConn) GetMulti(ctx context.Context, keys ...string) (map[string]*It return nil, c.fatal(err) } results := make(map[string]*Item, len(keys)) - if err = c.parseGetReply(func(it *Item) { + if err = c.parseGetReply(ctx, func(it *Item) { results[it.Key] = it }); err != nil { return nil, err @@ -163,8 +163,8 @@ func (c *asiiConn) GetMulti(ctx context.Context, keys ...string) (map[string]*It return results, nil } -func (c *asiiConn) parseGetReply(f func(*Item)) error { - c.conn.SetReadDeadline(shrinkDeadline(context.TODO(), c.readTimeout)) +func (c *asiiConn) parseGetReply(ctx context.Context, f func(*Item)) error { + c.conn.SetReadDeadline(shrinkDeadline(ctx, c.readTimeout)) for { line, err := c.rw.ReadSlice('\n') if err != nil { @@ -209,7 +209,7 @@ func scanGetReply(line []byte, item *Item) (size int, err error) { } func (c *asiiConn) Touch(ctx context.Context, key string, expire int32) error { - line, err := c.writeReadLine("touch %s %d\r\n", key, expire) + line, err := c.writeReadLine(ctx, "touch %s %d\r\n", key, expire) if err != nil { return err } @@ -217,7 +217,7 @@ func (c *asiiConn) Touch(ctx context.Context, key string, expire int32) error { } func (c *asiiConn) IncrDecr(ctx context.Context, cmd, key string, delta uint64) (uint64, error) { - line, err := c.writeReadLine("%s %s %d\r\n", cmd, key, delta) + line, err := c.writeReadLine(ctx, "%s %s %d\r\n", cmd, key, delta) if err != nil { return 0, err } @@ -236,23 +236,24 @@ func (c *asiiConn) IncrDecr(ctx context.Context, cmd, key string, delta uint64) } func (c *asiiConn) Delete(ctx context.Context, key string) error { - line, err := c.writeReadLine("delete %s\r\n", key) + line, err := c.writeReadLine(ctx, "delete %s\r\n", key) if err != nil { return err } return replyToError(line) } -func (c *asiiConn) writeReadLine(format string, args ...interface{}) ([]byte, error) { - c.conn.SetWriteDeadline(shrinkDeadline(context.TODO(), c.writeTimeout)) - _, err := fmt.Fprintf(c.rw, format, args...) +func (c *asiiConn) writeReadLine(ctx context.Context, format string, args ...interface{}) ([]byte, error) { + var err error + c.conn.SetWriteDeadline(shrinkDeadline(ctx, c.writeTimeout)) + _, err = fmt.Fprintf(c.rw, format, args...) if err != nil { return nil, c.fatal(pkgerr.WithStack(err)) } if err = c.rw.Flush(); err != nil { return nil, c.fatal(pkgerr.WithStack(err)) } - c.conn.SetReadDeadline(shrinkDeadline(context.TODO(), c.readTimeout)) + c.conn.SetReadDeadline(shrinkDeadline(ctx, c.readTimeout)) line, err := c.rw.ReadSlice('\n') if err != nil { return line, c.fatal(pkgerr.WithStack(err)) diff --git a/pkg/cache/memcache/pool_conn.go b/pkg/cache/memcache/pool_conn.go index 4ccff2aac..1ac218932 100644 --- a/pkg/cache/memcache/pool_conn.go +++ b/pkg/cache/memcache/pool_conn.go @@ -86,7 +86,7 @@ func (pc *poolConn) Err() error { } func (pc *poolConn) Set(item *Item) (err error) { - return pc.c.SetContext(pc.ctx, item) + return pc.SetContext(pc.ctx, item) } func (pc *poolConn) Add(item *Item) (err error) { @@ -102,15 +102,15 @@ func (pc *poolConn) CompareAndSwap(item *Item) (err error) { } func (pc *poolConn) Get(key string) (r *Item, err error) { - return pc.c.GetContext(pc.ctx, key) + return pc.GetContext(pc.ctx, key) } func (pc *poolConn) GetMulti(keys []string) (res map[string]*Item, err error) { - return pc.c.GetMultiContext(pc.ctx, keys) + return pc.GetMultiContext(pc.ctx, keys) } func (pc *poolConn) Touch(key string, timeout int32) (err error) { - return pc.c.TouchContext(pc.ctx, key, timeout) + return pc.TouchContext(pc.ctx, key, timeout) } func (pc *poolConn) Scan(item *Item, v interface{}) error { @@ -118,15 +118,15 @@ func (pc *poolConn) Scan(item *Item, v interface{}) error { } func (pc *poolConn) Delete(key string) (err error) { - return pc.c.DeleteContext(pc.ctx, key) + return pc.DeleteContext(pc.ctx, key) } func (pc *poolConn) Increment(key string, delta uint64) (newValue uint64, err error) { - return pc.c.IncrementContext(pc.ctx, key, delta) + return pc.IncrementContext(pc.ctx, key, delta) } func (pc *poolConn) Decrement(key string, delta uint64) (newValue uint64, err error) { - return pc.c.DecrementContext(pc.ctx, key, delta) + return pc.DecrementContext(pc.ctx, key, delta) } func (pc *poolConn) AddContext(ctx context.Context, item *Item) error { diff --git a/pkg/cache/memcache/util.go b/pkg/cache/memcache/util.go index ce64bf1fc..e42d49910 100644 --- a/pkg/cache/memcache/util.go +++ b/pkg/cache/memcache/util.go @@ -80,10 +80,9 @@ func ProtobufItem(key string, message proto.Message, flags uint32, expiration in } func shrinkDeadline(ctx context.Context, timeout time.Duration) time.Time { - // TODO: ignored context deadline to compatible old behaviour. - //deadline, ok := ctx.Deadline() - //if ok { - // return deadline - //} - return time.Now().Add(timeout) + timeoutTime := time.Now().Add(timeout) + if deadline, ok := ctx.Deadline(); ok && timeoutTime.After(deadline) { + return deadline + } + return timeoutTime } diff --git a/pkg/cache/memcache/util_test.go b/pkg/cache/memcache/util_test.go index 34b66c290..080ef4549 100644 --- a/pkg/cache/memcache/util_test.go +++ b/pkg/cache/memcache/util_test.go @@ -1,7 +1,9 @@ package memcache import ( + "context" "testing" + "time" pb "github.com/bilibili/kratos/pkg/cache/memcache/test" @@ -73,3 +75,31 @@ func TestLegalKey(t *testing.T) { }) } } + +func TestShrinkDeadline(t *testing.T) { + t.Run("test not deadline", func(t *testing.T) { + timeout := time.Second + timeoutTime := time.Now().Add(timeout) + tm := shrinkDeadline(context.Background(), timeout) + assert.True(t, tm.After(timeoutTime)) + }) + t.Run("test big deadline", func(t *testing.T) { + timeout := time.Second + timeoutTime := time.Now().Add(timeout) + deadlineTime := time.Now().Add(2 * time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + tm := shrinkDeadline(ctx, timeout) + assert.True(t, tm.After(timeoutTime) && tm.Before(deadlineTime)) + }) + t.Run("test small deadline", func(t *testing.T) { + timeout := time.Second + deadlineTime := time.Now().Add(500 * time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + tm := shrinkDeadline(ctx, timeout) + assert.True(t, tm.After(deadlineTime) && tm.Before(time.Now().Add(timeout))) + }) +}