- update
continuous-integration/drone/push Build is passing Details

master
李光春 2 years ago
parent 8e5b593199
commit 15776f755b

1
.gitignore vendored

@ -5,6 +5,5 @@
.vscode
*.log
gomod.sh
/vendor/
redis_test.go
memcached_test.go

@ -0,0 +1,28 @@
---
codecov:
require_ci_to_pass: true
comment:
behavior: default
layout: reach, diff, flags, files, footer
require_base: false
require_changes: false
require_head: true
coverage:
precision: 2
range:
- 70
- 100
round: down
status:
changes: false
patch: true
project: true
parsers:
gcov:
branch_detection:
conditional: true
loop: true
macro: false
method: false
javascript:
enable_partials: false

@ -0,0 +1,11 @@
.idea
.DS_Store
/server/server.exe
/server/server
/server/server_dar*
/server/server_fre*
/server/server_win*
/server/server_net*
/server/server_ope*
/server/server_lin*
CHANGELOG.md

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

@ -0,0 +1,200 @@
# BigCache [![Build Status](https://github.com/allegro/bigcache/workflows/build/badge.svg)](https://github.com/allegro/bigcache/actions?query=workflow%3Abuild) [![Coverage Status](https://coveralls.io/repos/github/allegro/bigcache/badge.svg?branch=master)](https://coveralls.io/github/allegro/bigcache?branch=master) [![GoDoc](https://godoc.org/github.com/allegro/bigcache?status.svg)](https://godoc.org/github.com/allegro/bigcache) [![Go Report Card](https://goreportcard.com/badge/github.com/allegro/bigcache)](https://goreportcard.com/report/github.com/allegro/bigcache)
Fast, concurrent, evicting in-memory cache written to keep big number of entries without impact on performance.
BigCache keeps entries on heap but omits GC for them. To achieve that, operations on byte slices take place,
therefore entries (de)serialization in front of the cache will be needed in most use cases.
Requires Go 1.12 or newer.
## Usage
### Simple initialization
```go
import "github.com/allegro/bigcache/v3"
cache, _ := bigcache.NewBigCache(bigcache.DefaultConfig(10 * time.Minute))
cache.Set("my-unique-key", []byte("value"))
entry, _ := cache.Get("my-unique-key")
fmt.Println(string(entry))
```
### Custom initialization
When cache load can be predicted in advance then it is better to use custom initialization because additional memory
allocation can be avoided in that way.
```go
import (
"log"
"github.com/allegro/bigcache/v3"
)
config := bigcache.Config {
// number of shards (must be a power of 2)
Shards: 1024,
// time after which entry can be evicted
LifeWindow: 10 * time.Minute,
// Interval between removing expired entries (clean up).
// If set to <= 0 then no action is performed.
// Setting to < 1 second is counterproductive bigcache has a one second resolution.
CleanWindow: 5 * time.Minute,
// rps * lifeWindow, used only in initial memory allocation
MaxEntriesInWindow: 1000 * 10 * 60,
// max entry size in bytes, used only in initial memory allocation
MaxEntrySize: 500,
// prints information about additional memory allocation
Verbose: true,
// cache will not allocate more memory than this limit, value in MB
// if value is reached then the oldest entries can be overridden for the new ones
// 0 value means no size limit
HardMaxCacheSize: 8192,
// callback fired when the oldest entry is removed because of its expiration time or no space left
// for the new entry, or because delete was called. A bitmask representing the reason will be returned.
// Default value is nil which means no callback and it prevents from unwrapping the oldest entry.
OnRemove: nil,
// OnRemoveWithReason is a callback fired when the oldest entry is removed because of its expiration time or no space left
// for the new entry, or because delete was called. A constant representing the reason will be passed through.
// Default value is nil which means no callback and it prevents from unwrapping the oldest entry.
// Ignored if OnRemove is specified.
OnRemoveWithReason: nil,
}
cache, initErr := bigcache.NewBigCache(config)
if initErr != nil {
log.Fatal(initErr)
}
cache.Set("my-unique-key", []byte("value"))
if entry, err := cache.Get("my-unique-key"); err == nil {
fmt.Println(string(entry))
}
```
### `LifeWindow` & `CleanWindow`
1. `LifeWindow` is a time. After that time, an entry can be called dead but not deleted.
2. `CleanWindow` is a time. After that time, all the dead entries will be deleted, but not the entries that still have life.
## [Benchmarks](https://github.com/allegro/bigcache-bench)
Three caches were compared: bigcache, [freecache](https://github.com/coocood/freecache) and map.
Benchmark tests were made using an
i7-6700K CPU @ 4.00GHz with 32GB of RAM on Ubuntu 18.04 LTS (5.2.12-050212-generic).
Benchmarks source code can be found [here](https://github.com/allegro/bigcache-bench)
### Writes and reads
```bash
go version
go version go1.13 linux/amd64
go test -bench=. -benchmem -benchtime=4s ./... -timeout 30m
goos: linux
goarch: amd64
pkg: github.com/allegro/bigcache/v3/caches_bench
BenchmarkMapSet-8 12999889 376 ns/op 199 B/op 3 allocs/op
BenchmarkConcurrentMapSet-8 4355726 1275 ns/op 337 B/op 8 allocs/op
BenchmarkFreeCacheSet-8 11068976 703 ns/op 328 B/op 2 allocs/op
BenchmarkBigCacheSet-8 10183717 478 ns/op 304 B/op 2 allocs/op
BenchmarkMapGet-8 16536015 324 ns/op 23 B/op 1 allocs/op
BenchmarkConcurrentMapGet-8 13165708 401 ns/op 24 B/op 2 allocs/op
BenchmarkFreeCacheGet-8 10137682 690 ns/op 136 B/op 2 allocs/op
BenchmarkBigCacheGet-8 11423854 450 ns/op 152 B/op 4 allocs/op
BenchmarkBigCacheSetParallel-8 34233472 148 ns/op 317 B/op 3 allocs/op
BenchmarkFreeCacheSetParallel-8 34222654 268 ns/op 350 B/op 3 allocs/op
BenchmarkConcurrentMapSetParallel-8 19635688 240 ns/op 200 B/op 6 allocs/op
BenchmarkBigCacheGetParallel-8 60547064 86.1 ns/op 152 B/op 4 allocs/op
BenchmarkFreeCacheGetParallel-8 50701280 147 ns/op 136 B/op 3 allocs/op
BenchmarkConcurrentMapGetParallel-8 27353288 175 ns/op 24 B/op 2 allocs/op
PASS
ok github.com/allegro/bigcache/v3/caches_bench 256.257s
```
Writes and reads in bigcache are faster than in freecache.
Writes to map are the slowest.
### GC pause time
```bash
go version
go version go1.13 linux/amd64
go run caches_gc_overhead_comparison.go
Number of entries: 20000000
GC pause for bigcache: 1.506077ms
GC pause for freecache: 5.594416ms
GC pause for map: 9.347015ms
```
```
go version
go version go1.13 linux/arm64
go run caches_gc_overhead_comparison.go
Number of entries: 20000000
GC pause for bigcache: 22.382827ms
GC pause for freecache: 41.264651ms
GC pause for map: 72.236853ms
```
Test shows how long are the GC pauses for caches filled with 20mln of entries.
Bigcache and freecache have very similar GC pause time.
### Memory usage
You may encounter system memory reporting what appears to be an exponential increase, however this is expected behaviour. Go runtime allocates memory in chunks or 'spans' and will inform the OS when they are no longer required by changing their state to 'idle'. The 'spans' will remain part of the process resource usage until the OS needs to repurpose the address. Further reading available [here](https://utcc.utoronto.ca/~cks/space/blog/programming/GoNoMemoryFreeing).
## How it works
BigCache relies on optimization presented in 1.5 version of Go ([issue-9477](https://github.com/golang/go/issues/9477)).
This optimization states that if map without pointers in keys and values is used then GC will omit its content.
Therefore BigCache uses `map[uint64]uint32` where keys are hashed and values are offsets of entries.
Entries are kept in byte slices, to omit GC again.
Byte slices size can grow to gigabytes without impact on performance
because GC will only see single pointer to it.
### Collisions
BigCache does not handle collisions. When new item is inserted and it's hash collides with previously stored item, new item overwrites previously stored value.
## Bigcache vs Freecache
Both caches provide the same core features but they reduce GC overhead in different ways.
Bigcache relies on `map[uint64]uint32`, freecache implements its own mapping built on
slices to reduce number of pointers.
Results from benchmark tests are presented above.
One of the advantage of bigcache over freecache is that you dont need to know
the size of the cache in advance, because when bigcache is full,
it can allocate additional memory for new entries instead of
overwriting existing ones as freecache does currently.
However hard max size in bigcache also can be set, check [HardMaxCacheSize](https://godoc.org/github.com/allegro/bigcache#Config).
## HTTP Server
This package also includes an easily deployable HTTP implementation of BigCache, which can be found in the [server](/server) package.
## More
Bigcache genesis is described in allegro.tech blog post: [writing a very fast cache service in Go](http://allegro.tech/2016/03/writing-fast-cache-service-in-go.html)
## License
BigCache is released under the Apache 2.0 license (see [LICENSE](LICENSE))

@ -0,0 +1,246 @@
package bigcache
import (
"fmt"
"time"
)
const (
minimumEntriesInShard = 10 // Minimum number of entries in single shard
)
// BigCache is fast, concurrent, evicting cache created to keep big number of entries without impact on performance.
// It keeps entries on heap but omits GC for them. To achieve that, operations take place on byte arrays,
// therefore entries (de)serialization in front of the cache will be needed in most use cases.
type BigCache struct {
shards []*cacheShard
lifeWindow uint64
clock clock
hash Hasher
config Config
shardMask uint64
close chan struct{}
}
// Response will contain metadata about the entry for which GetWithInfo(key) was called
type Response struct {
EntryStatus RemoveReason
}
// RemoveReason is a value used to signal to the user why a particular key was removed in the OnRemove callback.
type RemoveReason uint32
const (
// Expired means the key is past its LifeWindow.
Expired = RemoveReason(1)
// NoSpace means the key is the oldest and the cache size was at its maximum when Set was called, or the
// entry exceeded the maximum shard size.
NoSpace = RemoveReason(2)
// Deleted means Delete was called and this key was removed as a result.
Deleted = RemoveReason(3)
)
// NewBigCache initialize new instance of BigCache
func NewBigCache(config Config) (*BigCache, error) {
return newBigCache(config, &systemClock{})
}
func newBigCache(config Config, clock clock) (*BigCache, error) {
if !isPowerOfTwo(config.Shards) {
return nil, fmt.Errorf("Shards number must be power of two")
}
if config.MaxEntrySize < 0 {
return nil, fmt.Errorf("MaxEntrySize must be >= 0")
}
if config.MaxEntriesInWindow < 0 {
return nil, fmt.Errorf("MaxEntriesInWindow must be >= 0")
}
if config.HardMaxCacheSize < 0 {
return nil, fmt.Errorf("HardMaxCacheSize must be >= 0")
}
if config.Hasher == nil {
config.Hasher = newDefaultHasher()
}
cache := &BigCache{
shards: make([]*cacheShard, config.Shards),
lifeWindow: uint64(config.LifeWindow.Seconds()),
clock: clock,
hash: config.Hasher,
config: config,
shardMask: uint64(config.Shards - 1),
close: make(chan struct{}),
}
var onRemove func(wrappedEntry []byte, reason RemoveReason)
if config.OnRemoveWithMetadata != nil {
onRemove = cache.providedOnRemoveWithMetadata
} else if config.OnRemove != nil {
onRemove = cache.providedOnRemove
} else if config.OnRemoveWithReason != nil {
onRemove = cache.providedOnRemoveWithReason
} else {
onRemove = cache.notProvidedOnRemove
}
for i := 0; i < config.Shards; i++ {
cache.shards[i] = initNewShard(config, onRemove, clock)
}
if config.CleanWindow > 0 {
go func() {
ticker := time.NewTicker(config.CleanWindow)
defer ticker.Stop()
for {
select {
case t := <-ticker.C:
cache.cleanUp(uint64(t.Unix()))
case <-cache.close:
return
}
}
}()
}
return cache, nil
}
// Close is used to signal a shutdown of the cache when you are done with it.
// This allows the cleaning goroutines to exit and ensures references are not
// kept to the cache preventing GC of the entire cache.
func (c *BigCache) Close() error {
close(c.close)
return nil
}
// Get reads entry for the key.
// It returns an ErrEntryNotFound when
// no entry exists for the given key.
func (c *BigCache) Get(key string) ([]byte, error) {
hashedKey := c.hash.Sum64(key)
shard := c.getShard(hashedKey)
return shard.get(key, hashedKey)
}
// GetWithInfo reads entry for the key with Response info.
// It returns an ErrEntryNotFound when
// no entry exists for the given key.
func (c *BigCache) GetWithInfo(key string) ([]byte, Response, error) {
hashedKey := c.hash.Sum64(key)
shard := c.getShard(hashedKey)
return shard.getWithInfo(key, hashedKey)
}
// Set saves entry under the key
func (c *BigCache) Set(key string, entry []byte) error {
hashedKey := c.hash.Sum64(key)
shard := c.getShard(hashedKey)
return shard.set(key, hashedKey, entry)
}
// Append appends entry under the key if key exists, otherwise
// it will set the key (same behaviour as Set()). With Append() you can
// concatenate multiple entries under the same key in an lock optimized way.
func (c *BigCache) Append(key string, entry []byte) error {
hashedKey := c.hash.Sum64(key)
shard := c.getShard(hashedKey)
return shard.append(key, hashedKey, entry)
}
// Delete removes the key
func (c *BigCache) Delete(key string) error {
hashedKey := c.hash.Sum64(key)
shard := c.getShard(hashedKey)
return shard.del(hashedKey)
}
// Reset empties all cache shards
func (c *BigCache) Reset() error {
for _, shard := range c.shards {
shard.reset(c.config)
}
return nil
}
// Len computes number of entries in cache
func (c *BigCache) Len() int {
var len int
for _, shard := range c.shards {
len += shard.len()
}
return len
}
// Capacity returns amount of bytes store in the cache.
func (c *BigCache) Capacity() int {
var len int
for _, shard := range c.shards {
len += shard.capacity()
}
return len
}
// Stats returns cache's statistics
func (c *BigCache) Stats() Stats {
var s Stats
for _, shard := range c.shards {
tmp := shard.getStats()
s.Hits += tmp.Hits
s.Misses += tmp.Misses
s.DelHits += tmp.DelHits
s.DelMisses += tmp.DelMisses
s.Collisions += tmp.Collisions
}
return s
}
// KeyMetadata returns number of times a cached resource was requested.
func (c *BigCache) KeyMetadata(key string) Metadata {
hashedKey := c.hash.Sum64(key)
shard := c.getShard(hashedKey)
return shard.getKeyMetadataWithLock(hashedKey)
}
// Iterator returns iterator function to iterate over EntryInfo's from whole cache.
func (c *BigCache) Iterator() *EntryInfoIterator {
return newIterator(c)
}
func (c *BigCache) onEvict(oldestEntry []byte, currentTimestamp uint64, evict func(reason RemoveReason) error) bool {
oldestTimestamp := readTimestampFromEntry(oldestEntry)
if currentTimestamp-oldestTimestamp > c.lifeWindow {
evict(Expired)
return true
}
return false
}
func (c *BigCache) cleanUp(currentTimestamp uint64) {
for _, shard := range c.shards {
shard.cleanUp(currentTimestamp)
}
}
func (c *BigCache) getShard(hashedKey uint64) (shard *cacheShard) {
return c.shards[hashedKey&c.shardMask]
}
func (c *BigCache) providedOnRemove(wrappedEntry []byte, reason RemoveReason) {
c.config.OnRemove(readKeyFromEntry(wrappedEntry), readEntry(wrappedEntry))
}
func (c *BigCache) providedOnRemoveWithReason(wrappedEntry []byte, reason RemoveReason) {
if c.config.onRemoveFilter == 0 || (1<<uint(reason))&c.config.onRemoveFilter > 0 {
c.config.OnRemoveWithReason(readKeyFromEntry(wrappedEntry), readEntry(wrappedEntry), reason)
}
}
func (c *BigCache) notProvidedOnRemove(wrappedEntry []byte, reason RemoveReason) {
}
func (c *BigCache) providedOnRemoveWithMetadata(wrappedEntry []byte, reason RemoveReason) {
hashedKey := c.hash.Sum64(readKeyFromEntry(wrappedEntry))
shard := c.getShard(hashedKey)
c.config.OnRemoveWithMetadata(readKeyFromEntry(wrappedEntry), readEntry(wrappedEntry), shard.getKeyMetadata(hashedKey))
}

@ -0,0 +1,11 @@
// +build !appengine
package bigcache
import (
"unsafe"
)
func bytesToString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}

@ -0,0 +1,7 @@
// +build appengine
package bigcache
func bytesToString(b []byte) string {
return string(b)
}

@ -0,0 +1,14 @@
package bigcache
import "time"
type clock interface {
Epoch() int64
}
type systemClock struct {
}
func (c systemClock) Epoch() int64 {
return time.Now().Unix()
}

@ -0,0 +1,97 @@
package bigcache
import "time"
// Config for BigCache
type Config struct {
// Number of cache shards, value must be a power of two
Shards int
// Time after which entry can be evicted
LifeWindow time.Duration
// Interval between removing expired entries (clean up).
// If set to <= 0 then no action is performed. Setting to < 1 second is counterproductive — bigcache has a one second resolution.
CleanWindow time.Duration
// Max number of entries in life window. Used only to calculate initial size for cache shards.
// When proper value is set then additional memory allocation does not occur.
MaxEntriesInWindow int
// Max size of entry in bytes. Used only to calculate initial size for cache shards.
MaxEntrySize int
// StatsEnabled if true calculate the number of times a cached resource was requested.
StatsEnabled bool
// Verbose mode prints information about new memory allocation
Verbose bool
// Hasher used to map between string keys and unsigned 64bit integers, by default fnv64 hashing is used.
Hasher Hasher
// HardMaxCacheSize is a limit for BytesQueue size in MB.
// It can protect application from consuming all available memory on machine, therefore from running OOM Killer.
// Default value is 0 which means unlimited size. When the limit is higher than 0 and reached then
// the oldest entries are overridden for the new ones. The max memory consumption will be bigger than
// HardMaxCacheSize due to Shards' s additional memory. Every Shard consumes additional memory for map of keys
// and statistics (map[uint64]uint32) the size of this map is equal to number of entries in
// cache ~ 2×(64+32)×n bits + overhead or map itself.
HardMaxCacheSize int
// OnRemove is a callback fired when the oldest entry is removed because of its expiration time or no space left
// for the new entry, or because delete was called.
// Default value is nil which means no callback and it prevents from unwrapping the oldest entry.
// ignored if OnRemoveWithMetadata is specified.
OnRemove func(key string, entry []byte)
// OnRemoveWithMetadata is a callback fired when the oldest entry is removed because of its expiration time or no space left
// for the new entry, or because delete was called. A structure representing details about that specific entry.
// Default value is nil which means no callback and it prevents from unwrapping the oldest entry.
OnRemoveWithMetadata func(key string, entry []byte, keyMetadata Metadata)
// OnRemoveWithReason is a callback fired when the oldest entry is removed because of its expiration time or no space left
// for the new entry, or because delete was called. A constant representing the reason will be passed through.
// Default value is nil which means no callback and it prevents from unwrapping the oldest entry.
// Ignored if OnRemove is specified.
OnRemoveWithReason func(key string, entry []byte, reason RemoveReason)
onRemoveFilter int
// Logger is a logging interface and used in combination with `Verbose`
// Defaults to `DefaultLogger()`
Logger Logger
}
// DefaultConfig initializes config with default values.
// When load for BigCache can be predicted in advance then it is better to use custom config.
func DefaultConfig(eviction time.Duration) Config {
return Config{
Shards: 1024,
LifeWindow: eviction,
CleanWindow: 1 * time.Second,
MaxEntriesInWindow: 1000 * 10 * 60,
MaxEntrySize: 500,
StatsEnabled: false,
Verbose: true,
Hasher: newDefaultHasher(),
HardMaxCacheSize: 0,
Logger: DefaultLogger(),
}
}
// initialShardSize computes initial shard size
func (c Config) initialShardSize() int {
return max(c.MaxEntriesInWindow/c.Shards, minimumEntriesInShard)
}
// maximumShardSizeInBytes computes maximum shard size in bytes
func (c Config) maximumShardSizeInBytes() int {
maxShardSize := 0
if c.HardMaxCacheSize > 0 {
maxShardSize = convertMBToBytes(c.HardMaxCacheSize) / c.Shards
}
return maxShardSize
}
// OnRemoveFilterSet sets which remove reasons will trigger a call to OnRemoveWithReason.
// Filtering out reasons prevents bigcache from unwrapping them, which saves cpu.
func (c Config) OnRemoveFilterSet(reasons ...RemoveReason) Config {
c.onRemoveFilter = 0
for i := range reasons {
c.onRemoveFilter |= 1 << uint(reasons[i])
}
return c
}

@ -0,0 +1,83 @@
package bigcache
import (
"encoding/binary"
)
const (
timestampSizeInBytes = 8 // Number of bytes used for timestamp
hashSizeInBytes = 8 // Number of bytes used for hash
keySizeInBytes = 2 // Number of bytes used for size of entry key
headersSizeInBytes = timestampSizeInBytes + hashSizeInBytes + keySizeInBytes // Number of bytes used for all headers
)
func wrapEntry(timestamp uint64, hash uint64, key string, entry []byte, buffer *[]byte) []byte {
keyLength := len(key)
blobLength := len(entry) + headersSizeInBytes + keyLength
if blobLength > len(*buffer) {
*buffer = make([]byte, blobLength)
}
blob := *buffer
binary.LittleEndian.PutUint64(blob, timestamp)
binary.LittleEndian.PutUint64(blob[timestampSizeInBytes:], hash)
binary.LittleEndian.PutUint16(blob[timestampSizeInBytes+hashSizeInBytes:], uint16(keyLength))
copy(blob[headersSizeInBytes:], key)
copy(blob[headersSizeInBytes+keyLength:], entry)
return blob[:blobLength]
}
func appendToWrappedEntry(timestamp uint64, wrappedEntry []byte, entry []byte, buffer *[]byte) []byte {
blobLength := len(wrappedEntry) + len(entry)
if blobLength > len(*buffer) {
*buffer = make([]byte, blobLength)
}
blob := *buffer
binary.LittleEndian.PutUint64(blob, timestamp)
copy(blob[timestampSizeInBytes:], wrappedEntry[timestampSizeInBytes:])
copy(blob[len(wrappedEntry):], entry)
return blob[:blobLength]
}
func readEntry(data []byte) []byte {
length := binary.LittleEndian.Uint16(data[timestampSizeInBytes+hashSizeInBytes:])
// copy on read
dst := make([]byte, len(data)-int(headersSizeInBytes+length))
copy(dst, data[headersSizeInBytes+length:])
return dst
}
func readTimestampFromEntry(data []byte) uint64 {
return binary.LittleEndian.Uint64(data)
}
func readKeyFromEntry(data []byte) string {
length := binary.LittleEndian.Uint16(data[timestampSizeInBytes+hashSizeInBytes:])
// copy on read
dst := make([]byte, length)
copy(dst, data[headersSizeInBytes:headersSizeInBytes+length])
return bytesToString(dst)
}
func compareKeyFromEntry(data []byte, key string) bool {
length := binary.LittleEndian.Uint16(data[timestampSizeInBytes+hashSizeInBytes:])
return bytesToString(data[headersSizeInBytes:headersSizeInBytes+length]) == key
}
func readHashFromEntry(data []byte) uint64 {
return binary.LittleEndian.Uint64(data[timestampSizeInBytes:])
}
func resetKeyFromEntry(data []byte) {
binary.LittleEndian.PutUint64(data[timestampSizeInBytes:], 0)
}

@ -0,0 +1,8 @@
package bigcache
import "errors"
var (
// ErrEntryNotFound is an error type struct which is returned when entry was not found for provided key
ErrEntryNotFound = errors.New("Entry not found")
)

@ -0,0 +1,28 @@
package bigcache
// newDefaultHasher returns a new 64-bit FNV-1a Hasher which makes no memory allocations.
// Its Sum64 method will lay the value out in big-endian byte order.
// See https://en.wikipedia.org/wiki/FowlerNollVo_hash_function
func newDefaultHasher() Hasher {
return fnv64a{}
}
type fnv64a struct{}
const (
// offset64 FNVa offset basis. See https://en.wikipedia.org/wiki/FowlerNollVo_hash_function#FNV-1a_hash
offset64 = 14695981039346656037
// prime64 FNVa prime value. See https://en.wikipedia.org/wiki/FowlerNollVo_hash_function#FNV-1a_hash
prime64 = 1099511628211
)
// Sum64 gets the string and returns its uint64 hash value.
func (f fnv64a) Sum64(key string) uint64 {
var hash uint64 = offset64
for i := 0; i < len(key); i++ {
hash ^= uint64(key[i])
hash *= prime64
}
return hash
}

@ -0,0 +1,8 @@
package bigcache
// Hasher is responsible for generating unsigned, 64 bit hash of provided string. Hasher should minimize collisions
// (generating same hash for different strings) and while performance is also important fast functions are preferable (i.e.
// you can use FarmHash family).
type Hasher interface {
Sum64(string) uint64
}

@ -0,0 +1,146 @@
package bigcache
import (
"sync"
)
type iteratorError string
func (e iteratorError) Error() string {
return string(e)
}
// ErrInvalidIteratorState is reported when iterator is in invalid state
const ErrInvalidIteratorState = iteratorError("Iterator is in invalid state. Use SetNext() to move to next position")
// ErrCannotRetrieveEntry is reported when entry cannot be retrieved from underlying
const ErrCannotRetrieveEntry = iteratorError("Could not retrieve entry from cache")
var emptyEntryInfo = EntryInfo{}
// EntryInfo holds informations about entry in the cache
type EntryInfo struct {
timestamp uint64
hash uint64
key string
value []byte
err error
}
// Key returns entry's underlying key
func (e EntryInfo) Key() string {
return e.key
}
// Hash returns entry's hash value
func (e EntryInfo) Hash() uint64 {
return e.hash
}
// Timestamp returns entry's timestamp (time of insertion)
func (e EntryInfo) Timestamp() uint64 {
return e.timestamp
}
// Value returns entry's underlying value
func (e EntryInfo) Value() []byte {
return e.value
}
// EntryInfoIterator allows to iterate over entries in the cache
type EntryInfoIterator struct {
mutex sync.Mutex
cache *BigCache
currentShard int
currentIndex int
currentEntryInfo EntryInfo
elements []uint64
elementsCount int
valid bool
}
// SetNext moves to next element and returns true if it exists.
func (it *EntryInfoIterator) SetNext() bool {
it.mutex.Lock()
it.valid = false
it.currentIndex++
if it.elementsCount > it.currentIndex {
it.valid = true
empty := it.setCurrentEntry()
it.mutex.Unlock()
if empty {
return it.SetNext()
}
return true
}
for i := it.currentShard + 1; i < it.cache.config.Shards; i++ {
it.elements, it.elementsCount = it.cache.shards[i].copyHashedKeys()
// Non empty shard - stick with it
if it.elementsCount > 0 {
it.currentIndex = 0
it.currentShard = i
it.valid = true
empty := it.setCurrentEntry()
it.mutex.Unlock()
if empty {
return it.SetNext()
}
return true
}
}
it.mutex.Unlock()
return false
}
func (it *EntryInfoIterator) setCurrentEntry() bool {
var entryNotFound = false
entry, err := it.cache.shards[it.currentShard].getEntry(it.elements[it.currentIndex])
if err == ErrEntryNotFound {
it.currentEntryInfo = emptyEntryInfo
entryNotFound = true
} else if err != nil {
it.currentEntryInfo = EntryInfo{
err: err,
}
} else {
it.currentEntryInfo = EntryInfo{
timestamp: readTimestampFromEntry(entry),
hash: readHashFromEntry(entry),
key: readKeyFromEntry(entry),
value: readEntry(entry),
err: err,
}
}
return entryNotFound
}
func newIterator(cache *BigCache) *EntryInfoIterator {
elements, count := cache.shards[0].copyHashedKeys()
return &EntryInfoIterator{
cache: cache,
currentShard: 0,
currentIndex: -1,
elements: elements,
elementsCount: count,
}
}
// Value returns current value from the iterator
func (it *EntryInfoIterator) Value() (EntryInfo, error) {
if !it.valid {
return emptyEntryInfo, ErrInvalidIteratorState
}
return it.currentEntryInfo, it.currentEntryInfo.err
}

@ -0,0 +1,30 @@
package bigcache
import (
"log"
"os"
)
// Logger is invoked when `Config.Verbose=true`
type Logger interface {
Printf(format string, v ...interface{})
}
// this is a safeguard, breaking on compile time in case
// `log.Logger` does not adhere to our `Logger` interface.
// see https://golang.org/doc/faq#guarantee_satisfies_interface
var _ Logger = &log.Logger{}
// DefaultLogger returns a `Logger` implementation
// backed by stdlib's log
func DefaultLogger() *log.Logger {
return log.New(os.Stdout, "", log.LstdFlags)
}
func newLogger(custom Logger) Logger {
if custom != nil {
return custom
}
return DefaultLogger()
}

@ -0,0 +1,269 @@
package queue
import (
"encoding/binary"
"log"
"time"
)
const (
// Number of bytes to encode 0 in uvarint format
minimumHeaderSize = 17 // 1 byte blobsize + timestampSizeInBytes + hashSizeInBytes
// Bytes before left margin are not used. Zero index means element does not exist in queue, useful while reading slice from index
leftMarginIndex = 1
)
var (
errEmptyQueue = &queueError{"Empty queue"}
errInvalidIndex = &queueError{"Index must be greater than zero. Invalid index."}
errIndexOutOfBounds = &queueError{"Index out of range"}
)
// BytesQueue is a non-thread safe queue type of fifo based on bytes array.
// For every push operation index of entry is returned. It can be used to read the entry later
type BytesQueue struct {
full bool
array []byte
capacity int
maxCapacity int
head int
tail int
count int
rightMargin int
headerBuffer []byte
verbose bool
}
type queueError struct {
message string
}
// getNeededSize returns the number of bytes an entry of length need in the queue
func getNeededSize(length int) int {
var header int
switch {
case length < 127: // 1<<7-1
header = 1
case length < 16382: // 1<<14-2
header = 2
case length < 2097149: // 1<<21 -3
header = 3
case length < 268435452: // 1<<28 -4
header = 4
default:
header = 5
}
return length + header
}
// NewBytesQueue initialize new bytes queue.
// capacity is used in bytes array allocation
// When verbose flag is set then information about memory allocation are printed
func NewBytesQueue(capacity int, maxCapacity int, verbose bool) *BytesQueue {
return &BytesQueue{
array: make([]byte, capacity),
capacity: capacity,
maxCapacity: maxCapacity,
headerBuffer: make([]byte, binary.MaxVarintLen32),
tail: leftMarginIndex,
head: leftMarginIndex,
rightMargin: leftMarginIndex,
verbose: verbose,
}
}
// Reset removes all entries from queue
func (q *BytesQueue) Reset() {
// Just reset indexes
q.tail = leftMarginIndex
q.head = leftMarginIndex
q.rightMargin = leftMarginIndex
q.count = 0
q.full = false
}
// Push copies entry at the end of queue and moves tail pointer. Allocates more space if needed.
// Returns index for pushed data or error if maximum size queue limit is reached.
func (q *BytesQueue) Push(data []byte) (int, error) {
neededSize := getNeededSize(len(data))
if !q.canInsertAfterTail(neededSize) {
if q.canInsertBeforeHead(neededSize) {
q.tail = leftMarginIndex
} else if q.capacity+neededSize >= q.maxCapacity && q.maxCapacity > 0 {
return -1, &queueError{"Full queue. Maximum size limit reached."}
} else {
q.allocateAdditionalMemory(neededSize)
}
}
index := q.tail
q.push(data, neededSize)
return index, nil
}
func (q *BytesQueue) allocateAdditionalMemory(minimum int) {
start := time.Now()
if q.capacity < minimum {
q.capacity += minimum
}
q.capacity = q.capacity * 2
if q.capacity > q.maxCapacity && q.maxCapacity > 0 {
q.capacity = q.maxCapacity
}
oldArray := q.array
q.array = make([]byte, q.capacity)
if leftMarginIndex != q.rightMargin {
copy(q.array, oldArray[:q.rightMargin])
if q.tail <= q.head {
if q.tail != q.head {
// created slice is slightly larger then need but this is fine after only the needed bytes are copied
q.push(make([]byte, q.head-q.tail), q.head-q.tail)
}
q.head = leftMarginIndex
q.tail = q.rightMargin
}
}
q.full = false
if q.verbose {
log.Printf("Allocated new queue in %s; Capacity: %d \n", time.Since(start), q.capacity)
}
}
func (q *BytesQueue) push(data []byte, len int) {
headerEntrySize := binary.PutUvarint(q.headerBuffer, uint64(len))
q.copy(q.headerBuffer, headerEntrySize)
q.copy(data, len-headerEntrySize)
if q.tail > q.head {
q.rightMargin = q.tail
}
if q.tail == q.head {
q.full = true
}
q.count++
}
func (q *BytesQueue) copy(data []byte, len int) {
q.tail += copy(q.array[q.tail:], data[:len])
}
// Pop reads the oldest entry from queue and moves head pointer to the next one
func (q *BytesQueue) Pop() ([]byte, error) {
data, blockSize, err := q.peek(q.head)
if err != nil {
return nil, err
}
q.head += blockSize
q.count--
if q.head == q.rightMargin {
q.head = leftMarginIndex
if q.tail == q.rightMargin {
q.tail = leftMarginIndex
}
q.rightMargin = q.tail
}
q.full = false
return data, nil
}
// Peek reads the oldest entry from list without moving head pointer
func (q *BytesQueue) Peek() ([]byte, error) {
data, _, err := q.peek(q.head)
return data, err
}
// Get reads entry from index
func (q *BytesQueue) Get(index int) ([]byte, error) {
data, _, err := q.peek(index)
return data, err
}
// CheckGet checks if an entry can be read from index
func (q *BytesQueue) CheckGet(index int) error {
return q.peekCheckErr(index)
}
// Capacity returns number of allocated bytes for queue
func (q *BytesQueue) Capacity() int {
return q.capacity
}
// Len returns number of entries kept in queue
func (q *BytesQueue) Len() int {
return q.count
}
// Error returns error message
func (e *queueError) Error() string {
return e.message
}
// peekCheckErr is identical to peek, but does not actually return any data
func (q *BytesQueue) peekCheckErr(index int) error {
if q.count == 0 {
return errEmptyQueue
}
if index <= 0 {
return errInvalidIndex
}
if index >= len(q.array) {
return errIndexOutOfBounds
}
return nil
}
// peek returns the data from index and the number of bytes to encode the length of the data in uvarint format
func (q *BytesQueue) peek(index int) ([]byte, int, error) {
err := q.peekCheckErr(index)
if err != nil {
return nil, 0, err
}
blockSize, n := binary.Uvarint(q.array[index:])
return q.array[index+n : index+int(blockSize)], int(blockSize), nil
}
// canInsertAfterTail returns true if it's possible to insert an entry of size of need after the tail of the queue
func (q *BytesQueue) canInsertAfterTail(need int) bool {
if q.full {
return false
}
if q.tail >= q.head {
return q.capacity-q.tail >= need
}
// 1. there is exactly need bytes between head and tail, so we do not need
// to reserve extra space for a potential empty entry when realloc this queue
// 2. still have unused space between tail and head, then we must reserve
// at least headerEntrySize bytes so we can put an empty entry
return q.head-q.tail == need || q.head-q.tail >= need+minimumHeaderSize
}
// canInsertBeforeHead returns true if it's possible to insert an entry of size of need before the head of the queue
func (q *BytesQueue) canInsertBeforeHead(need int) bool {
if q.full {
return false
}
if q.tail >= q.head {
return q.head-leftMarginIndex == need || q.head-leftMarginIndex >= need+minimumHeaderSize
}
return q.head-q.tail == need || q.head-q.tail >= need+minimumHeaderSize
}

@ -0,0 +1,434 @@
package bigcache
import (
"fmt"
"sync"
"sync/atomic"
"github.com/allegro/bigcache/v3/queue"
)
type onRemoveCallback func(wrappedEntry []byte, reason RemoveReason)
// Metadata contains information of a specific entry
type Metadata struct {
RequestCount uint32
}
type cacheShard struct {
hashmap map[uint64]uint32
entries queue.BytesQueue
lock sync.RWMutex
entryBuffer []byte
onRemove onRemoveCallback
isVerbose bool
statsEnabled bool
logger Logger
clock clock
lifeWindow uint64
hashmapStats map[uint64]uint32
stats Stats
}
func (s *cacheShard) getWithInfo(key string, hashedKey uint64) (entry []byte, resp Response, err error) {
currentTime := uint64(s.clock.Epoch())
s.lock.RLock()
wrappedEntry, err := s.getWrappedEntry(hashedKey)
if err != nil {
s.lock.RUnlock()
return nil, resp, err
}
if entryKey := readKeyFromEntry(wrappedEntry); key != entryKey {
s.lock.RUnlock()
s.collision()
if s.isVerbose {
s.logger.Printf("Collision detected. Both %q and %q have the same hash %x", key, entryKey, hashedKey)
}
return nil, resp, ErrEntryNotFound
}
entry = readEntry(wrappedEntry)
oldestTimeStamp := readTimestampFromEntry(wrappedEntry)
s.lock.RUnlock()
s.hit(hashedKey)
if currentTime-oldestTimeStamp >= s.lifeWindow {
resp.EntryStatus = Expired
}
return entry, resp, nil
}
func (s *cacheShard) get(key string, hashedKey uint64) ([]byte, error) {
s.lock.RLock()
wrappedEntry, err := s.getWrappedEntry(hashedKey)
if err != nil {
s.lock.RUnlock()
return nil, err
}
if entryKey := readKeyFromEntry(wrappedEntry); key != entryKey {
s.lock.RUnlock()
s.collision()
if s.isVerbose {
s.logger.Printf("Collision detected. Both %q and %q have the same hash %x", key, entryKey, hashedKey)
}
return nil, ErrEntryNotFound
}
entry := readEntry(wrappedEntry)
s.lock.RUnlock()
s.hit(hashedKey)
return entry, nil
}
func (s *cacheShard) getWrappedEntry(hashedKey uint64) ([]byte, error) {
itemIndex := s.hashmap[hashedKey]
if itemIndex == 0 {
s.miss()
return nil, ErrEntryNotFound
}
wrappedEntry, err := s.entries.Get(int(itemIndex))
if err != nil {
s.miss()
return nil, err
}
return wrappedEntry, err
}
func (s *cacheShard) getValidWrapEntry(key string, hashedKey uint64) ([]byte, error) {
wrappedEntry, err := s.getWrappedEntry(hashedKey)
if err != nil {
return nil, err
}
if !compareKeyFromEntry(wrappedEntry, key) {
s.collision()
if s.isVerbose {
s.logger.Printf("Collision detected. Both %q and %q have the same hash %x", key, readKeyFromEntry(wrappedEntry), hashedKey)
}
return nil, ErrEntryNotFound
}
s.hitWithoutLock(hashedKey)
return wrappedEntry, nil
}
func (s *cacheShard) set(key string, hashedKey uint64, entry []byte) error {
currentTimestamp := uint64(s.clock.Epoch())
s.lock.Lock()
if previousIndex := s.hashmap[hashedKey]; previousIndex != 0 {
if previousEntry, err := s.entries.Get(int(previousIndex)); err == nil {
resetKeyFromEntry(previousEntry)
//remove hashkey
delete(s.hashmap, hashedKey)
}
}
if oldestEntry, err := s.entries.Peek(); err == nil {
s.onEvict(oldestEntry, currentTimestamp, s.removeOldestEntry)
}
w := wrapEntry(currentTimestamp, hashedKey, key, entry, &s.entryBuffer)
for {
if index, err := s.entries.Push(w); err == nil {
s.hashmap[hashedKey] = uint32(index)
s.lock.Unlock()
return nil
}
if s.removeOldestEntry(NoSpace) != nil {
s.lock.Unlock()
return fmt.Errorf("entry is bigger than max shard size")
}
}
}
func (s *cacheShard) addNewWithoutLock(key string, hashedKey uint64, entry []byte) error {
currentTimestamp := uint64(s.clock.Epoch())
if oldestEntry, err := s.entries.Peek(); err == nil {
s.onEvict(oldestEntry, currentTimestamp, s.removeOldestEntry)
}
w := wrapEntry(currentTimestamp, hashedKey, key, entry, &s.entryBuffer)
for {
if index, err := s.entries.Push(w); err == nil {
s.hashmap[hashedKey] = uint32(index)
return nil
}
if s.removeOldestEntry(NoSpace) != nil {
return fmt.Errorf("entry is bigger than max shard size")
}
}
}
func (s *cacheShard) setWrappedEntryWithoutLock(currentTimestamp uint64, w []byte, hashedKey uint64) error {
if previousIndex := s.hashmap[hashedKey]; previousIndex != 0 {
if previousEntry, err := s.entries.Get(int(previousIndex)); err == nil {
resetKeyFromEntry(previousEntry)
}
}
if oldestEntry, err := s.entries.Peek(); err == nil {
s.onEvict(oldestEntry, currentTimestamp, s.removeOldestEntry)
}
for {
if index, err := s.entries.Push(w); err == nil {
s.hashmap[hashedKey] = uint32(index)
return nil
}
if s.removeOldestEntry(NoSpace) != nil {
return fmt.Errorf("entry is bigger than max shard size")
}
}
}
func (s *cacheShard) append(key string, hashedKey uint64, entry []byte) error {
s.lock.Lock()
wrappedEntry, err := s.getValidWrapEntry(key, hashedKey)
if err == ErrEntryNotFound {
err = s.addNewWithoutLock(key, hashedKey, entry)
s.lock.Unlock()
return err
}
if err != nil {
s.lock.Unlock()
return err
}
currentTimestamp := uint64(s.clock.Epoch())
w := appendToWrappedEntry(currentTimestamp, wrappedEntry, entry, &s.entryBuffer)
err = s.setWrappedEntryWithoutLock(currentTimestamp, w, hashedKey)
s.lock.Unlock()
return err
}
func (s *cacheShard) del(hashedKey uint64) error {
// Optimistic pre-check using only readlock
s.lock.RLock()
{
itemIndex := s.hashmap[hashedKey]
if itemIndex == 0 {
s.lock.RUnlock()
s.delmiss()
return ErrEntryNotFound
}
if err := s.entries.CheckGet(int(itemIndex)); err != nil {
s.lock.RUnlock()
s.delmiss()
return err
}
}
s.lock.RUnlock()
s.lock.Lock()
{
// After obtaining the writelock, we need to read the same again,
// since the data delivered earlier may be stale now
itemIndex := s.hashmap[hashedKey]
if itemIndex == 0 {
s.lock.Unlock()
s.delmiss()
return ErrEntryNotFound
}
wrappedEntry, err := s.entries.Get(int(itemIndex))
if err != nil {
s.lock.Unlock()
s.delmiss()
return err
}
delete(s.hashmap, hashedKey)
s.onRemove(wrappedEntry, Deleted)
if s.statsEnabled {
delete(s.hashmapStats, hashedKey)
}
resetKeyFromEntry(wrappedEntry)
}
s.lock.Unlock()
s.delhit()
return nil
}
func (s *cacheShard) onEvict(oldestEntry []byte, currentTimestamp uint64, evict func(reason RemoveReason) error) bool {
oldestTimestamp := readTimestampFromEntry(oldestEntry)
if currentTimestamp-oldestTimestamp > s.lifeWindow {
evict(Expired)
return true
}
return false
}
func (s *cacheShard) cleanUp(currentTimestamp uint64) {
s.lock.Lock()
for {
if oldestEntry, err := s.entries.Peek(); err != nil {
break
} else if evicted := s.onEvict(oldestEntry, currentTimestamp, s.removeOldestEntry); !evicted {
break
}
}
s.lock.Unlock()
}
func (s *cacheShard) getEntry(hashedKey uint64) ([]byte, error) {
s.lock.RLock()
entry, err := s.getWrappedEntry(hashedKey)
// copy entry
newEntry := make([]byte, len(entry))
copy(newEntry, entry)
s.lock.RUnlock()
return newEntry, err
}
func (s *cacheShard) copyHashedKeys() (keys []uint64, next int) {
s.lock.RLock()
keys = make([]uint64, len(s.hashmap))
for key := range s.hashmap {
keys[next] = key
next++
}
s.lock.RUnlock()
return keys, next
}
func (s *cacheShard) removeOldestEntry(reason RemoveReason) error {
oldest, err := s.entries.Pop()
if err == nil {
hash := readHashFromEntry(oldest)
if hash == 0 {
// entry has been explicitly deleted with resetKeyFromEntry, ignore
return nil
}
delete(s.hashmap, hash)
s.onRemove(oldest, reason)
if s.statsEnabled {
delete(s.hashmapStats, hash)
}
return nil
}
return err
}
func (s *cacheShard) reset(config Config) {
s.lock.Lock()
s.hashmap = make(map[uint64]uint32, config.initialShardSize())
s.entryBuffer = make([]byte, config.MaxEntrySize+headersSizeInBytes)
s.entries.Reset()
s.lock.Unlock()
}
func (s *cacheShard) len() int {
s.lock.RLock()
res := len(s.hashmap)
s.lock.RUnlock()
return res
}
func (s *cacheShard) capacity() int {
s.lock.RLock()
res := s.entries.Capacity()
s.lock.RUnlock()
return res
}
func (s *cacheShard) getStats() Stats {
var stats = Stats{
Hits: atomic.LoadInt64(&s.stats.Hits),
Misses: atomic.LoadInt64(&s.stats.Misses),
DelHits: atomic.LoadInt64(&s.stats.DelHits),
DelMisses: atomic.LoadInt64(&s.stats.DelMisses),
Collisions: atomic.LoadInt64(&s.stats.Collisions),
}
return stats
}
func (s *cacheShard) getKeyMetadataWithLock(key uint64) Metadata {
s.lock.RLock()
c := s.hashmapStats[key]
s.lock.RUnlock()
return Metadata{
RequestCount: c,
}
}
func (s *cacheShard) getKeyMetadata(key uint64) Metadata {
return Metadata{
RequestCount: s.hashmapStats[key],
}
}
func (s *cacheShard) hit(key uint64) {
atomic.AddInt64(&s.stats.Hits, 1)
if s.statsEnabled {
s.lock.Lock()
s.hashmapStats[key]++
s.lock.Unlock()
}
}
func (s *cacheShard) hitWithoutLock(key uint64) {
atomic.AddInt64(&s.stats.Hits, 1)
if s.statsEnabled {
s.hashmapStats[key]++
}
}
func (s *cacheShard) miss() {
atomic.AddInt64(&s.stats.Misses, 1)
}
func (s *cacheShard) delhit() {
atomic.AddInt64(&s.stats.DelHits, 1)
}
func (s *cacheShard) delmiss() {
atomic.AddInt64(&s.stats.DelMisses, 1)
}
func (s *cacheShard) collision() {
atomic.AddInt64(&s.stats.Collisions, 1)
}
func initNewShard(config Config, callback onRemoveCallback, clock clock) *cacheShard {
bytesQueueInitialCapacity := config.initialShardSize() * config.MaxEntrySize
maximumShardSizeInBytes := config.maximumShardSizeInBytes()
if maximumShardSizeInBytes > 0 && bytesQueueInitialCapacity > maximumShardSizeInBytes {
bytesQueueInitialCapacity = maximumShardSizeInBytes
}
return &cacheShard{
hashmap: make(map[uint64]uint32, config.initialShardSize()),
hashmapStats: make(map[uint64]uint32, config.initialShardSize()),
entries: *queue.NewBytesQueue(bytesQueueInitialCapacity, maximumShardSizeInBytes, config.Verbose),
entryBuffer: make([]byte, config.MaxEntrySize+headersSizeInBytes),
onRemove: callback,
isVerbose: config.Verbose,
logger: newLogger(config.Logger),
clock: clock,
lifeWindow: uint64(config.LifeWindow.Seconds()),
statsEnabled: config.StatsEnabled,
}
}

@ -0,0 +1,15 @@
package bigcache
// Stats stores cache statistics
type Stats struct {
// Hits is a number of successfully found keys
Hits int64 `json:"hits"`
// Misses is a number of not found keys
Misses int64 `json:"misses"`
// DelHits is a number of successfully deleted keys
DelHits int64 `json:"delete_hits"`
// DelMisses is a number of not deleted keys
DelMisses int64 `json:"delete_misses"`
// Collisions is a number of happened key-collisions
Collisions int64 `json:"collisions"`
}

@ -0,0 +1,23 @@
package bigcache
func max(a, b int) int {
if a > b {
return a
}
return b
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func convertMBToBytes(value int) int {
return value * 1024 * 1024
}
func isPowerOfTwo(number int) bool {
return (number != 0) && (number&(number-1)) == 0
}

@ -0,0 +1,25 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
/.tags

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016 Bastien Gysler
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -0,0 +1,82 @@
# goxml2json [![CircleCI](https://circleci.com/gh/basgys/goxml2json.svg?style=svg)](https://circleci.com/gh/basgys/goxml2json)
Go package that converts XML to JSON
### Install
go get -u github.com/basgys/goxml2json
### Importing
import github.com/basgys/goxml2json
### Usage
**Code example**
```go
package main
import (
"fmt"
"strings"
xj "github.com/basgys/goxml2json"
)
func main() {
// xml is an io.Reader
xml := strings.NewReader(`<?xml version="1.0" encoding="UTF-8"?><hello>world</hello>`)
json, err := xj.Convert(xml)
if err != nil {
panic("That's embarrassing...")
}
fmt.Println(json.String())
// {"hello": "world"}
}
```
**Input**
```xml
<?xml version="1.0" encoding="UTF-8"?>
<osm version="0.6" generator="CGImap 0.0.2">
<bounds minlat="54.0889580" minlon="12.2487570" maxlat="54.0913900" maxlon="12.2524800"/>
<foo>bar</foo>
</osm>
```
**Output**
```json
{
"osm": {
"-version": "0.6",
"-generator": "CGImap 0.0.2",
"bounds": {
"-minlat": "54.0889580",
"-minlon": "12.2487570",
"-maxlat": "54.0913900",
"-maxlon": "12.2524800"
},
"foo": "bar"
}
}
```
### Contributing
Feel free to contribute to this project if you want to fix/extend/improve it.
### Contributors
- [DirectX](https://github.com/directx)
- [samuelhug](https://github.com/samuelhug)
### TODO
* Extract data types in JSON (numbers, boolean, ...)
* Categorise errors
* Option to prettify the JSON output
* Benchmark

@ -0,0 +1,25 @@
package xml2json
import (
"bytes"
"io"
)
// Convert converts the given XML document to JSON
func Convert(r io.Reader) (*bytes.Buffer, error) {
// Decode XML document
root := &Node{}
err := NewDecoder(r).Decode(root)
if err != nil {
return nil, err
}
// Then encode it in JSON
buf := new(bytes.Buffer)
err = NewEncoder(buf).Encode(root)
if err != nil {
return nil, err
}
return buf, nil
}

@ -0,0 +1,140 @@
package xml2json
import (
"encoding/xml"
"io"
"unicode"
"golang.org/x/net/html/charset"
)
const (
attrPrefix = "-"
contentPrefix = "#"
)
// A Decoder reads and decodes XML objects from an input stream.
type Decoder struct {
r io.Reader
err error
attributePrefix string
contentPrefix string
}
type element struct {
parent *element
n *Node
label string
}
func (dec *Decoder) SetAttributePrefix(prefix string) {
dec.attributePrefix = prefix
}
func (dec *Decoder) SetContentPrefix(prefix string) {
dec.contentPrefix = prefix
}
func (dec *Decoder) DecodeWithCustomPrefixes(root *Node, contentPrefix string, attributePrefix string) error {
dec.contentPrefix = contentPrefix
dec.attributePrefix = attributePrefix
return dec.Decode(root)
}
// NewDecoder returns a new decoder that reads from r.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}
// Decode reads the next JSON-encoded value from its
// input and stores it in the value pointed to by v.
func (dec *Decoder) Decode(root *Node) error {
if dec.contentPrefix == "" {
dec.contentPrefix = contentPrefix
}
if dec.attributePrefix == "" {
dec.attributePrefix = attrPrefix
}
xmlDec := xml.NewDecoder(dec.r)
// That will convert the charset if the provided XML is non-UTF-8
xmlDec.CharsetReader = charset.NewReaderLabel
// Create first element from the root node
elem := &element{
parent: nil,
n: root,
}
for {
t, _ := xmlDec.Token()
if t == nil {
break
}
switch se := t.(type) {
case xml.StartElement:
// Build new a new current element and link it to its parent
elem = &element{
parent: elem,
n: &Node{},
label: se.Name.Local,
}
// Extract attributes as children
for _, a := range se.Attr {
elem.n.AddChild(dec.attributePrefix+a.Name.Local, &Node{Data: a.Value})
}
case xml.CharData:
// Extract XML data (if any)
elem.n.Data = trimNonGraphic(string(xml.CharData(se)))
case xml.EndElement:
// And add it to its parent list
if elem.parent != nil {
elem.parent.n.AddChild(elem.label, elem.n)
}
// Then change the current element to its parent
elem = elem.parent
}
}
return nil
}
// trimNonGraphic returns a slice of the string s, with all leading and trailing
// non graphic characters and spaces removed.
//
// Graphic characters include letters, marks, numbers, punctuation, symbols,
// and spaces, from categories L, M, N, P, S, Zs.
// Spacing characters are set by category Z and property Pattern_White_Space.
func trimNonGraphic(s string) string {
if s == "" {
return s
}
var first *int
var last int
for i, r := range []rune(s) {
if !unicode.IsGraphic(r) || unicode.IsSpace(r) {
continue
}
if first == nil {
f := i // copy i
first = &f
last = i
} else {
last = i
}
}
// If first is nil, it means there are no graphic characters
if first == nil {
return ""
}
return string([]rune(s)[*first : last+1])
}

@ -0,0 +1,2 @@
// Package xml2json is an XML to JSON converter
package xml2json

@ -0,0 +1,197 @@
package xml2json
import (
"bytes"
"io"
"unicode/utf8"
)
// An Encoder writes JSON objects to an output stream.
type Encoder struct {
w io.Writer
err error
contentPrefix string
attributePrefix string
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w}
}
func (enc *Encoder) SetAttributePrefix(prefix string) {
enc.attributePrefix = prefix
}
func (enc *Encoder) SetContentPrefix(prefix string) {
enc.contentPrefix = prefix
}
func (enc *Encoder) EncodeWithCustomPrefixes(root *Node, contentPrefix string, attributePrefix string) error {
enc.contentPrefix = contentPrefix
enc.attributePrefix = attributePrefix
return enc.Encode(root)
}
// Encode writes the JSON encoding of v to the stream
func (enc *Encoder) Encode(root *Node) error {
if enc.err != nil {
return enc.err
}
if root == nil {
return nil
}
if enc.contentPrefix == "" {
enc.contentPrefix = contentPrefix
}
if enc.attributePrefix == "" {
enc.attributePrefix = attrPrefix
}
enc.err = enc.format(root, 0)
// Terminate each value with a newline.
// This makes the output look a little nicer
// when debugging, and some kind of space
// is required if the encoded value was a number,
// so that the reader knows there aren't more
// digits coming.
enc.write("\n")
return enc.err
}
func (enc *Encoder) format(n *Node, lvl int) error {
if n.IsComplex() {
enc.write("{")
// Add data as an additional attibute (if any)
if len(n.Data) > 0 {
enc.write("\"")
enc.write(enc.contentPrefix)
enc.write("content")
enc.write("\": ")
enc.write(sanitiseString(n.Data))
enc.write(", ")
}
i := 0
tot := len(n.Children)
for label, children := range n.Children {
enc.write("\"")
enc.write(label)
enc.write("\": ")
if len(children) > 1 {
// Array
enc.write("[")
for j, c := range children {
enc.format(c, lvl+1)
if j < len(children)-1 {
enc.write(", ")
}
}
enc.write("]")
} else {
// Map
enc.format(children[0], lvl+1)
}
if i < tot-1 {
enc.write(", ")
}
i++
}
enc.write("}")
} else {
// TODO : Extract data type
enc.write(sanitiseString(n.Data))
}
return nil
}
func (enc *Encoder) write(s string) {
enc.w.Write([]byte(s))
}
// https://golang.org/src/encoding/json/encode.go?s=5584:5627#L788
var hex = "0123456789abcdef"
func sanitiseString(s string) string {
var buf bytes.Buffer
buf.WriteByte('"')
start := 0
for i := 0; i < len(s); {
if b := s[i]; b < utf8.RuneSelf {
if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' {
i++
continue
}
if start < i {
buf.WriteString(s[start:i])
}
switch b {
case '\\', '"':
buf.WriteByte('\\')
buf.WriteByte(b)
case '\n':
buf.WriteByte('\\')
buf.WriteByte('n')
case '\r':
buf.WriteByte('\\')
buf.WriteByte('r')
case '\t':
buf.WriteByte('\\')
buf.WriteByte('t')
default:
// This encodes bytes < 0x20 except for \n and \r,
// as well as <, > and &. The latter are escaped because they
// can lead to security holes when user-controlled strings
// are rendered into JSON and served to some browsers.
buf.WriteString(`\u00`)
buf.WriteByte(hex[b>>4])
buf.WriteByte(hex[b&0xF])
}
i++
start = i
continue
}
c, size := utf8.DecodeRuneInString(s[i:])
if c == utf8.RuneError && size == 1 {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\ufffd`)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
if c == '\u2028' || c == '\u2029' {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\u202`)
buf.WriteByte(hex[c&0xF])
i += size
start = i
continue
}
i += size
}
if start < len(s) {
buf.WriteString(s[start:])
}
buf.WriteByte('"')
return buf.String()
}

@ -0,0 +1,25 @@
package xml2json
// Node is a data element on a tree
type Node struct {
Children map[string]Nodes
Data string
}
// Nodes is a list of nodes
type Nodes []*Node
// AddChild appends a node to the list of children
func (n *Node) AddChild(s string, c *Node) {
// Lazy lazy
if n.Children == nil {
n.Children = map[string]Nodes{}
}
n.Children[s] = append(n.Children[s], c)
}
// IsComplex returns whether it is a complex type (has children)
func (n *Node) IsComplex() bool {
return len(n.Children) > 0
}

@ -0,0 +1,22 @@
Copyright (c) 2016 Caleb Spare
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -0,0 +1,69 @@
# xxhash
[![Go Reference](https://pkg.go.dev/badge/github.com/cespare/xxhash/v2.svg)](https://pkg.go.dev/github.com/cespare/xxhash/v2)
[![Test](https://github.com/cespare/xxhash/actions/workflows/test.yml/badge.svg)](https://github.com/cespare/xxhash/actions/workflows/test.yml)
xxhash is a Go implementation of the 64-bit
[xxHash](http://cyan4973.github.io/xxHash/) algorithm, XXH64. This is a
high-quality hashing algorithm that is much faster than anything in the Go
standard library.
This package provides a straightforward API:
```
func Sum64(b []byte) uint64
func Sum64String(s string) uint64
type Digest struct{ ... }
func New() *Digest
```
The `Digest` type implements hash.Hash64. Its key methods are:
```
func (*Digest) Write([]byte) (int, error)
func (*Digest) WriteString(string) (int, error)
func (*Digest) Sum64() uint64
```
This implementation provides a fast pure-Go implementation and an even faster
assembly implementation for amd64.
## Compatibility
This package is in a module and the latest code is in version 2 of the module.
You need a version of Go with at least "minimal module compatibility" to use
github.com/cespare/xxhash/v2:
* 1.9.7+ for Go 1.9
* 1.10.3+ for Go 1.10
* Go 1.11 or later
I recommend using the latest release of Go.
## Benchmarks
Here are some quick benchmarks comparing the pure-Go and assembly
implementations of Sum64.
| input size | purego | asm |
| --- | --- | --- |
| 5 B | 979.66 MB/s | 1291.17 MB/s |
| 100 B | 7475.26 MB/s | 7973.40 MB/s |
| 4 KB | 17573.46 MB/s | 17602.65 MB/s |
| 10 MB | 17131.46 MB/s | 17142.16 MB/s |
These numbers were generated on Ubuntu 18.04 with an Intel i7-8700K CPU using
the following commands under Go 1.11.2:
```
$ go test -tags purego -benchtime 10s -bench '/xxhash,direct,bytes'
$ go test -benchtime 10s -bench '/xxhash,direct,bytes'
```
## Projects using this package
- [InfluxDB](https://github.com/influxdata/influxdb)
- [Prometheus](https://github.com/prometheus/prometheus)
- [VictoriaMetrics](https://github.com/VictoriaMetrics/VictoriaMetrics)
- [FreeCache](https://github.com/coocood/freecache)
- [FastCache](https://github.com/VictoriaMetrics/fastcache)

@ -0,0 +1,235 @@
// Package xxhash implements the 64-bit variant of xxHash (XXH64) as described
// at http://cyan4973.github.io/xxHash/.
package xxhash
import (
"encoding/binary"
"errors"
"math/bits"
)
const (
prime1 uint64 = 11400714785074694791
prime2 uint64 = 14029467366897019727
prime3 uint64 = 1609587929392839161
prime4 uint64 = 9650029242287828579
prime5 uint64 = 2870177450012600261
)
// NOTE(caleb): I'm using both consts and vars of the primes. Using consts where
// possible in the Go code is worth a small (but measurable) performance boost
// by avoiding some MOVQs. Vars are needed for the asm and also are useful for
// convenience in the Go code in a few places where we need to intentionally
// avoid constant arithmetic (e.g., v1 := prime1 + prime2 fails because the
// result overflows a uint64).
var (
prime1v = prime1
prime2v = prime2
prime3v = prime3
prime4v = prime4
prime5v = prime5
)
// Digest implements hash.Hash64.
type Digest struct {
v1 uint64
v2 uint64
v3 uint64
v4 uint64
total uint64
mem [32]byte
n int // how much of mem is used
}
// New creates a new Digest that computes the 64-bit xxHash algorithm.
func New() *Digest {
var d Digest
d.Reset()
return &d
}
// Reset clears the Digest's state so that it can be reused.
func (d *Digest) Reset() {
d.v1 = prime1v + prime2
d.v2 = prime2
d.v3 = 0
d.v4 = -prime1v
d.total = 0
d.n = 0
}
// Size always returns 8 bytes.
func (d *Digest) Size() int { return 8 }
// BlockSize always returns 32 bytes.
func (d *Digest) BlockSize() int { return 32 }
// Write adds more data to d. It always returns len(b), nil.
func (d *Digest) Write(b []byte) (n int, err error) {
n = len(b)
d.total += uint64(n)
if d.n+n < 32 {
// This new data doesn't even fill the current block.
copy(d.mem[d.n:], b)
d.n += n
return
}
if d.n > 0 {
// Finish off the partial block.
copy(d.mem[d.n:], b)
d.v1 = round(d.v1, u64(d.mem[0:8]))
d.v2 = round(d.v2, u64(d.mem[8:16]))
d.v3 = round(d.v3, u64(d.mem[16:24]))
d.v4 = round(d.v4, u64(d.mem[24:32]))
b = b[32-d.n:]
d.n = 0
}
if len(b) >= 32 {
// One or more full blocks left.
nw := writeBlocks(d, b)
b = b[nw:]
}
// Store any remaining partial block.
copy(d.mem[:], b)
d.n = len(b)
return
}
// Sum appends the current hash to b and returns the resulting slice.
func (d *Digest) Sum(b []byte) []byte {
s := d.Sum64()
return append(
b,
byte(s>>56),
byte(s>>48),
byte(s>>40),
byte(s>>32),
byte(s>>24),
byte(s>>16),
byte(s>>8),
byte(s),
)
}
// Sum64 returns the current hash.
func (d *Digest) Sum64() uint64 {
var h uint64
if d.total >= 32 {
v1, v2, v3, v4 := d.v1, d.v2, d.v3, d.v4
h = rol1(v1) + rol7(v2) + rol12(v3) + rol18(v4)
h = mergeRound(h, v1)
h = mergeRound(h, v2)
h = mergeRound(h, v3)
h = mergeRound(h, v4)
} else {
h = d.v3 + prime5
}
h += d.total
i, end := 0, d.n
for ; i+8 <= end; i += 8 {
k1 := round(0, u64(d.mem[i:i+8]))
h ^= k1
h = rol27(h)*prime1 + prime4
}
if i+4 <= end {
h ^= uint64(u32(d.mem[i:i+4])) * prime1
h = rol23(h)*prime2 + prime3
i += 4
}
for i < end {
h ^= uint64(d.mem[i]) * prime5
h = rol11(h) * prime1
i++
}
h ^= h >> 33
h *= prime2
h ^= h >> 29
h *= prime3
h ^= h >> 32
return h
}
const (
magic = "xxh\x06"
marshaledSize = len(magic) + 8*5 + 32
)
// MarshalBinary implements the encoding.BinaryMarshaler interface.
func (d *Digest) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaledSize)
b = append(b, magic...)
b = appendUint64(b, d.v1)
b = appendUint64(b, d.v2)
b = appendUint64(b, d.v3)
b = appendUint64(b, d.v4)
b = appendUint64(b, d.total)
b = append(b, d.mem[:d.n]...)
b = b[:len(b)+len(d.mem)-d.n]
return b, nil
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
func (d *Digest) UnmarshalBinary(b []byte) error {
if len(b) < len(magic) || string(b[:len(magic)]) != magic {
return errors.New("xxhash: invalid hash state identifier")
}
if len(b) != marshaledSize {
return errors.New("xxhash: invalid hash state size")
}
b = b[len(magic):]
b, d.v1 = consumeUint64(b)
b, d.v2 = consumeUint64(b)
b, d.v3 = consumeUint64(b)
b, d.v4 = consumeUint64(b)
b, d.total = consumeUint64(b)
copy(d.mem[:], b)
d.n = int(d.total % uint64(len(d.mem)))
return nil
}
func appendUint64(b []byte, x uint64) []byte {
var a [8]byte
binary.LittleEndian.PutUint64(a[:], x)
return append(b, a[:]...)
}
func consumeUint64(b []byte) ([]byte, uint64) {
x := u64(b)
return b[8:], x
}
func u64(b []byte) uint64 { return binary.LittleEndian.Uint64(b) }
func u32(b []byte) uint32 { return binary.LittleEndian.Uint32(b) }
func round(acc, input uint64) uint64 {
acc += input * prime2
acc = rol31(acc)
acc *= prime1
return acc
}
func mergeRound(acc, val uint64) uint64 {
val = round(0, val)
acc ^= val
acc = acc*prime1 + prime4
return acc
}
func rol1(x uint64) uint64 { return bits.RotateLeft64(x, 1) }
func rol7(x uint64) uint64 { return bits.RotateLeft64(x, 7) }
func rol11(x uint64) uint64 { return bits.RotateLeft64(x, 11) }
func rol12(x uint64) uint64 { return bits.RotateLeft64(x, 12) }
func rol18(x uint64) uint64 { return bits.RotateLeft64(x, 18) }
func rol23(x uint64) uint64 { return bits.RotateLeft64(x, 23) }
func rol27(x uint64) uint64 { return bits.RotateLeft64(x, 27) }
func rol31(x uint64) uint64 { return bits.RotateLeft64(x, 31) }

@ -0,0 +1,13 @@
// +build !appengine
// +build gc
// +build !purego
package xxhash
// Sum64 computes the 64-bit xxHash digest of b.
//
//go:noescape
func Sum64(b []byte) uint64
//go:noescape
func writeBlocks(d *Digest, b []byte) int

@ -0,0 +1,215 @@
// +build !appengine
// +build gc
// +build !purego
#include "textflag.h"
// Register allocation:
// AX h
// SI pointer to advance through b
// DX n
// BX loop end
// R8 v1, k1
// R9 v2
// R10 v3
// R11 v4
// R12 tmp
// R13 prime1v
// R14 prime2v
// DI prime4v
// round reads from and advances the buffer pointer in SI.
// It assumes that R13 has prime1v and R14 has prime2v.
#define round(r) \
MOVQ (SI), R12 \
ADDQ $8, SI \
IMULQ R14, R12 \
ADDQ R12, r \
ROLQ $31, r \
IMULQ R13, r
// mergeRound applies a merge round on the two registers acc and val.
// It assumes that R13 has prime1v, R14 has prime2v, and DI has prime4v.
#define mergeRound(acc, val) \
IMULQ R14, val \
ROLQ $31, val \
IMULQ R13, val \
XORQ val, acc \
IMULQ R13, acc \
ADDQ DI, acc
// func Sum64(b []byte) uint64
TEXT ·Sum64(SB), NOSPLIT, $0-32
// Load fixed primes.
MOVQ ·prime1v(SB), R13
MOVQ ·prime2v(SB), R14
MOVQ ·prime4v(SB), DI
// Load slice.
MOVQ b_base+0(FP), SI
MOVQ b_len+8(FP), DX
LEAQ (SI)(DX*1), BX
// The first loop limit will be len(b)-32.
SUBQ $32, BX
// Check whether we have at least one block.
CMPQ DX, $32
JLT noBlocks
// Set up initial state (v1, v2, v3, v4).
MOVQ R13, R8
ADDQ R14, R8
MOVQ R14, R9
XORQ R10, R10
XORQ R11, R11
SUBQ R13, R11
// Loop until SI > BX.
blockLoop:
round(R8)
round(R9)
round(R10)
round(R11)
CMPQ SI, BX
JLE blockLoop
MOVQ R8, AX
ROLQ $1, AX
MOVQ R9, R12
ROLQ $7, R12
ADDQ R12, AX
MOVQ R10, R12
ROLQ $12, R12
ADDQ R12, AX
MOVQ R11, R12
ROLQ $18, R12
ADDQ R12, AX
mergeRound(AX, R8)
mergeRound(AX, R9)
mergeRound(AX, R10)
mergeRound(AX, R11)
JMP afterBlocks
noBlocks:
MOVQ ·prime5v(SB), AX
afterBlocks:
ADDQ DX, AX
// Right now BX has len(b)-32, and we want to loop until SI > len(b)-8.
ADDQ $24, BX
CMPQ SI, BX
JG fourByte
wordLoop:
// Calculate k1.
MOVQ (SI), R8
ADDQ $8, SI
IMULQ R14, R8
ROLQ $31, R8
IMULQ R13, R8
XORQ R8, AX
ROLQ $27, AX
IMULQ R13, AX
ADDQ DI, AX
CMPQ SI, BX
JLE wordLoop
fourByte:
ADDQ $4, BX
CMPQ SI, BX
JG singles
MOVL (SI), R8
ADDQ $4, SI
IMULQ R13, R8
XORQ R8, AX
ROLQ $23, AX
IMULQ R14, AX
ADDQ ·prime3v(SB), AX
singles:
ADDQ $4, BX
CMPQ SI, BX
JGE finalize
singlesLoop:
MOVBQZX (SI), R12
ADDQ $1, SI
IMULQ ·prime5v(SB), R12
XORQ R12, AX
ROLQ $11, AX
IMULQ R13, AX
CMPQ SI, BX
JL singlesLoop
finalize:
MOVQ AX, R12
SHRQ $33, R12
XORQ R12, AX
IMULQ R14, AX
MOVQ AX, R12
SHRQ $29, R12
XORQ R12, AX
IMULQ ·prime3v(SB), AX
MOVQ AX, R12
SHRQ $32, R12
XORQ R12, AX
MOVQ AX, ret+24(FP)
RET
// writeBlocks uses the same registers as above except that it uses AX to store
// the d pointer.
// func writeBlocks(d *Digest, b []byte) int
TEXT ·writeBlocks(SB), NOSPLIT, $0-40
// Load fixed primes needed for round.
MOVQ ·prime1v(SB), R13
MOVQ ·prime2v(SB), R14
// Load slice.
MOVQ b_base+8(FP), SI
MOVQ b_len+16(FP), DX
LEAQ (SI)(DX*1), BX
SUBQ $32, BX
// Load vN from d.
MOVQ d+0(FP), AX
MOVQ 0(AX), R8 // v1
MOVQ 8(AX), R9 // v2
MOVQ 16(AX), R10 // v3
MOVQ 24(AX), R11 // v4
// We don't need to check the loop condition here; this function is
// always called with at least one block of data to process.
blockLoop:
round(R8)
round(R9)
round(R10)
round(R11)
CMPQ SI, BX
JLE blockLoop
// Copy vN back to d.
MOVQ R8, 0(AX)
MOVQ R9, 8(AX)
MOVQ R10, 16(AX)
MOVQ R11, 24(AX)
// The number of bytes written is SI minus the old base pointer.
SUBQ b_base+8(FP), SI
MOVQ SI, ret+32(FP)
RET

@ -0,0 +1,76 @@
// +build !amd64 appengine !gc purego
package xxhash
// Sum64 computes the 64-bit xxHash digest of b.
func Sum64(b []byte) uint64 {
// A simpler version would be
// d := New()
// d.Write(b)
// return d.Sum64()
// but this is faster, particularly for small inputs.
n := len(b)
var h uint64
if n >= 32 {
v1 := prime1v + prime2
v2 := prime2
v3 := uint64(0)
v4 := -prime1v
for len(b) >= 32 {
v1 = round(v1, u64(b[0:8:len(b)]))
v2 = round(v2, u64(b[8:16:len(b)]))
v3 = round(v3, u64(b[16:24:len(b)]))
v4 = round(v4, u64(b[24:32:len(b)]))
b = b[32:len(b):len(b)]
}
h = rol1(v1) + rol7(v2) + rol12(v3) + rol18(v4)
h = mergeRound(h, v1)
h = mergeRound(h, v2)
h = mergeRound(h, v3)
h = mergeRound(h, v4)
} else {
h = prime5
}
h += uint64(n)
i, end := 0, len(b)
for ; i+8 <= end; i += 8 {
k1 := round(0, u64(b[i:i+8:len(b)]))
h ^= k1
h = rol27(h)*prime1 + prime4
}
if i+4 <= end {
h ^= uint64(u32(b[i:i+4:len(b)])) * prime1
h = rol23(h)*prime2 + prime3
i += 4
}
for ; i < end; i++ {
h ^= uint64(b[i]) * prime5
h = rol11(h) * prime1
}
h ^= h >> 33
h *= prime2
h ^= h >> 29
h *= prime3
h ^= h >> 32
return h
}
func writeBlocks(d *Digest, b []byte) int {
v1, v2, v3, v4 := d.v1, d.v2, d.v3, d.v4
n := len(b)
for len(b) >= 32 {
v1 = round(v1, u64(b[0:8:len(b)]))
v2 = round(v2, u64(b[8:16:len(b)]))
v3 = round(v3, u64(b[16:24:len(b)]))
v4 = round(v4, u64(b[24:32:len(b)]))
b = b[32:len(b):len(b)]
}
d.v1, d.v2, d.v3, d.v4 = v1, v2, v3, v4
return n - len(b)
}

@ -0,0 +1,15 @@
// +build appengine
// This file contains the safe implementations of otherwise unsafe-using code.
package xxhash
// Sum64String computes the 64-bit xxHash digest of s.
func Sum64String(s string) uint64 {
return Sum64([]byte(s))
}
// WriteString adds more data to d. It always returns len(s), nil.
func (d *Digest) WriteString(s string) (n int, err error) {
return d.Write([]byte(s))
}

@ -0,0 +1,57 @@
// +build !appengine
// This file encapsulates usage of unsafe.
// xxhash_safe.go contains the safe implementations.
package xxhash
import (
"unsafe"
)
// In the future it's possible that compiler optimizations will make these
// XxxString functions unnecessary by realizing that calls such as
// Sum64([]byte(s)) don't need to copy s. See https://golang.org/issue/2205.
// If that happens, even if we keep these functions they can be replaced with
// the trivial safe code.
// NOTE: The usual way of doing an unsafe string-to-[]byte conversion is:
//
// var b []byte
// bh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
// bh.Data = (*reflect.StringHeader)(unsafe.Pointer(&s)).Data
// bh.Len = len(s)
// bh.Cap = len(s)
//
// Unfortunately, as of Go 1.15.3 the inliner's cost model assigns a high enough
// weight to this sequence of expressions that any function that uses it will
// not be inlined. Instead, the functions below use a different unsafe
// conversion designed to minimize the inliner weight and allow both to be
// inlined. There is also a test (TestInlining) which verifies that these are
// inlined.
//
// See https://github.com/golang/go/issues/42739 for discussion.
// Sum64String computes the 64-bit xxHash digest of s.
// It may be faster than Sum64([]byte(s)) by avoiding a copy.
func Sum64String(s string) uint64 {
b := *(*[]byte)(unsafe.Pointer(&sliceHeader{s, len(s)}))
return Sum64(b)
}
// WriteString adds more data to d. It always returns len(s), nil.
// It may be faster than Write([]byte(s)) by avoiding a copy.
func (d *Digest) WriteString(s string) (n int, err error) {
d.Write(*(*[]byte)(unsafe.Pointer(&sliceHeader{s, len(s)})))
// d.Write always returns len(s), nil.
// Ignoring the return output and returning these fixed values buys a
// savings of 6 in the inliner's cost model.
return len(s), nil
}
// sliceHeader is similar to reflect.SliceHeader, but it assumes that the layout
// of the first two words is the same as the layout of a string.
type sliceHeader struct {
s string
cap int
}

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2017-2020 Damian Gryski <damian@gryski.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

@ -0,0 +1,79 @@
package rendezvous
type Rendezvous struct {
nodes map[string]int
nstr []string
nhash []uint64
hash Hasher
}
type Hasher func(s string) uint64
func New(nodes []string, hash Hasher) *Rendezvous {
r := &Rendezvous{
nodes: make(map[string]int, len(nodes)),
nstr: make([]string, len(nodes)),
nhash: make([]uint64, len(nodes)),
hash: hash,
}
for i, n := range nodes {
r.nodes[n] = i
r.nstr[i] = n
r.nhash[i] = hash(n)
}
return r
}
func (r *Rendezvous) Lookup(k string) string {
// short-circuit if we're empty
if len(r.nodes) == 0 {
return ""
}
khash := r.hash(k)
var midx int
var mhash = xorshiftMult64(khash ^ r.nhash[0])
for i, nhash := range r.nhash[1:] {
if h := xorshiftMult64(khash ^ nhash); h > mhash {
midx = i + 1
mhash = h
}
}
return r.nstr[midx]
}
func (r *Rendezvous) Add(node string) {
r.nodes[node] = len(r.nstr)
r.nstr = append(r.nstr, node)
r.nhash = append(r.nhash, r.hash(node))
}
func (r *Rendezvous) Remove(node string) {
// find index of node to remove
nidx := r.nodes[node]
// remove from the slices
l := len(r.nstr)
r.nstr[nidx] = r.nstr[l]
r.nstr = r.nstr[:l]
r.nhash[nidx] = r.nhash[l]
r.nhash = r.nhash[:l]
// update the map
delete(r.nodes, node)
moved := r.nstr[nidx]
r.nodes[moved] = nidx
}
func xorshiftMult64(x uint64) uint64 {
x ^= x >> 12 // a
x ^= x << 25 // b
x ^= x >> 27 // c
return x * 2685821657736338717
}

@ -0,0 +1,3 @@
*.rdb
testdata/*/
.idea/

@ -0,0 +1,4 @@
run:
concurrency: 8
deadline: 5m
tests: false

@ -0,0 +1,4 @@
semi: false
singleQuote: true
proseWrap: always
printWidth: 100

@ -0,0 +1,31 @@
# [9.0.0-beta.1](https://github.com/go-redis/redis/compare/v8.11.5...v9.0.0-beta.1) (2022-06-04)
### Bug Fixes
* **#1943:** xInfoConsumer.Idle should be time.Duration instead of int64 ([#2052](https://github.com/go-redis/redis/issues/2052)) ([997ab5e](https://github.com/go-redis/redis/commit/997ab5e7e3ddf53837917013a4babbded73e944f)), closes [#1943](https://github.com/go-redis/redis/issues/1943)
* add XInfoConsumers test ([6f1a1ac](https://github.com/go-redis/redis/commit/6f1a1ac284ea3f683eeb3b06a59969e8424b6376))
* fix tests ([3a722be](https://github.com/go-redis/redis/commit/3a722be81180e4d2a9cf0a29dc9a1ee1421f5859))
* remove test(XInfoConsumer.idle), not a stable return value when tested. ([f5fbb36](https://github.com/go-redis/redis/commit/f5fbb367e7d9dfd7f391fc535a7387002232fa8a))
* update ChannelWithSubscriptions to accept options ([c98c5f0](https://github.com/go-redis/redis/commit/c98c5f0eebf8d254307183c2ce702a48256b718d))
* update COMMAND parser for Redis 7 ([b0bb514](https://github.com/go-redis/redis/commit/b0bb514059249e01ed7328c9094e5b8a439dfb12))
* use redis over ssh channel([#2057](https://github.com/go-redis/redis/issues/2057)) ([#2060](https://github.com/go-redis/redis/issues/2060)) ([3961b95](https://github.com/go-redis/redis/commit/3961b9577f622a3079fe74f8fc8da12ba67a77ff))
### Features
* add ClientUnpause ([91171f5](https://github.com/go-redis/redis/commit/91171f5e19a261dc4cfbf8706626d461b6ba03e4))
* add NewXPendingResult for unit testing XPending ([#2066](https://github.com/go-redis/redis/issues/2066)) ([b7fd09e](https://github.com/go-redis/redis/commit/b7fd09e59479bc6ed5b3b13c4645a3620fd448a3))
* add WriteArg and Scan net.IP([#2062](https://github.com/go-redis/redis/issues/2062)) ([7d5167e](https://github.com/go-redis/redis/commit/7d5167e8624ac1515e146ed183becb97dadb3d1a))
* **pool:** add check for badConnection ([a8a7665](https://github.com/go-redis/redis/commit/a8a7665ddf8cc657c5226b1826a8ee83dab4b8c1)), closes [#2053](https://github.com/go-redis/redis/issues/2053)
* provide a username and password callback method, so that the plaintext username and password will not be stored in the memory, and the username and password will only be generated once when the CredentialsProvider is called. After the method is executed, the username and password strings on the stack will be released. ([#2097](https://github.com/go-redis/redis/issues/2097)) ([56a3dbc](https://github.com/go-redis/redis/commit/56a3dbc7b656525eb88e0735e239d56e04a23bee))
* upgrade to Redis 7 ([d09c27e](https://github.com/go-redis/redis/commit/d09c27e6046129fd27b1d275e5a13a477bd7f778))
## v9 UNRELEASED
- Added support for [RESP3](https://github.com/antirez/RESP3/blob/master/spec.md) protocol.
- Removed `Pipeline.Close` since there is no real need to explicitly manage pipeline resources.
`Pipeline.Discard` is still available if you want to reset commands for some reason.
- Replaced `*redis.Z` with `redis.Z` since it is small enough to be passed as value.

@ -0,0 +1,25 @@
Copyright (c) 2013 The github.com/go-redis/redis Authors.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -0,0 +1,34 @@
PACKAGE_DIRS := $(shell find . -mindepth 2 -type f -name 'go.mod' -exec dirname {} \; | sort)
test: testdeps
go test ./...
go test ./... -short -race
go test ./... -run=NONE -bench=. -benchmem
env GOOS=linux GOARCH=386 go test ./...
go vet
testdeps: testdata/redis/src/redis-server
bench: testdeps
go test ./... -test.run=NONE -test.bench=. -test.benchmem
.PHONY: all test testdeps bench
testdata/redis:
mkdir -p $@
wget -qO- https://download.redis.io/releases/redis-7.0.0.tar.gz | tar xvz --strip-components=1 -C $@
testdata/redis/src/redis-server: testdata/redis
cd $< && make all
fmt:
gofmt -w -s ./
goimports -w -local github.com/go-redis/redis ./
go_mod_tidy:
set -e; for dir in $(PACKAGE_DIRS); do \
echo "go mod tidy in $${dir}"; \
(cd "$${dir}" && \
go get -u ./... && \
go mod tidy -compat=1.17); \
done

@ -0,0 +1,195 @@
# Redis client for Go
![build workflow](https://github.com/go-redis/redis/actions/workflows/build.yml/badge.svg)
[![PkgGoDev](https://pkg.go.dev/badge/github.com/go-redis/redis/v8)](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc)
[![Documentation](https://img.shields.io/badge/redis-documentation-informational)](https://redis.uptrace.dev/)
[![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj)
> go-redis is brought to you by :star: [**uptrace/uptrace**](https://github.com/uptrace/uptrace).
> Uptrace is an open source and blazingly fast
> [distributed tracing tool](https://get.uptrace.dev/compare/distributed-tracing-tools.html) powered
> by OpenTelemetry and ClickHouse. Give it a star as well!
## Sponsors
### Upstash: Serverless Database for Redis
<a href="https://upstash.com/?utm_source=goredis"><img align="right" width="320" src="https://raw.githubusercontent.com/upstash/sponsorship/master/redis.png" alt="Upstash"></a>
Upstash is a Serverless Database with Redis/REST API and durable storage. It is the perfect database
for your applications thanks to its per request pricing and low latency data.
[Start for free in 30 seconds!](https://upstash.com/?utm_source=goredis)
<br clear="both"/>
## Resources
- [Documentation](https://redis.uptrace.dev)
- [Discussions](https://github.com/go-redis/redis/discussions)
- [Chat](https://discord.gg/rWtp5Aj)
- [Reference](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc)
- [Examples](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#pkg-examples)
## Ecosystem
- [Redis Mock](https://github.com/go-redis/redismock)
- [Distributed Locks](https://github.com/bsm/redislock)
- [Redis Cache](https://github.com/go-redis/cache)
- [Rate limiting](https://github.com/go-redis/redis_rate)
This client also works with [kvrocks](https://github.com/KvrocksLabs/kvrocks), a distributed key
value NoSQL database that uses RocksDB as storage engine and is compatible with Redis protocol.
## Features
- Redis 3 commands except QUIT, MONITOR, and SYNC.
- Automatic connection pooling with
- [Pub/Sub](https://redis.uptrace.dev/guide/go-redis-pubsub.html).
- [Pipelines and transactions](https://redis.uptrace.dev/guide/go-redis-pipelines.html).
- [Scripting](https://redis.uptrace.dev/guide/lua-scripting.html).
- [Redis Sentinel](https://redis.uptrace.dev/guide/go-redis-sentinel.html).
- [Redis Cluster](https://redis.uptrace.dev/guide/go-redis-cluster.html).
- [Redis Ring](https://redis.uptrace.dev/guide/ring.html).
- [Redis Performance Monitoring](https://redis.uptrace.dev/guide/redis-performance-monitoring.html).
## Installation
go-redis supports 2 last Go versions and requires a Go version with
[modules](https://github.com/golang/go/wiki/Modules) support. So make sure to initialize a Go
module:
```shell
go mod init github.com/my/repo
```
If you are using **Redis 6**, install go-redis/**v8**:
```shell
go get github.com/go-redis/redis/v8
```
If you are using **Redis 7**, install **go-redis/v9**:
```shell
go get github.com/go-redis/redis/v9
```
## Quickstart
```go
import (
"context"
"github.com/go-redis/redis/v8"
"fmt"
)
var ctx = context.Background()
func ExampleClient() {
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "", // no password set
DB: 0, // use default DB
})
err := rdb.Set(ctx, "key", "value", 0).Err()
if err != nil {
panic(err)
}
val, err := rdb.Get(ctx, "key").Result()
if err != nil {
panic(err)
}
fmt.Println("key", val)
val2, err := rdb.Get(ctx, "key2").Result()
if err == redis.Nil {
fmt.Println("key2 does not exist")
} else if err != nil {
panic(err)
} else {
fmt.Println("key2", val2)
}
// Output: key value
// key2 does not exist
}
```
## Look and feel
Some corner cases:
```go
// SET key value EX 10 NX
set, err := rdb.SetNX(ctx, "key", "value", 10*time.Second).Result()
// SET key value keepttl NX
set, err := rdb.SetNX(ctx, "key", "value", redis.KeepTTL).Result()
// SORT list LIMIT 0 2 ASC
vals, err := rdb.Sort(ctx, "list", &redis.Sort{Offset: 0, Count: 2, Order: "ASC"}).Result()
// ZRANGEBYSCORE zset -inf +inf WITHSCORES LIMIT 0 2
vals, err := rdb.ZRangeByScoreWithScores(ctx, "zset", &redis.ZRangeBy{
Min: "-inf",
Max: "+inf",
Offset: 0,
Count: 2,
}).Result()
// ZINTERSTORE out 2 zset1 zset2 WEIGHTS 2 3 AGGREGATE SUM
vals, err := rdb.ZInterStore(ctx, "out", &redis.ZStore{
Keys: []string{"zset1", "zset2"},
Weights: []int64{2, 3}
}).Result()
// EVAL "return {KEYS[1],ARGV[1]}" 1 "key" "hello"
vals, err := rdb.Eval(ctx, "return {KEYS[1],ARGV[1]}", []string{"key"}, "hello").Result()
// custom command
res, err := rdb.Do(ctx, "set", "key", "value").Result()
```
## Run the test
go-redis will start a redis-server and run the test cases.
The paths of redis-server bin file and redis config file are defined in `main_test.go`:
```go
var (
redisServerBin, _ = filepath.Abs(filepath.Join("testdata", "redis", "src", "redis-server"))
redisServerConf, _ = filepath.Abs(filepath.Join("testdata", "redis", "redis.conf"))
)
```
For local testing, you can change the variables to refer to your local files, or create a soft link
to the corresponding folder for redis-server and copy the config file to `testdata/redis/`:
```shell
ln -s /usr/bin/redis-server ./go-redis/testdata/redis/src
cp ./go-redis/testdata/redis.conf ./go-redis/testdata/redis/
```
Lastly, run:
```shell
go test
```
## See also
- [Golang ORM](https://bun.uptrace.dev) for PostgreSQL, MySQL, MSSQL, and SQLite
- [Golang PostgreSQL](https://bun.uptrace.dev/postgres/)
- [Golang HTTP router](https://bunrouter.uptrace.dev/)
- [Golang ClickHouse ORM](https://github.com/uptrace/go-clickhouse)
## Contributors
Thanks to all the people who already contributed!
<a href="https://github.com/go-redis/redis/graphs/contributors">
<img src="https://contributors-img.web.app/image?repo=go-redis/redis" />
</a>

@ -0,0 +1,15 @@
# Releasing
1. Run `release.sh` script which updates versions in go.mod files and pushes a new branch to GitHub:
```shell
TAG=v1.0.0 ./scripts/release.sh
```
2. Open a pull request and wait for the build to finish.
3. Merge the pull request and run `tag.sh` to create tags for packages:
```shell
TAG=v1.0.0 ./scripts/tag.sh
```

File diff suppressed because it is too large Load Diff

@ -0,0 +1,109 @@
package redis
import (
"context"
"sync"
"sync/atomic"
)
func (c *ClusterClient) DBSize(ctx context.Context) *IntCmd {
cmd := NewIntCmd(ctx, "dbsize")
_ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error {
var size int64
err := c.ForEachMaster(ctx, func(ctx context.Context, master *Client) error {
n, err := master.DBSize(ctx).Result()
if err != nil {
return err
}
atomic.AddInt64(&size, n)
return nil
})
if err != nil {
cmd.SetErr(err)
} else {
cmd.val = size
}
return nil
})
return cmd
}
func (c *ClusterClient) ScriptLoad(ctx context.Context, script string) *StringCmd {
cmd := NewStringCmd(ctx, "script", "load", script)
_ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error {
mu := &sync.Mutex{}
err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error {
val, err := shard.ScriptLoad(ctx, script).Result()
if err != nil {
return err
}
mu.Lock()
if cmd.Val() == "" {
cmd.val = val
}
mu.Unlock()
return nil
})
if err != nil {
cmd.SetErr(err)
}
return nil
})
return cmd
}
func (c *ClusterClient) ScriptFlush(ctx context.Context) *StatusCmd {
cmd := NewStatusCmd(ctx, "script", "flush")
_ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error {
err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error {
return shard.ScriptFlush(ctx).Err()
})
if err != nil {
cmd.SetErr(err)
}
return nil
})
return cmd
}
func (c *ClusterClient) ScriptExists(ctx context.Context, hashes ...string) *BoolSliceCmd {
args := make([]interface{}, 2+len(hashes))
args[0] = "script"
args[1] = "exists"
for i, hash := range hashes {
args[2+i] = hash
}
cmd := NewBoolSliceCmd(ctx, args...)
result := make([]bool, len(hashes))
for i := range result {
result[i] = true
}
_ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error {
mu := &sync.Mutex{}
err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error {
val, err := shard.ScriptExists(ctx, hashes...).Result()
if err != nil {
return err
}
mu.Lock()
for i, v := range val {
result[i] = result[i] && v
}
mu.Unlock()
return nil
})
if err != nil {
cmd.SetErr(err)
} else {
cmd.val = result
}
return nil
})
return cmd
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,4 @@
/*
Package redis implements a Redis client.
*/
package redis

@ -0,0 +1,144 @@
package redis
import (
"context"
"io"
"net"
"strings"
"github.com/go-redis/redis/v9/internal/pool"
"github.com/go-redis/redis/v9/internal/proto"
)
// ErrClosed performs any operation on the closed client will return this error.
var ErrClosed = pool.ErrClosed
type Error interface {
error
// RedisError is a no-op function but
// serves to distinguish types that are Redis
// errors from ordinary errors: a type is a
// Redis error if it has a RedisError method.
RedisError()
}
var _ Error = proto.RedisError("")
func shouldRetry(err error, retryTimeout bool) bool {
switch err {
case io.EOF, io.ErrUnexpectedEOF:
return true
case nil, context.Canceled, context.DeadlineExceeded:
return false
}
if v, ok := err.(timeoutError); ok {
if v.Timeout() {
return retryTimeout
}
return true
}
s := err.Error()
if s == "ERR max number of clients reached" {
return true
}
if strings.HasPrefix(s, "LOADING ") {
return true
}
if strings.HasPrefix(s, "READONLY ") {
return true
}
if strings.HasPrefix(s, "CLUSTERDOWN ") {
return true
}
if strings.HasPrefix(s, "TRYAGAIN ") {
return true
}
return false
}
func isRedisError(err error) bool {
_, ok := err.(proto.RedisError)
return ok
}
func isBadConn(err error, allowTimeout bool, addr string) bool {
switch err {
case nil:
return false
case context.Canceled, context.DeadlineExceeded:
return true
}
if isRedisError(err) {
switch {
case isReadOnlyError(err):
// Close connections in read only state in case domain addr is used
// and domain resolves to a different Redis Server. See #790.
return true
case isMovedSameConnAddr(err, addr):
// Close connections when we are asked to move to the same addr
// of the connection. Force a DNS resolution when all connections
// of the pool are recycled
return true
default:
return false
}
}
if allowTimeout {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return false
}
}
return true
}
func isMovedError(err error) (moved bool, ask bool, addr string) {
if !isRedisError(err) {
return
}
s := err.Error()
switch {
case strings.HasPrefix(s, "MOVED "):
moved = true
case strings.HasPrefix(s, "ASK "):
ask = true
default:
return
}
ind := strings.LastIndex(s, " ")
if ind == -1 {
return false, false, ""
}
addr = s[ind+1:]
return
}
func isLoadingError(err error) bool {
return strings.HasPrefix(err.Error(), "LOADING ")
}
func isReadOnlyError(err error) bool {
return strings.HasPrefix(err.Error(), "READONLY ")
}
func isMovedSameConnAddr(err error, addr string) bool {
redisError := err.Error()
if !strings.HasPrefix(redisError, "MOVED ") {
return false
}
return strings.HasSuffix(redisError, " "+addr)
}
//------------------------------------------------------------------------------
type timeoutError interface {
Timeout() bool
}

@ -0,0 +1,56 @@
package internal
import (
"fmt"
"strconv"
"time"
)
func AppendArg(b []byte, v interface{}) []byte {
switch v := v.(type) {
case nil:
return append(b, "<nil>"...)
case string:
return appendUTF8String(b, Bytes(v))
case []byte:
return appendUTF8String(b, v)
case int:
return strconv.AppendInt(b, int64(v), 10)
case int8:
return strconv.AppendInt(b, int64(v), 10)
case int16:
return strconv.AppendInt(b, int64(v), 10)
case int32:
return strconv.AppendInt(b, int64(v), 10)
case int64:
return strconv.AppendInt(b, v, 10)
case uint:
return strconv.AppendUint(b, uint64(v), 10)
case uint8:
return strconv.AppendUint(b, uint64(v), 10)
case uint16:
return strconv.AppendUint(b, uint64(v), 10)
case uint32:
return strconv.AppendUint(b, uint64(v), 10)
case uint64:
return strconv.AppendUint(b, v, 10)
case float32:
return strconv.AppendFloat(b, float64(v), 'f', -1, 64)
case float64:
return strconv.AppendFloat(b, v, 'f', -1, 64)
case bool:
if v {
return append(b, "true"...)
}
return append(b, "false"...)
case time.Time:
return v.AppendFormat(b, time.RFC3339Nano)
default:
return append(b, fmt.Sprint(v)...)
}
}
func appendUTF8String(dst []byte, src []byte) []byte {
dst = append(dst, src...)
return dst
}

@ -0,0 +1,78 @@
package hashtag
import (
"strings"
"github.com/go-redis/redis/v9/internal/rand"
)
const slotNumber = 16384
// CRC16 implementation according to CCITT standards.
// Copyright 2001-2010 Georges Menie (www.menie.org)
// Copyright 2013 The Go Authors. All rights reserved.
// http://redis.io/topics/cluster-spec#appendix-a-crc16-reference-implementation-in-ansi-c
var crc16tab = [256]uint16{
0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7,
0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef,
0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6,
0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de,
0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485,
0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d,
0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4,
0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc,
0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823,
0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b,
0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12,
0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a,
0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41,
0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49,
0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70,
0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78,
0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f,
0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067,
0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e,
0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256,
0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d,
0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405,
0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c,
0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634,
0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab,
0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3,
0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a,
0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92,
0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9,
0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1,
0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8,
0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0,
}
func Key(key string) string {
if s := strings.IndexByte(key, '{'); s > -1 {
if e := strings.IndexByte(key[s+1:], '}'); e > 0 {
return key[s+1 : s+e+1]
}
}
return key
}
func RandomSlot() int {
return rand.Intn(slotNumber)
}
// Slot returns a consistent slot number between 0 and 16383
// for any given string key.
func Slot(key string) int {
if key == "" {
return RandomSlot()
}
key = Key(key)
return int(crc16sum(key)) % slotNumber
}
func crc16sum(key string) (crc uint16) {
for i := 0; i < len(key); i++ {
crc = (crc << 8) ^ crc16tab[(byte(crc>>8)^key[i])&0x00ff]
}
return
}

@ -0,0 +1,201 @@
package hscan
import (
"errors"
"fmt"
"reflect"
"strconv"
)
// decoderFunc represents decoding functions for default built-in types.
type decoderFunc func(reflect.Value, string) error
var (
// List of built-in decoders indexed by their numeric constant values (eg: reflect.Bool = 1).
decoders = []decoderFunc{
reflect.Bool: decodeBool,
reflect.Int: decodeInt,
reflect.Int8: decodeInt8,
reflect.Int16: decodeInt16,
reflect.Int32: decodeInt32,
reflect.Int64: decodeInt64,
reflect.Uint: decodeUint,
reflect.Uint8: decodeUint8,
reflect.Uint16: decodeUint16,
reflect.Uint32: decodeUint32,
reflect.Uint64: decodeUint64,
reflect.Float32: decodeFloat32,
reflect.Float64: decodeFloat64,
reflect.Complex64: decodeUnsupported,
reflect.Complex128: decodeUnsupported,
reflect.Array: decodeUnsupported,
reflect.Chan: decodeUnsupported,
reflect.Func: decodeUnsupported,
reflect.Interface: decodeUnsupported,
reflect.Map: decodeUnsupported,
reflect.Ptr: decodeUnsupported,
reflect.Slice: decodeSlice,
reflect.String: decodeString,
reflect.Struct: decodeUnsupported,
reflect.UnsafePointer: decodeUnsupported,
}
// Global map of struct field specs that is populated once for every new
// struct type that is scanned. This caches the field types and the corresponding
// decoder functions to avoid iterating through struct fields on subsequent scans.
globalStructMap = newStructMap()
)
func Struct(dst interface{}) (StructValue, error) {
v := reflect.ValueOf(dst)
// The destination to scan into should be a struct pointer.
if v.Kind() != reflect.Ptr || v.IsNil() {
return StructValue{}, fmt.Errorf("redis.Scan(non-pointer %T)", dst)
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return StructValue{}, fmt.Errorf("redis.Scan(non-struct %T)", dst)
}
return StructValue{
spec: globalStructMap.get(v.Type()),
value: v,
}, nil
}
// Scan scans the results from a key-value Redis map result set to a destination struct.
// The Redis keys are matched to the struct's field with the `redis` tag.
func Scan(dst interface{}, keys []interface{}, vals []interface{}) error {
if len(keys) != len(vals) {
return errors.New("args should have the same number of keys and vals")
}
strct, err := Struct(dst)
if err != nil {
return err
}
// Iterate through the (key, value) sequence.
for i := 0; i < len(vals); i++ {
key, ok := keys[i].(string)
if !ok {
continue
}
val, ok := vals[i].(string)
if !ok {
continue
}
if err := strct.Scan(key, val); err != nil {
return err
}
}
return nil
}
func decodeBool(f reflect.Value, s string) error {
b, err := strconv.ParseBool(s)
if err != nil {
return err
}
f.SetBool(b)
return nil
}
func decodeInt8(f reflect.Value, s string) error {
return decodeNumber(f, s, 8)
}
func decodeInt16(f reflect.Value, s string) error {
return decodeNumber(f, s, 16)
}
func decodeInt32(f reflect.Value, s string) error {
return decodeNumber(f, s, 32)
}
func decodeInt64(f reflect.Value, s string) error {
return decodeNumber(f, s, 64)
}
func decodeInt(f reflect.Value, s string) error {
return decodeNumber(f, s, 0)
}
func decodeNumber(f reflect.Value, s string, bitSize int) error {
v, err := strconv.ParseInt(s, 10, bitSize)
if err != nil {
return err
}
f.SetInt(v)
return nil
}
func decodeUint8(f reflect.Value, s string) error {
return decodeUnsignedNumber(f, s, 8)
}
func decodeUint16(f reflect.Value, s string) error {
return decodeUnsignedNumber(f, s, 16)
}
func decodeUint32(f reflect.Value, s string) error {
return decodeUnsignedNumber(f, s, 32)
}
func decodeUint64(f reflect.Value, s string) error {
return decodeUnsignedNumber(f, s, 64)
}
func decodeUint(f reflect.Value, s string) error {
return decodeUnsignedNumber(f, s, 0)
}
func decodeUnsignedNumber(f reflect.Value, s string, bitSize int) error {
v, err := strconv.ParseUint(s, 10, bitSize)
if err != nil {
return err
}
f.SetUint(v)
return nil
}
func decodeFloat32(f reflect.Value, s string) error {
v, err := strconv.ParseFloat(s, 32)
if err != nil {
return err
}
f.SetFloat(v)
return nil
}
// although the default is float64, but we better define it.
func decodeFloat64(f reflect.Value, s string) error {
v, err := strconv.ParseFloat(s, 64)
if err != nil {
return err
}
f.SetFloat(v)
return nil
}
func decodeString(f reflect.Value, s string) error {
f.SetString(s)
return nil
}
func decodeSlice(f reflect.Value, s string) error {
// []byte slice ([]uint8).
if f.Type().Elem().Kind() == reflect.Uint8 {
f.SetBytes([]byte(s))
}
return nil
}
func decodeUnsupported(v reflect.Value, s string) error {
return fmt.Errorf("redis.Scan(unsupported %s)", v.Type())
}

@ -0,0 +1,93 @@
package hscan
import (
"fmt"
"reflect"
"strings"
"sync"
)
// structMap contains the map of struct fields for target structs
// indexed by the struct type.
type structMap struct {
m sync.Map
}
func newStructMap() *structMap {
return new(structMap)
}
func (s *structMap) get(t reflect.Type) *structSpec {
if v, ok := s.m.Load(t); ok {
return v.(*structSpec)
}
spec := newStructSpec(t, "redis")
s.m.Store(t, spec)
return spec
}
//------------------------------------------------------------------------------
// structSpec contains the list of all fields in a target struct.
type structSpec struct {
m map[string]*structField
}
func (s *structSpec) set(tag string, sf *structField) {
s.m[tag] = sf
}
func newStructSpec(t reflect.Type, fieldTag string) *structSpec {
numField := t.NumField()
out := &structSpec{
m: make(map[string]*structField, numField),
}
for i := 0; i < numField; i++ {
f := t.Field(i)
tag := f.Tag.Get(fieldTag)
if tag == "" || tag == "-" {
continue
}
tag = strings.Split(tag, ",")[0]
if tag == "" {
continue
}
// Use the built-in decoder.
out.set(tag, &structField{index: i, fn: decoders[f.Type.Kind()]})
}
return out
}
//------------------------------------------------------------------------------
// structField represents a single field in a target struct.
type structField struct {
index int
fn decoderFunc
}
//------------------------------------------------------------------------------
type StructValue struct {
spec *structSpec
value reflect.Value
}
func (s StructValue) Scan(key string, value string) error {
field, ok := s.spec.m[key]
if !ok {
return nil
}
if err := field.fn(s.value.Field(field.index), value); err != nil {
t := s.value.Type()
return fmt.Errorf("cannot scan redis.result %s into struct field %s.%s of type %s, error-%s",
value, t.Name(), t.Field(field.index).Name, t.Field(field.index).Type, err.Error())
}
return nil
}

@ -0,0 +1,29 @@
package internal
import (
"time"
"github.com/go-redis/redis/v9/internal/rand"
)
func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration {
if retry < 0 {
panic("not reached")
}
if minBackoff == 0 {
return 0
}
d := minBackoff << uint(retry)
if d < minBackoff {
return maxBackoff
}
d = minBackoff + time.Duration(rand.Int63n(int64(d)))
if d > maxBackoff || d < minBackoff {
d = maxBackoff
}
return d
}

@ -0,0 +1,26 @@
package internal
import (
"context"
"fmt"
"log"
"os"
)
type Logging interface {
Printf(ctx context.Context, format string, v ...interface{})
}
type logger struct {
log *log.Logger
}
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
_ = l.log.Output(2, fmt.Sprintf(format, v...))
}
// Logger calls Output to print to the stderr.
// Arguments are handled in the manner of fmt.Print.
var Logger Logging = &logger{
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
}

@ -0,0 +1,60 @@
/*
Copyright 2014 The Camlistore Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package internal
import (
"sync"
"sync/atomic"
)
// A Once will perform a successful action exactly once.
//
// Unlike a sync.Once, this Once's func returns an error
// and is re-armed on failure.
type Once struct {
m sync.Mutex
done uint32
}
// Do calls the function f if and only if Do has not been invoked
// without error for this instance of Once. In other words, given
// var once Once
// if once.Do(f) is called multiple times, only the first call will
// invoke f, even if f has a different value in each invocation unless
// f returns an error. A new instance of Once is required for each
// function to execute.
//
// Do is intended for initialization that must be run exactly once. Since f
// is niladic, it may be necessary to use a function literal to capture the
// arguments to a function to be invoked by Do:
// err := config.once.Do(func() error { return config.init(filename) })
func (o *Once) Do(f func() error) error {
if atomic.LoadUint32(&o.done) == 1 {
return nil
}
// Slow-path.
o.m.Lock()
defer o.m.Unlock()
var err error
if o.done == 0 {
err = f()
if err == nil {
atomic.StoreUint32(&o.done, 1)
}
}
return err
}

@ -0,0 +1,125 @@
package pool
import (
"bufio"
"context"
"net"
"sync/atomic"
"time"
"github.com/go-redis/redis/v9/internal/proto"
)
var noDeadline = time.Time{}
type Conn struct {
usedAt int64 // atomic
netConn net.Conn
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
Inited bool
pooled bool
createdAt time.Time
}
func NewConn(netConn net.Conn) *Conn {
cn := &Conn{
netConn: netConn,
createdAt: time.Now(),
}
cn.rd = proto.NewReader(netConn)
cn.bw = bufio.NewWriter(netConn)
cn.wr = proto.NewWriter(cn.bw)
cn.SetUsedAt(time.Now())
return cn
}
func (cn *Conn) UsedAt() time.Time {
unix := atomic.LoadInt64(&cn.usedAt)
return time.Unix(unix, 0)
}
func (cn *Conn) SetUsedAt(tm time.Time) {
atomic.StoreInt64(&cn.usedAt, tm.Unix())
}
func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn
cn.rd.Reset(netConn)
cn.bw.Reset(netConn)
}
func (cn *Conn) Write(b []byte) (int, error) {
return cn.netConn.Write(b)
}
func (cn *Conn) RemoteAddr() net.Addr {
if cn.netConn != nil {
return cn.netConn.RemoteAddr()
}
return nil
}
func (cn *Conn) WithReader(ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error) error {
if timeout != 0 {
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
return err
}
}
return fn(cn.rd)
}
func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
) error {
if timeout != 0 {
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
return err
}
}
if cn.bw.Buffered() > 0 {
cn.bw.Reset(cn.netConn)
}
if err := fn(cn.wr); err != nil {
return err
}
return cn.bw.Flush()
}
func (cn *Conn) Close() error {
return cn.netConn.Close()
}
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
tm := time.Now()
cn.SetUsedAt(tm)
if timeout > 0 {
tm = tm.Add(timeout)
}
if ctx != nil {
deadline, ok := ctx.Deadline()
if ok {
if timeout == 0 {
return deadline
}
if deadline.Before(tm) {
return deadline
}
return tm
}
}
if timeout > 0 {
return tm
}
return noDeadline
}

@ -0,0 +1,50 @@
//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos
// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos
package pool
import (
"errors"
"io"
"net"
"syscall"
"time"
)
var errUnexpectedRead = errors.New("unexpected read from socket")
func connCheck(conn net.Conn) error {
// Reset previous timeout.
_ = conn.SetDeadline(time.Time{})
sysConn, ok := conn.(syscall.Conn)
if !ok {
return nil
}
rawConn, err := sysConn.SyscallConn()
if err != nil {
return err
}
var sysErr error
err = rawConn.Read(func(fd uintptr) bool {
var buf [1]byte
n, err := syscall.Read(int(fd), buf[:])
switch {
case n == 0 && err == nil:
sysErr = io.EOF
case n > 0:
sysErr = errUnexpectedRead
case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK:
sysErr = nil
default:
sysErr = err
}
return true
})
if err != nil {
return err
}
return sysErr
}

@ -0,0 +1,10 @@
//go:build !linux && !darwin && !dragonfly && !freebsd && !netbsd && !openbsd && !solaris && !illumos
// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos
package pool
import "net"
func connCheck(conn net.Conn) error {
return nil
}

@ -0,0 +1,557 @@
package pool
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/go-redis/redis/v9/internal"
)
var (
// ErrClosed performs any operation on the closed client will return this error.
ErrClosed = errors.New("redis: client is closed")
// ErrPoolTimeout timed out waiting to get a connection from the connection pool.
ErrPoolTimeout = errors.New("redis: connection pool timeout")
)
var timers = sync.Pool{
New: func() interface{} {
t := time.NewTimer(time.Hour)
t.Stop()
return t
},
}
// Stats contains pool state information and accumulated stats.
type Stats struct {
Hits uint32 // number of times free connection was found in the pool
Misses uint32 // number of times free connection was NOT found in the pool
Timeouts uint32 // number of times a wait timeout occurred
TotalConns uint32 // number of total connections in the pool
IdleConns uint32 // number of idle connections in the pool
StaleConns uint32 // number of stale connections removed from the pool
}
type Pooler interface {
NewConn(context.Context) (*Conn, error)
CloseConn(*Conn) error
Get(context.Context) (*Conn, error)
Put(context.Context, *Conn)
Remove(context.Context, *Conn, error)
Len() int
IdleLen() int
Stats() *Stats
Close() error
}
type Options struct {
Dialer func(context.Context) (net.Conn, error)
OnClose func(*Conn) error
PoolFIFO bool
PoolSize int
MinIdleConns int
MaxConnAge time.Duration
PoolTimeout time.Duration
IdleTimeout time.Duration
IdleCheckFrequency time.Duration
}
type lastDialErrorWrap struct {
err error
}
type ConnPool struct {
opt *Options
dialErrorsNum uint32 // atomic
lastDialError atomic.Value
queue chan struct{}
connsMu sync.Mutex
conns []*Conn
idleConns []*Conn
poolSize int
idleConnsLen int
stats Stats
_closed uint32 // atomic
closedCh chan struct{}
}
var _ Pooler = (*ConnPool)(nil)
func NewConnPool(opt *Options) *ConnPool {
p := &ConnPool{
opt: opt,
queue: make(chan struct{}, opt.PoolSize),
conns: make([]*Conn, 0, opt.PoolSize),
idleConns: make([]*Conn, 0, opt.PoolSize),
closedCh: make(chan struct{}),
}
p.connsMu.Lock()
p.checkMinIdleConns()
p.connsMu.Unlock()
if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 {
go p.reaper(opt.IdleCheckFrequency)
}
return p
}
func (p *ConnPool) checkMinIdleConns() {
if p.opt.MinIdleConns == 0 {
return
}
for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns {
p.poolSize++
p.idleConnsLen++
go func() {
err := p.addIdleConn()
if err != nil && err != ErrClosed {
p.connsMu.Lock()
p.poolSize--
p.idleConnsLen--
p.connsMu.Unlock()
}
}()
}
}
func (p *ConnPool) addIdleConn() error {
cn, err := p.dialConn(context.TODO(), true)
if err != nil {
return err
}
p.connsMu.Lock()
defer p.connsMu.Unlock()
// It is not allowed to add new connections to the closed connection pool.
if p.closed() {
_ = cn.Close()
return ErrClosed
}
p.conns = append(p.conns, cn)
p.idleConns = append(p.idleConns, cn)
return nil
}
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.newConn(ctx, false)
}
func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
cn, err := p.dialConn(ctx, pooled)
if err != nil {
return nil, err
}
p.connsMu.Lock()
defer p.connsMu.Unlock()
// It is not allowed to add new connections to the closed connection pool.
if p.closed() {
_ = cn.Close()
return nil, ErrClosed
}
p.conns = append(p.conns, cn)
if pooled {
// If pool is full remove the cn on next Put.
if p.poolSize >= p.opt.PoolSize {
cn.pooled = false
} else {
p.poolSize++
}
}
return cn, nil
}
func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) {
return nil, p.getLastDialError()
}
netConn, err := p.opt.Dialer(ctx)
if err != nil {
p.setLastDialError(err)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
go p.tryDial()
}
return nil, err
}
cn := NewConn(netConn)
cn.pooled = pooled
return cn, nil
}
func (p *ConnPool) tryDial() {
for {
if p.closed() {
return
}
conn, err := p.opt.Dialer(context.Background())
if err != nil {
p.setLastDialError(err)
time.Sleep(time.Second)
continue
}
atomic.StoreUint32(&p.dialErrorsNum, 0)
_ = conn.Close()
return
}
}
func (p *ConnPool) setLastDialError(err error) {
p.lastDialError.Store(&lastDialErrorWrap{err: err})
}
func (p *ConnPool) getLastDialError() error {
err, _ := p.lastDialError.Load().(*lastDialErrorWrap)
if err != nil {
return err.err
}
return nil
}
// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
if err := p.waitTurn(ctx); err != nil {
return nil, err
}
for {
p.connsMu.Lock()
cn, err := p.popIdle()
p.connsMu.Unlock()
if err != nil {
return nil, err
}
if cn == nil {
break
}
if p.isStaleConn(cn) {
_ = p.CloseConn(cn)
continue
}
atomic.AddUint32(&p.stats.Hits, 1)
return cn, nil
}
atomic.AddUint32(&p.stats.Misses, 1)
newcn, err := p.newConn(ctx, true)
if err != nil {
p.freeTurn()
return nil, err
}
return newcn, nil
}
func (p *ConnPool) getTurn() {
p.queue <- struct{}{}
}
func (p *ConnPool) waitTurn(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
select {
case p.queue <- struct{}{}:
return nil
default:
}
timer := timers.Get().(*time.Timer)
timer.Reset(p.opt.PoolTimeout)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return ctx.Err()
case p.queue <- struct{}{}:
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return nil
case <-timer.C:
timers.Put(timer)
atomic.AddUint32(&p.stats.Timeouts, 1)
return ErrPoolTimeout
}
}
func (p *ConnPool) freeTurn() {
<-p.queue
}
func (p *ConnPool) popIdle() (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
n := len(p.idleConns)
if n == 0 {
return nil, nil
}
var cn *Conn
if p.opt.PoolFIFO {
cn = p.idleConns[0]
copy(p.idleConns, p.idleConns[1:])
p.idleConns = p.idleConns[:n-1]
} else {
idx := n - 1
cn = p.idleConns[idx]
p.idleConns = p.idleConns[:idx]
}
p.idleConnsLen--
p.checkMinIdleConns()
return cn, nil
}
func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
if cn.rd.Buffered() > 0 {
internal.Logger.Printf(ctx, "Conn has unread data")
p.Remove(ctx, cn, BadConnError{})
return
}
if !cn.pooled {
p.Remove(ctx, cn, nil)
return
}
p.connsMu.Lock()
p.idleConns = append(p.idleConns, cn)
p.idleConnsLen++
p.connsMu.Unlock()
p.freeTurn()
}
func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
p.removeConnWithLock(cn)
p.freeTurn()
_ = p.closeConn(cn)
}
func (p *ConnPool) CloseConn(cn *Conn) error {
p.removeConnWithLock(cn)
return p.closeConn(cn)
}
func (p *ConnPool) removeConnWithLock(cn *Conn) {
p.connsMu.Lock()
p.removeConn(cn)
p.connsMu.Unlock()
}
func (p *ConnPool) removeConn(cn *Conn) {
for i, c := range p.conns {
if c == cn {
p.conns = append(p.conns[:i], p.conns[i+1:]...)
if cn.pooled {
p.poolSize--
p.checkMinIdleConns()
}
return
}
}
}
func (p *ConnPool) closeConn(cn *Conn) error {
if p.opt.OnClose != nil {
_ = p.opt.OnClose(cn)
}
return cn.Close()
}
// Len returns total number of connections.
func (p *ConnPool) Len() int {
p.connsMu.Lock()
n := len(p.conns)
p.connsMu.Unlock()
return n
}
// IdleLen returns number of idle connections.
func (p *ConnPool) IdleLen() int {
p.connsMu.Lock()
n := p.idleConnsLen
p.connsMu.Unlock()
return n
}
func (p *ConnPool) Stats() *Stats {
idleLen := p.IdleLen()
return &Stats{
Hits: atomic.LoadUint32(&p.stats.Hits),
Misses: atomic.LoadUint32(&p.stats.Misses),
Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
TotalConns: uint32(p.Len()),
IdleConns: uint32(idleLen),
StaleConns: atomic.LoadUint32(&p.stats.StaleConns),
}
}
func (p *ConnPool) closed() bool {
return atomic.LoadUint32(&p._closed) == 1
}
func (p *ConnPool) Filter(fn func(*Conn) bool) error {
p.connsMu.Lock()
defer p.connsMu.Unlock()
var firstErr error
for _, cn := range p.conns {
if fn(cn) {
if err := p.closeConn(cn); err != nil && firstErr == nil {
firstErr = err
}
}
}
return firstErr
}
func (p *ConnPool) Close() error {
if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
return ErrClosed
}
close(p.closedCh)
var firstErr error
p.connsMu.Lock()
for _, cn := range p.conns {
if err := p.closeConn(cn); err != nil && firstErr == nil {
firstErr = err
}
}
p.conns = nil
p.poolSize = 0
p.idleConns = nil
p.idleConnsLen = 0
p.connsMu.Unlock()
return firstErr
}
func (p *ConnPool) reaper(frequency time.Duration) {
ticker := time.NewTicker(frequency)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// It is possible that ticker and closedCh arrive together,
// and select pseudo-randomly pick ticker case, we double
// check here to prevent being executed after closed.
if p.closed() {
return
}
_, err := p.ReapStaleConns()
if err != nil {
internal.Logger.Printf(context.Background(), "ReapStaleConns failed: %s", err)
continue
}
case <-p.closedCh:
return
}
}
}
func (p *ConnPool) ReapStaleConns() (int, error) {
var n int
for {
p.getTurn()
p.connsMu.Lock()
cn := p.reapStaleConn()
p.connsMu.Unlock()
p.freeTurn()
if cn != nil {
_ = p.closeConn(cn)
n++
} else {
break
}
}
atomic.AddUint32(&p.stats.StaleConns, uint32(n))
return n, nil
}
func (p *ConnPool) reapStaleConn() *Conn {
if len(p.idleConns) == 0 {
return nil
}
cn := p.idleConns[0]
if !p.isStaleConn(cn) {
return nil
}
p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...)
p.idleConnsLen--
p.removeConn(cn)
return cn
}
func (p *ConnPool) isStaleConn(cn *Conn) bool {
if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 {
return connCheck(cn.netConn) != nil
}
now := time.Now()
if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout {
return true
}
if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge {
return true
}
return connCheck(cn.netConn) != nil
}

@ -0,0 +1,58 @@
package pool
import "context"
type SingleConnPool struct {
pool Pooler
cn *Conn
stickyErr error
}
var _ Pooler = (*SingleConnPool)(nil)
func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool {
return &SingleConnPool{
pool: pool,
cn: cn,
}
}
func (p *SingleConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.pool.NewConn(ctx)
}
func (p *SingleConnPool) CloseConn(cn *Conn) error {
return p.pool.CloseConn(cn)
}
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
if p.stickyErr != nil {
return nil, p.stickyErr
}
return p.cn, nil
}
func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {}
func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
p.cn = nil
p.stickyErr = reason
}
func (p *SingleConnPool) Close() error {
p.cn = nil
p.stickyErr = ErrClosed
return nil
}
func (p *SingleConnPool) Len() int {
return 0
}
func (p *SingleConnPool) IdleLen() int {
return 0
}
func (p *SingleConnPool) Stats() *Stats {
return &Stats{}
}

@ -0,0 +1,201 @@
package pool
import (
"context"
"errors"
"fmt"
"sync/atomic"
)
const (
stateDefault = 0
stateInited = 1
stateClosed = 2
)
type BadConnError struct {
wrapped error
}
var _ error = (*BadConnError)(nil)
func (e BadConnError) Error() string {
s := "redis: Conn is in a bad state"
if e.wrapped != nil {
s += ": " + e.wrapped.Error()
}
return s
}
func (e BadConnError) Unwrap() error {
return e.wrapped
}
//------------------------------------------------------------------------------
type StickyConnPool struct {
pool Pooler
shared int32 // atomic
state uint32 // atomic
ch chan *Conn
_badConnError atomic.Value
}
var _ Pooler = (*StickyConnPool)(nil)
func NewStickyConnPool(pool Pooler) *StickyConnPool {
p, ok := pool.(*StickyConnPool)
if !ok {
p = &StickyConnPool{
pool: pool,
ch: make(chan *Conn, 1),
}
}
atomic.AddInt32(&p.shared, 1)
return p
}
func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) {
return p.pool.NewConn(ctx)
}
func (p *StickyConnPool) CloseConn(cn *Conn) error {
return p.pool.CloseConn(cn)
}
func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
// In worst case this races with Close which is not a very common operation.
for i := 0; i < 1000; i++ {
switch atomic.LoadUint32(&p.state) {
case stateDefault:
cn, err := p.pool.Get(ctx)
if err != nil {
return nil, err
}
if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
return cn, nil
}
p.pool.Remove(ctx, cn, ErrClosed)
case stateInited:
if err := p.badConnError(); err != nil {
return nil, err
}
cn, ok := <-p.ch
if !ok {
return nil, ErrClosed
}
return cn, nil
case stateClosed:
return nil, ErrClosed
default:
panic("not reached")
}
}
return nil, fmt.Errorf("redis: StickyConnPool.Get: infinite loop")
}
func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) {
defer func() {
if recover() != nil {
p.freeConn(ctx, cn)
}
}()
p.ch <- cn
}
func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) {
if err := p.badConnError(); err != nil {
p.pool.Remove(ctx, cn, err)
} else {
p.pool.Put(ctx, cn)
}
}
func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
defer func() {
if recover() != nil {
p.pool.Remove(ctx, cn, ErrClosed)
}
}()
p._badConnError.Store(BadConnError{wrapped: reason})
p.ch <- cn
}
func (p *StickyConnPool) Close() error {
if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
return nil
}
for i := 0; i < 1000; i++ {
state := atomic.LoadUint32(&p.state)
if state == stateClosed {
return ErrClosed
}
if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
close(p.ch)
cn, ok := <-p.ch
if ok {
p.freeConn(context.TODO(), cn)
}
return nil
}
}
return errors.New("redis: StickyConnPool.Close: infinite loop")
}
func (p *StickyConnPool) Reset(ctx context.Context) error {
if p.badConnError() == nil {
return nil
}
select {
case cn, ok := <-p.ch:
if !ok {
return ErrClosed
}
p.pool.Remove(ctx, cn, ErrClosed)
p._badConnError.Store(BadConnError{wrapped: nil})
default:
return errors.New("redis: StickyConnPool does not have a Conn")
}
if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
state := atomic.LoadUint32(&p.state)
return fmt.Errorf("redis: invalid StickyConnPool state: %d", state)
}
return nil
}
func (p *StickyConnPool) badConnError() error {
if v := p._badConnError.Load(); v != nil {
if err := v.(BadConnError); err.wrapped != nil {
return err
}
}
return nil
}
func (p *StickyConnPool) Len() int {
switch atomic.LoadUint32(&p.state) {
case stateDefault:
return 0
case stateInited:
return 1
case stateClosed:
return 0
default:
panic("not reached")
}
}
func (p *StickyConnPool) IdleLen() int {
return len(p.ch)
}
func (p *StickyConnPool) Stats() *Stats {
return &Stats{}
}

@ -0,0 +1,523 @@
package proto
import (
"bufio"
"errors"
"fmt"
"io"
"math"
"math/big"
"strconv"
"github.com/go-redis/redis/v9/internal/util"
)
// redis resp protocol data type.
const (
RespStatus = '+' // +<string>\r\n
RespError = '-' // -<string>\r\n
RespString = '$' // $<length>\r\n<bytes>\r\n
RespInt = ':' // :<number>\r\n
RespNil = '_' // _\r\n
RespFloat = ',' // ,<floating-point-number>\r\n (golang float)
RespBool = '#' // true: #t\r\n false: #f\r\n
RespBlobError = '!' // !<length>\r\n<bytes>\r\n
RespVerbatim = '=' // =<length>\r\nFORMAT:<bytes>\r\n
RespBigInt = '(' // (<big number>\r\n
RespArray = '*' // *<len>\r\n... (same as resp2)
RespMap = '%' // %<len>\r\n(key)\r\n(value)\r\n... (golang map)
RespSet = '~' // ~<len>\r\n... (same as Array)
RespAttr = '|' // |<len>\r\n(key)\r\n(value)\r\n... + command reply
RespPush = '>' // ><len>\r\n... (same as Array)
)
// Not used temporarily.
// Redis has not used these two data types for the time being, and will implement them later.
// Streamed = "EOF:"
// StreamedAggregated = '?'
//------------------------------------------------------------------------------
const Nil = RedisError("redis: nil") // nolint:errname
type RedisError string
func (e RedisError) Error() string { return string(e) }
func (RedisError) RedisError() {}
func ParseErrorReply(line []byte) error {
return RedisError(line[1:])
}
//------------------------------------------------------------------------------
type Reader struct {
rd *bufio.Reader
}
func NewReader(rd io.Reader) *Reader {
return &Reader{
rd: bufio.NewReader(rd),
}
}
func (r *Reader) Buffered() int {
return r.rd.Buffered()
}
func (r *Reader) Peek(n int) ([]byte, error) {
return r.rd.Peek(n)
}
func (r *Reader) Reset(rd io.Reader) {
r.rd.Reset(rd)
}
// PeekReplyType returns the data type of the next response without advancing the Reader,
// and discard the attribute type.
func (r *Reader) PeekReplyType() (byte, error) {
b, err := r.rd.Peek(1)
if err != nil {
return 0, err
}
if b[0] == RespAttr {
if err = r.DiscardNext(); err != nil {
return 0, err
}
return r.PeekReplyType()
}
return b[0], nil
}
// ReadLine Return a valid reply, it will check the protocol or redis error,
// and discard the attribute type.
func (r *Reader) ReadLine() ([]byte, error) {
line, err := r.readLine()
if err != nil {
return nil, err
}
switch line[0] {
case RespError:
return nil, ParseErrorReply(line)
case RespNil:
return nil, Nil
case RespBlobError:
var blobErr string
blobErr, err = r.readStringReply(line)
if err == nil {
err = RedisError(blobErr)
}
return nil, err
case RespAttr:
if err = r.Discard(line); err != nil {
return nil, err
}
return r.ReadLine()
}
// Compatible with RESP2
if IsNilReply(line) {
return nil, Nil
}
return line, nil
}
// readLine returns an error if:
// - there is a pending read error;
// - or line does not end with \r\n.
func (r *Reader) readLine() ([]byte, error) {
b, err := r.rd.ReadSlice('\n')
if err != nil {
if err != bufio.ErrBufferFull {
return nil, err
}
full := make([]byte, len(b))
copy(full, b)
b, err = r.rd.ReadBytes('\n')
if err != nil {
return nil, err
}
full = append(full, b...) //nolint:makezero
b = full
}
if len(b) <= 2 || b[len(b)-1] != '\n' || b[len(b)-2] != '\r' {
return nil, fmt.Errorf("redis: invalid reply: %q", b)
}
return b[:len(b)-2], nil
}
func (r *Reader) ReadReply() (interface{}, error) {
line, err := r.ReadLine()
if err != nil {
return nil, err
}
switch line[0] {
case RespStatus:
return string(line[1:]), nil
case RespInt:
return util.ParseInt(line[1:], 10, 64)
case RespFloat:
return r.readFloat(line)
case RespBool:
return r.readBool(line)
case RespBigInt:
return r.readBigInt(line)
case RespString:
return r.readStringReply(line)
case RespVerbatim:
return r.readVerb(line)
case RespArray, RespSet, RespPush:
return r.readSlice(line)
case RespMap:
return r.readMap(line)
}
return nil, fmt.Errorf("redis: can't parse %.100q", line)
}
func (r *Reader) readFloat(line []byte) (float64, error) {
v := string(line[1:])
switch string(line[1:]) {
case "inf":
return math.Inf(1), nil
case "-inf":
return math.Inf(-1), nil
}
return strconv.ParseFloat(v, 64)
}
func (r *Reader) readBool(line []byte) (bool, error) {
switch string(line[1:]) {
case "t":
return true, nil
case "f":
return false, nil
}
return false, fmt.Errorf("redis: can't parse bool reply: %q", line)
}
func (r *Reader) readBigInt(line []byte) (*big.Int, error) {
i := new(big.Int)
if i, ok := i.SetString(string(line[1:]), 10); ok {
return i, nil
}
return nil, fmt.Errorf("redis: can't parse bigInt reply: %q", line)
}
func (r *Reader) readStringReply(line []byte) (string, error) {
n, err := replyLen(line)
if err != nil {
return "", err
}
b := make([]byte, n+2)
_, err = io.ReadFull(r.rd, b)
if err != nil {
return "", err
}
return util.BytesToString(b[:n]), nil
}
func (r *Reader) readVerb(line []byte) (string, error) {
s, err := r.readStringReply(line)
if err != nil {
return "", err
}
if len(s) < 4 || s[3] != ':' {
return "", fmt.Errorf("redis: can't parse verbatim string reply: %q", line)
}
return s[4:], nil
}
func (r *Reader) readSlice(line []byte) ([]interface{}, error) {
n, err := replyLen(line)
if err != nil {
return nil, err
}
val := make([]interface{}, n)
for i := 0; i < len(val); i++ {
v, err := r.ReadReply()
if err != nil {
if err == Nil {
val[i] = nil
continue
}
if err, ok := err.(RedisError); ok {
val[i] = err
continue
}
return nil, err
}
val[i] = v
}
return val, nil
}
func (r *Reader) readMap(line []byte) (map[interface{}]interface{}, error) {
n, err := replyLen(line)
if err != nil {
return nil, err
}
m := make(map[interface{}]interface{}, n)
for i := 0; i < n; i++ {
k, err := r.ReadReply()
if err != nil {
return nil, err
}
v, err := r.ReadReply()
if err != nil {
if err == Nil {
m[k] = nil
continue
}
if err, ok := err.(RedisError); ok {
m[k] = err
continue
}
return nil, err
}
m[k] = v
}
return m, nil
}
// -------------------------------
func (r *Reader) ReadInt() (int64, error) {
line, err := r.ReadLine()
if err != nil {
return 0, err
}
switch line[0] {
case RespInt, RespStatus:
return util.ParseInt(line[1:], 10, 64)
case RespString:
s, err := r.readStringReply(line)
if err != nil {
return 0, err
}
return util.ParseInt([]byte(s), 10, 64)
case RespBigInt:
b, err := r.readBigInt(line)
if err != nil {
return 0, err
}
if !b.IsInt64() {
return 0, fmt.Errorf("bigInt(%s) value out of range", b.String())
}
return b.Int64(), nil
}
return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line)
}
func (r *Reader) ReadFloat() (float64, error) {
line, err := r.ReadLine()
if err != nil {
return 0, err
}
switch line[0] {
case RespFloat:
return r.readFloat(line)
case RespStatus:
return strconv.ParseFloat(string(line[1:]), 64)
case RespString:
s, err := r.readStringReply(line)
if err != nil {
return 0, err
}
return strconv.ParseFloat(s, 64)
}
return 0, fmt.Errorf("redis: can't parse float reply: %.100q", line)
}
func (r *Reader) ReadString() (string, error) {
line, err := r.ReadLine()
if err != nil {
return "", err
}
switch line[0] {
case RespStatus, RespInt, RespFloat:
return string(line[1:]), nil
case RespString:
return r.readStringReply(line)
case RespBool:
b, err := r.readBool(line)
return strconv.FormatBool(b), err
case RespVerbatim:
return r.readVerb(line)
case RespBigInt:
b, err := r.readBigInt(line)
if err != nil {
return "", err
}
return b.String(), nil
}
return "", fmt.Errorf("redis: can't parse reply=%.100q reading string", line)
}
func (r *Reader) ReadBool() (bool, error) {
s, err := r.ReadString()
if err != nil {
return false, err
}
return s == "OK" || s == "1" || s == "true", nil
}
func (r *Reader) ReadSlice() ([]interface{}, error) {
line, err := r.ReadLine()
if err != nil {
return nil, err
}
return r.readSlice(line)
}
// ReadFixedArrayLen read fixed array length.
func (r *Reader) ReadFixedArrayLen(fixedLen int) error {
n, err := r.ReadArrayLen()
if err != nil {
return err
}
if n != fixedLen {
return fmt.Errorf("redis: got %d elements in the array, wanted %d", n, fixedLen)
}
return nil
}
// ReadArrayLen Read and return the length of the array.
func (r *Reader) ReadArrayLen() (int, error) {
line, err := r.ReadLine()
if err != nil {
return 0, err
}
switch line[0] {
case RespArray, RespSet, RespPush:
return replyLen(line)
default:
return 0, fmt.Errorf("redis: can't parse array/set/push reply: %.100q", line)
}
}
// ReadFixedMapLen reads fixed map length.
func (r *Reader) ReadFixedMapLen(fixedLen int) error {
n, err := r.ReadMapLen()
if err != nil {
return err
}
if n != fixedLen {
return fmt.Errorf("redis: got %d elements in the map, wanted %d", n, fixedLen)
}
return nil
}
// ReadMapLen reads the length of the map type.
// If responding to the array type (RespArray/RespSet/RespPush),
// it must be a multiple of 2 and return n/2.
// Other types will return an error.
func (r *Reader) ReadMapLen() (int, error) {
line, err := r.ReadLine()
if err != nil {
return 0, err
}
switch line[0] {
case RespMap:
return replyLen(line)
case RespArray, RespSet, RespPush:
// Some commands and RESP2 protocol may respond to array types.
n, err := replyLen(line)
if err != nil {
return 0, err
}
if n%2 != 0 {
return 0, fmt.Errorf("redis: the length of the array must be a multiple of 2, got: %d", n)
}
return n / 2, nil
default:
return 0, fmt.Errorf("redis: can't parse map reply: %.100q", line)
}
}
// DiscardNext read and discard the data represented by the next line.
func (r *Reader) DiscardNext() error {
line, err := r.readLine()
if err != nil {
return err
}
return r.Discard(line)
}
// Discard the data represented by line.
func (r *Reader) Discard(line []byte) (err error) {
if len(line) == 0 {
return errors.New("redis: invalid line")
}
switch line[0] {
case RespStatus, RespError, RespInt, RespNil, RespFloat, RespBool, RespBigInt:
return nil
}
n, err := replyLen(line)
if err != nil && err != Nil {
return err
}
switch line[0] {
case RespBlobError, RespString, RespVerbatim:
// +\r\n
_, err = r.rd.Discard(n + 2)
return err
case RespArray, RespSet, RespPush:
for i := 0; i < n; i++ {
if err = r.DiscardNext(); err != nil {
return err
}
}
return nil
case RespMap, RespAttr:
// Read key & value.
for i := 0; i < n*2; i++ {
if err = r.DiscardNext(); err != nil {
return err
}
}
return nil
}
return fmt.Errorf("redis: can't parse %.100q", line)
}
func replyLen(line []byte) (n int, err error) {
n, err = util.Atoi(line[1:])
if err != nil {
return 0, err
}
if n < -1 {
return 0, fmt.Errorf("redis: invalid reply: %q", line)
}
switch line[0] {
case RespString, RespVerbatim, RespBlobError,
RespArray, RespSet, RespPush, RespMap, RespAttr:
if n == -1 {
return 0, Nil
}
}
return n, nil
}
// IsNilReply detects redis.Nil of RESP2.
func IsNilReply(line []byte) bool {
return len(line) == 3 &&
(line[0] == RespString || line[0] == RespArray) &&
line[1] == '-' && line[2] == '1'
}

@ -0,0 +1,184 @@
package proto
import (
"encoding"
"fmt"
"net"
"reflect"
"time"
"github.com/go-redis/redis/v9/internal/util"
)
// Scan parses bytes `b` to `v` with appropriate type.
//nolint:gocyclo
func Scan(b []byte, v interface{}) error {
switch v := v.(type) {
case nil:
return fmt.Errorf("redis: Scan(nil)")
case *string:
*v = util.BytesToString(b)
return nil
case *[]byte:
*v = b
return nil
case *int:
var err error
*v, err = util.Atoi(b)
return err
case *int8:
n, err := util.ParseInt(b, 10, 8)
if err != nil {
return err
}
*v = int8(n)
return nil
case *int16:
n, err := util.ParseInt(b, 10, 16)
if err != nil {
return err
}
*v = int16(n)
return nil
case *int32:
n, err := util.ParseInt(b, 10, 32)
if err != nil {
return err
}
*v = int32(n)
return nil
case *int64:
n, err := util.ParseInt(b, 10, 64)
if err != nil {
return err
}
*v = n
return nil
case *uint:
n, err := util.ParseUint(b, 10, 64)
if err != nil {
return err
}
*v = uint(n)
return nil
case *uint8:
n, err := util.ParseUint(b, 10, 8)
if err != nil {
return err
}
*v = uint8(n)
return nil
case *uint16:
n, err := util.ParseUint(b, 10, 16)
if err != nil {
return err
}
*v = uint16(n)
return nil
case *uint32:
n, err := util.ParseUint(b, 10, 32)
if err != nil {
return err
}
*v = uint32(n)
return nil
case *uint64:
n, err := util.ParseUint(b, 10, 64)
if err != nil {
return err
}
*v = n
return nil
case *float32:
n, err := util.ParseFloat(b, 32)
if err != nil {
return err
}
*v = float32(n)
return err
case *float64:
var err error
*v, err = util.ParseFloat(b, 64)
return err
case *bool:
*v = len(b) == 1 && b[0] == '1'
return nil
case *time.Time:
var err error
*v, err = time.Parse(time.RFC3339Nano, util.BytesToString(b))
return err
case *time.Duration:
n, err := util.ParseInt(b, 10, 64)
if err != nil {
return err
}
*v = time.Duration(n)
return nil
case encoding.BinaryUnmarshaler:
return v.UnmarshalBinary(b)
case *net.IP:
*v = b
return nil
default:
return fmt.Errorf(
"redis: can't unmarshal %T (consider implementing BinaryUnmarshaler)", v)
}
}
func ScanSlice(data []string, slice interface{}) error {
v := reflect.ValueOf(slice)
if !v.IsValid() {
return fmt.Errorf("redis: ScanSlice(nil)")
}
if v.Kind() != reflect.Ptr {
return fmt.Errorf("redis: ScanSlice(non-pointer %T)", slice)
}
v = v.Elem()
if v.Kind() != reflect.Slice {
return fmt.Errorf("redis: ScanSlice(non-slice %T)", slice)
}
next := makeSliceNextElemFunc(v)
for i, s := range data {
elem := next()
if err := Scan([]byte(s), elem.Addr().Interface()); err != nil {
err = fmt.Errorf("redis: ScanSlice index=%d value=%q failed: %w", i, s, err)
return err
}
}
return nil
}
func makeSliceNextElemFunc(v reflect.Value) func() reflect.Value {
elemType := v.Type().Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
elem := v.Index(v.Len() - 1)
if elem.IsNil() {
elem.Set(reflect.New(elemType))
}
return elem.Elem()
}
elem := reflect.New(elemType)
v.Set(reflect.Append(v, elem))
return elem.Elem()
}
}
zero := reflect.Zero(elemType)
return func() reflect.Value {
if v.Len() < v.Cap() {
v.Set(v.Slice(0, v.Len()+1))
return v.Index(v.Len() - 1)
}
v.Set(reflect.Append(v, zero))
return v.Index(v.Len() - 1)
}
}

@ -0,0 +1,158 @@
package proto
import (
"encoding"
"fmt"
"io"
"net"
"strconv"
"time"
"github.com/go-redis/redis/v9/internal/util"
)
type writer interface {
io.Writer
io.ByteWriter
// WriteString implement io.StringWriter.
WriteString(s string) (n int, err error)
}
type Writer struct {
writer
lenBuf []byte
numBuf []byte
}
func NewWriter(wr writer) *Writer {
return &Writer{
writer: wr,
lenBuf: make([]byte, 64),
numBuf: make([]byte, 64),
}
}
func (w *Writer) WriteArgs(args []interface{}) error {
if err := w.WriteByte(RespArray); err != nil {
return err
}
if err := w.writeLen(len(args)); err != nil {
return err
}
for _, arg := range args {
if err := w.WriteArg(arg); err != nil {
return err
}
}
return nil
}
func (w *Writer) writeLen(n int) error {
w.lenBuf = strconv.AppendUint(w.lenBuf[:0], uint64(n), 10)
w.lenBuf = append(w.lenBuf, '\r', '\n')
_, err := w.Write(w.lenBuf)
return err
}
func (w *Writer) WriteArg(v interface{}) error {
switch v := v.(type) {
case nil:
return w.string("")
case string:
return w.string(v)
case []byte:
return w.bytes(v)
case int:
return w.int(int64(v))
case int8:
return w.int(int64(v))
case int16:
return w.int(int64(v))
case int32:
return w.int(int64(v))
case int64:
return w.int(v)
case uint:
return w.uint(uint64(v))
case uint8:
return w.uint(uint64(v))
case uint16:
return w.uint(uint64(v))
case uint32:
return w.uint(uint64(v))
case uint64:
return w.uint(v)
case float32:
return w.float(float64(v))
case float64:
return w.float(v)
case bool:
if v {
return w.int(1)
}
return w.int(0)
case time.Time:
w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano)
return w.bytes(w.numBuf)
case time.Duration:
return w.int(v.Nanoseconds())
case encoding.BinaryMarshaler:
b, err := v.MarshalBinary()
if err != nil {
return err
}
return w.bytes(b)
case net.IP:
return w.bytes(v)
default:
return fmt.Errorf(
"redis: can't marshal %T (implement encoding.BinaryMarshaler)", v)
}
}
func (w *Writer) bytes(b []byte) error {
if err := w.WriteByte(RespString); err != nil {
return err
}
if err := w.writeLen(len(b)); err != nil {
return err
}
if _, err := w.Write(b); err != nil {
return err
}
return w.crlf()
}
func (w *Writer) string(s string) error {
return w.bytes(util.StringToBytes(s))
}
func (w *Writer) uint(n uint64) error {
w.numBuf = strconv.AppendUint(w.numBuf[:0], n, 10)
return w.bytes(w.numBuf)
}
func (w *Writer) int(n int64) error {
w.numBuf = strconv.AppendInt(w.numBuf[:0], n, 10)
return w.bytes(w.numBuf)
}
func (w *Writer) float(f float64) error {
w.numBuf = strconv.AppendFloat(w.numBuf[:0], f, 'f', -1, 64)
return w.bytes(w.numBuf)
}
func (w *Writer) crlf() error {
if err := w.WriteByte('\r'); err != nil {
return err
}
return w.WriteByte('\n')
}

@ -0,0 +1,50 @@
package rand
import (
"math/rand"
"sync"
)
// Int returns a non-negative pseudo-random int.
func Int() int { return pseudo.Int() }
// Intn returns, as an int, a non-negative pseudo-random number in [0,n).
// It panics if n <= 0.
func Intn(n int) int { return pseudo.Intn(n) }
// Int63n returns, as an int64, a non-negative pseudo-random number in [0,n).
// It panics if n <= 0.
func Int63n(n int64) int64 { return pseudo.Int63n(n) }
// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers [0,n).
func Perm(n int) []int { return pseudo.Perm(n) }
// Seed uses the provided seed value to initialize the default Source to a
// deterministic state. If Seed is not called, the generator behaves as if
// seeded by Seed(1).
func Seed(n int64) { pseudo.Seed(n) }
var pseudo = rand.New(&source{src: rand.NewSource(1)})
type source struct {
src rand.Source
mu sync.Mutex
}
func (s *source) Int63() int64 {
s.mu.Lock()
n := s.src.Int63()
s.mu.Unlock()
return n
}
func (s *source) Seed(seed int64) {
s.mu.Lock()
s.src.Seed(seed)
s.mu.Unlock()
}
// Shuffle pseudo-randomizes the order of elements.
// n is the number of elements.
// swap swaps the elements with indexes i and j.
func Shuffle(n int, swap func(i, j int)) { pseudo.Shuffle(n, swap) }

@ -0,0 +1,12 @@
//go:build appengine
// +build appengine
package internal
func String(b []byte) string {
return string(b)
}
func Bytes(s string) []byte {
return []byte(s)
}

@ -0,0 +1,21 @@
//go:build !appengine
// +build !appengine
package internal
import "unsafe"
// String converts byte slice to string.
func String(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// Bytes converts string to byte slice.
func Bytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(
&struct {
string
Cap int
}{s, len(s)},
))
}

@ -0,0 +1,46 @@
package internal
import (
"context"
"time"
"github.com/go-redis/redis/v9/internal/util"
)
func Sleep(ctx context.Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func ToLower(s string) string {
if isLower(s) {
return s
}
b := make([]byte, len(s))
for i := range b {
c := s[i]
if c >= 'A' && c <= 'Z' {
c += 'a' - 'A'
}
b[i] = c
}
return util.BytesToString(b)
}
func isLower(s string) bool {
for i := 0; i < len(s); i++ {
c := s[i]
if c >= 'A' && c <= 'Z' {
return false
}
}
return true
}

@ -0,0 +1,12 @@
//go:build appengine
// +build appengine
package util
func BytesToString(b []byte) string {
return string(b)
}
func StringToBytes(s string) []byte {
return []byte(s)
}

@ -0,0 +1,19 @@
package util
import "strconv"
func Atoi(b []byte) (int, error) {
return strconv.Atoi(BytesToString(b))
}
func ParseInt(b []byte, base int, bitSize int) (int64, error) {
return strconv.ParseInt(BytesToString(b), base, bitSize)
}
func ParseUint(b []byte, base int, bitSize int) (uint64, error) {
return strconv.ParseUint(BytesToString(b), base, bitSize)
}
func ParseFloat(b []byte, bitSize int) (float64, error) {
return strconv.ParseFloat(BytesToString(b), bitSize)
}

@ -0,0 +1,23 @@
//go:build !appengine
// +build !appengine
package util
import (
"unsafe"
)
// BytesToString converts byte slice to string.
func BytesToString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// StringToBytes converts string to byte slice.
func StringToBytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(
&struct {
string
Cap int
}{s, len(s)},
))
}

@ -0,0 +1,77 @@
package redis
import (
"context"
"sync"
)
// ScanIterator is used to incrementally iterate over a collection of elements.
// It's safe for concurrent use by multiple goroutines.
type ScanIterator struct {
mu sync.Mutex // protects Scanner and pos
cmd *ScanCmd
pos int
}
// Err returns the last iterator error, if any.
func (it *ScanIterator) Err() error {
it.mu.Lock()
err := it.cmd.Err()
it.mu.Unlock()
return err
}
// Next advances the cursor and returns true if more values can be read.
func (it *ScanIterator) Next(ctx context.Context) bool {
it.mu.Lock()
defer it.mu.Unlock()
// Instantly return on errors.
if it.cmd.Err() != nil {
return false
}
// Advance cursor, check if we are still within range.
if it.pos < len(it.cmd.page) {
it.pos++
return true
}
for {
// Return if there is no more data to fetch.
if it.cmd.cursor == 0 {
return false
}
// Fetch next page.
switch it.cmd.args[0] {
case "scan", "qscan":
it.cmd.args[1] = it.cmd.cursor
default:
it.cmd.args[2] = it.cmd.cursor
}
err := it.cmd.process(ctx, it.cmd)
if err != nil {
return false
}
it.pos = 1
// Redis can occasionally return empty page.
if len(it.cmd.page) > 0 {
return true
}
}
}
// Val returns the key/field at the current cursor position.
func (it *ScanIterator) Val() string {
var v string
it.mu.Lock()
if it.cmd.Err() == nil && it.pos > 0 && it.pos <= len(it.cmd.page) {
v = it.cmd.page[it.pos-1]
}
it.mu.Unlock()
return v
}

@ -0,0 +1,435 @@
package redis
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"runtime"
"sort"
"strconv"
"strings"
"time"
"github.com/go-redis/redis/v9/internal/pool"
)
// Limiter is the interface of a rate limiter or a circuit breaker.
type Limiter interface {
// Allow returns nil if operation is allowed or an error otherwise.
// If operation is allowed client must ReportResult of the operation
// whether it is a success or a failure.
Allow() error
// ReportResult reports the result of the previously allowed operation.
// nil indicates a success, non-nil error usually indicates a failure.
ReportResult(result error)
}
// Options keeps the settings to setup redis connection.
type Options struct {
// The network type, either tcp or unix.
// Default is tcp.
Network string
// host:port address.
Addr string
// Dialer creates new network connection and has priority over
// Network and Addr options.
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
// Hook that is called when new connection is established.
OnConnect func(ctx context.Context, cn *Conn) error
// Use the specified Username to authenticate the current connection
// with one of the connections defined in the ACL list when connecting
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system.
Username string
// Optional password. Must match the password specified in the
// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower),
// or the User Password when connecting to a Redis 6.0 instance, or greater,
// that is using the Redis ACL system.
Password string
// CredentialsProvider allows the username and password to be updated
// before reconnecting. It should return the current username and password.
CredentialsProvider func() (username string, password string)
// Database to be selected after connecting to the server.
DB int
// Maximum number of retries before giving up.
// Default is 3 retries; -1 (not 0) disables retries.
MaxRetries int
// Minimum backoff between each retry.
// Default is 8 milliseconds; -1 disables backoff.
MinRetryBackoff time.Duration
// Maximum backoff between each retry.
// Default is 512 milliseconds; -1 disables backoff.
MaxRetryBackoff time.Duration
// Dial timeout for establishing new connections.
// Default is 5 seconds.
DialTimeout time.Duration
// Timeout for socket reads. If reached, commands will fail
// with a timeout instead of blocking. Use value -1 for no timeout and 0 for default.
// Default is 3 seconds.
ReadTimeout time.Duration
// Timeout for socket writes. If reached, commands will fail
// with a timeout instead of blocking.
// Default is ReadTimeout.
WriteTimeout time.Duration
// Type of connection pool.
// true for FIFO pool, false for LIFO pool.
// Note that fifo has higher overhead compared to lifo.
PoolFIFO bool
// Maximum number of socket connections.
// Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS.
PoolSize int
// Minimum number of idle connections which is useful when establishing
// new connection is slow.
MinIdleConns int
// Connection age at which client retires (closes) the connection.
// Default is to not close aged connections.
MaxConnAge time.Duration
// Amount of time client waits for connection if all connections
// are busy before returning an error.
// Default is ReadTimeout + 1 second.
PoolTimeout time.Duration
// Amount of time after which client closes idle connections.
// Should be less than server's timeout.
// Default is 5 minutes. -1 disables idle timeout check.
IdleTimeout time.Duration
// Frequency of idle checks made by idle connections reaper.
// Default is 1 minute. -1 disables idle connections reaper,
// but idle connections are still discarded by the client
// if IdleTimeout is set.
IdleCheckFrequency time.Duration
// Enables read only queries on slave nodes.
readOnly bool
// TLS Config to use. When set TLS will be negotiated.
TLSConfig *tls.Config
// Limiter interface used to implemented circuit breaker or rate limiter.
Limiter Limiter
}
func (opt *Options) init() {
if opt.Addr == "" {
opt.Addr = "localhost:6379"
}
if opt.Network == "" {
if strings.HasPrefix(opt.Addr, "/") {
opt.Network = "unix"
} else {
opt.Network = "tcp"
}
}
if opt.DialTimeout == 0 {
opt.DialTimeout = 5 * time.Second
}
if opt.Dialer == nil {
opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialer := &net.Dialer{
Timeout: opt.DialTimeout,
KeepAlive: 5 * time.Minute,
}
if opt.TLSConfig == nil {
return netDialer.DialContext(ctx, network, addr)
}
return tls.DialWithDialer(netDialer, network, addr, opt.TLSConfig)
}
}
if opt.PoolSize == 0 {
opt.PoolSize = 10 * runtime.GOMAXPROCS(0)
}
switch opt.ReadTimeout {
case -1:
opt.ReadTimeout = 0
case 0:
opt.ReadTimeout = 3 * time.Second
}
switch opt.WriteTimeout {
case -1:
opt.WriteTimeout = 0
case 0:
opt.WriteTimeout = opt.ReadTimeout
}
if opt.PoolTimeout == 0 {
opt.PoolTimeout = opt.ReadTimeout + time.Second
}
if opt.IdleTimeout == 0 {
opt.IdleTimeout = 5 * time.Minute
}
if opt.IdleCheckFrequency == 0 {
opt.IdleCheckFrequency = time.Minute
}
if opt.MaxRetries == -1 {
opt.MaxRetries = 0
} else if opt.MaxRetries == 0 {
opt.MaxRetries = 3
}
switch opt.MinRetryBackoff {
case -1:
opt.MinRetryBackoff = 0
case 0:
opt.MinRetryBackoff = 8 * time.Millisecond
}
switch opt.MaxRetryBackoff {
case -1:
opt.MaxRetryBackoff = 0
case 0:
opt.MaxRetryBackoff = 512 * time.Millisecond
}
}
func (opt *Options) clone() *Options {
clone := *opt
return &clone
}
// ParseURL parses an URL into Options that can be used to connect to Redis.
// Scheme is required.
// There are two connection types: by tcp socket and by unix socket.
// Tcp connection:
// redis://<user>:<password>@<host>:<port>/<db_number>
// Unix connection:
// unix://<user>:<password>@</path/to/redis.sock>?db=<db_number>
// Most Option fields can be set using query parameters, with the following restrictions:
// - field names are mapped using snake-case conversion: to set MaxRetries, use max_retries
// - only scalar type fields are supported (bool, int, time.Duration)
// - for time.Duration fields, values must be a valid input for time.ParseDuration();
// additionally a plain integer as value (i.e. without unit) is intepreted as seconds
// - to disable a duration field, use value less than or equal to 0; to use the default
// value, leave the value blank or remove the parameter
// - only the last value is interpreted if a parameter is given multiple times
// - fields "network", "addr", "username" and "password" can only be set using other
// URL attributes (scheme, host, userinfo, resp.), query paremeters using these
// names will be treated as unknown parameters
// - unknown parameter names will result in an error
// Examples:
// redis://user:password@localhost:6789/3?dial_timeout=3&db=1&read_timeout=6s&max_retries=2
// is equivalent to:
// &Options{
// Network: "tcp",
// Addr: "localhost:6789",
// DB: 1, // path "/3" was overridden by "&db=1"
// DialTimeout: 3 * time.Second, // no time unit = seconds
// ReadTimeout: 6 * time.Second,
// MaxRetries: 2,
// }
func ParseURL(redisURL string) (*Options, error) {
u, err := url.Parse(redisURL)
if err != nil {
return nil, err
}
switch u.Scheme {
case "redis", "rediss":
return setupTCPConn(u)
case "unix":
return setupUnixConn(u)
default:
return nil, fmt.Errorf("redis: invalid URL scheme: %s", u.Scheme)
}
}
func setupTCPConn(u *url.URL) (*Options, error) {
o := &Options{Network: "tcp"}
o.Username, o.Password = getUserPassword(u)
h, p, err := net.SplitHostPort(u.Host)
if err != nil {
h = u.Host
}
if h == "" {
h = "localhost"
}
if p == "" {
p = "6379"
}
o.Addr = net.JoinHostPort(h, p)
f := strings.FieldsFunc(u.Path, func(r rune) bool {
return r == '/'
})
switch len(f) {
case 0:
o.DB = 0
case 1:
if o.DB, err = strconv.Atoi(f[0]); err != nil {
return nil, fmt.Errorf("redis: invalid database number: %q", f[0])
}
default:
return nil, fmt.Errorf("redis: invalid URL path: %s", u.Path)
}
if u.Scheme == "rediss" {
o.TLSConfig = &tls.Config{
ServerName: h,
MinVersion: tls.VersionTLS12,
}
}
return setupConnParams(u, o)
}
func setupUnixConn(u *url.URL) (*Options, error) {
o := &Options{
Network: "unix",
}
if strings.TrimSpace(u.Path) == "" { // path is required with unix connection
return nil, errors.New("redis: empty unix socket path")
}
o.Addr = u.Path
o.Username, o.Password = getUserPassword(u)
return setupConnParams(u, o)
}
type queryOptions struct {
q url.Values
err error
}
func (o *queryOptions) string(name string) string {
vs := o.q[name]
if len(vs) == 0 {
return ""
}
delete(o.q, name) // enable detection of unknown parameters
return vs[len(vs)-1]
}
func (o *queryOptions) int(name string) int {
s := o.string(name)
if s == "" {
return 0
}
i, err := strconv.Atoi(s)
if err == nil {
return i
}
if o.err == nil {
o.err = fmt.Errorf("redis: invalid %s number: %s", name, err)
}
return 0
}
func (o *queryOptions) duration(name string) time.Duration {
s := o.string(name)
if s == "" {
return 0
}
// try plain number first
if i, err := strconv.Atoi(s); err == nil {
if i <= 0 {
// disable timeouts
return -1
}
return time.Duration(i) * time.Second
}
dur, err := time.ParseDuration(s)
if err == nil {
return dur
}
if o.err == nil {
o.err = fmt.Errorf("redis: invalid %s duration: %w", name, err)
}
return 0
}
func (o *queryOptions) bool(name string) bool {
switch s := o.string(name); s {
case "true", "1":
return true
case "false", "0", "":
return false
default:
if o.err == nil {
o.err = fmt.Errorf("redis: invalid %s boolean: expected true/false/1/0 or an empty string, got %q", name, s)
}
return false
}
}
func (o *queryOptions) remaining() []string {
if len(o.q) == 0 {
return nil
}
keys := make([]string, 0, len(o.q))
for k := range o.q {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
// setupConnParams converts query parameters in u to option value in o.
func setupConnParams(u *url.URL, o *Options) (*Options, error) {
q := queryOptions{q: u.Query()}
// compat: a future major release may use q.int("db")
if tmp := q.string("db"); tmp != "" {
db, err := strconv.Atoi(tmp)
if err != nil {
return nil, fmt.Errorf("redis: invalid database number: %w", err)
}
o.DB = db
}
o.MaxRetries = q.int("max_retries")
o.MinRetryBackoff = q.duration("min_retry_backoff")
o.MaxRetryBackoff = q.duration("max_retry_backoff")
o.DialTimeout = q.duration("dial_timeout")
o.ReadTimeout = q.duration("read_timeout")
o.WriteTimeout = q.duration("write_timeout")
o.PoolFIFO = q.bool("pool_fifo")
o.PoolSize = q.int("pool_size")
o.MinIdleConns = q.int("min_idle_conns")
o.MaxConnAge = q.duration("max_conn_age")
o.PoolTimeout = q.duration("pool_timeout")
o.IdleTimeout = q.duration("idle_timeout")
o.IdleCheckFrequency = q.duration("idle_check_frequency")
if q.err != nil {
return nil, q.err
}
// any parameters left?
if r := q.remaining(); len(r) > 0 {
return nil, fmt.Errorf("redis: unexpected option: %s", strings.Join(r, ", "))
}
return o, nil
}
func getUserPassword(u *url.URL) (string, string) {
var user, password string
if u.User != nil {
user = u.User.Username()
if p, ok := u.User.Password(); ok {
password = p
}
}
return user, password
}
func newConnPool(opt *Options) *pool.ConnPool {
return pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return opt.Dialer(ctx, opt.Network, opt.Addr)
},
PoolFIFO: opt.PoolFIFO,
PoolSize: opt.PoolSize,
MinIdleConns: opt.MinIdleConns,
MaxConnAge: opt.MaxConnAge,
PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout,
IdleCheckFrequency: opt.IdleCheckFrequency,
})
}

@ -0,0 +1,8 @@
{
"name": "redis",
"version": "9.0.0-beta.1",
"main": "index.js",
"repository": "git@github.com:go-redis/redis.git",
"author": "Vladimir Mihailenco <vladimir.webdev@gmail.com>",
"license": "BSD-2-clause"
}

@ -0,0 +1,119 @@
package redis
import (
"context"
"sync"
)
type pipelineExecer func(context.Context, []Cmder) error
// Pipeliner is an mechanism to realise Redis Pipeline technique.
//
// Pipelining is a technique to extremely speed up processing by packing
// operations to batches, send them at once to Redis and read a replies in a
// singe step.
// See https://redis.io/topics/pipelining
//
// Pay attention, that Pipeline is not a transaction, so you can get unexpected
// results in case of big pipelines and small read/write timeouts.
// Redis client has retransmission logic in case of timeouts, pipeline
// can be retransmitted and commands can be executed more then once.
// To avoid this: it is good idea to use reasonable bigger read/write timeouts
// depends of your batch size and/or use TxPipeline.
type Pipeliner interface {
StatefulCmdable
Len() int
Do(ctx context.Context, args ...interface{}) *Cmd
Process(ctx context.Context, cmd Cmder) error
Discard()
Exec(ctx context.Context) ([]Cmder, error)
}
var _ Pipeliner = (*Pipeline)(nil)
// Pipeline implements pipelining as described in
// http://redis.io/topics/pipelining. It's safe for concurrent use
// by multiple goroutines.
type Pipeline struct {
cmdable
statefulCmdable
ctx context.Context
exec pipelineExecer
mu sync.Mutex
cmds []Cmder
}
func (c *Pipeline) init() {
c.cmdable = c.Process
c.statefulCmdable = c.Process
}
// Len returns the number of queued commands.
func (c *Pipeline) Len() int {
c.mu.Lock()
ln := len(c.cmds)
c.mu.Unlock()
return ln
}
// Do queues the custom command for later execution.
func (c *Pipeline) Do(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(ctx, args...)
_ = c.Process(ctx, cmd)
return cmd
}
// Process queues the cmd for later execution.
func (c *Pipeline) Process(ctx context.Context, cmd Cmder) error {
c.mu.Lock()
c.cmds = append(c.cmds, cmd)
c.mu.Unlock()
return nil
}
// Discard resets the pipeline and discards queued commands.
func (c *Pipeline) Discard() {
c.mu.Lock()
c.cmds = c.cmds[:0]
c.mu.Unlock()
}
// Exec executes all previously queued commands using one
// client-server roundtrip.
//
// Exec always returns list of commands and error of the first failed
// command if any.
func (c *Pipeline) Exec(ctx context.Context) ([]Cmder, error) {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.cmds) == 0 {
return nil, nil
}
cmds := c.cmds
c.cmds = nil
return cmds, c.exec(ctx, cmds)
}
func (c *Pipeline) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
if err := fn(c); err != nil {
return nil, err
}
return c.Exec(ctx)
}
func (c *Pipeline) Pipeline() Pipeliner {
return c
}
func (c *Pipeline) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipelined(ctx, fn)
}
func (c *Pipeline) TxPipeline() Pipeliner {
return c
}

@ -0,0 +1,668 @@
package redis
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/go-redis/redis/v9/internal"
"github.com/go-redis/redis/v9/internal/pool"
"github.com/go-redis/redis/v9/internal/proto"
)
// PubSub implements Pub/Sub commands as described in
// http://redis.io/topics/pubsub. Message receiving is NOT safe
// for concurrent use by multiple goroutines.
//
// PubSub automatically reconnects to Redis Server and resubscribes
// to the channels in case of network errors.
type PubSub struct {
opt *Options
newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
closeConn func(*pool.Conn) error
mu sync.Mutex
cn *pool.Conn
channels map[string]struct{}
patterns map[string]struct{}
closed bool
exit chan struct{}
cmd *Cmd
chOnce sync.Once
msgCh *channel
allCh *channel
}
func (c *PubSub) init() {
c.exit = make(chan struct{})
}
func (c *PubSub) String() string {
channels := mapKeys(c.channels)
channels = append(channels, mapKeys(c.patterns)...)
return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
}
func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) {
c.mu.Lock()
cn, err := c.conn(ctx, nil)
c.mu.Unlock()
return cn, err
}
func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, error) {
if c.closed {
return nil, pool.ErrClosed
}
if c.cn != nil {
return c.cn, nil
}
channels := mapKeys(c.channels)
channels = append(channels, newChannels...)
cn, err := c.newConn(ctx, channels)
if err != nil {
return nil, err
}
if err := c.resubscribe(ctx, cn); err != nil {
_ = c.closeConn(cn)
return nil, err
}
c.cn = cn
return cn, nil
}
func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error {
return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
})
}
func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error {
var firstErr error
if len(c.channels) > 0 {
firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels))
}
if len(c.patterns) > 0 {
err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns))
if err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
func mapKeys(m map[string]struct{}) []string {
s := make([]string, len(m))
i := 0
for k := range m {
s[i] = k
i++
}
return s
}
func (c *PubSub) _subscribe(
ctx context.Context, cn *pool.Conn, redisCmd string, channels []string,
) error {
args := make([]interface{}, 0, 1+len(channels))
args = append(args, redisCmd)
for _, channel := range channels {
args = append(args, channel)
}
cmd := NewSliceCmd(ctx, args...)
return c.writeCmd(ctx, cn, cmd)
}
func (c *PubSub) releaseConnWithLock(
ctx context.Context,
cn *pool.Conn,
err error,
allowTimeout bool,
) {
c.mu.Lock()
c.releaseConn(ctx, cn, err, allowTimeout)
c.mu.Unlock()
}
func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) {
if c.cn != cn {
return
}
if isBadConn(err, allowTimeout, c.opt.Addr) {
c.reconnect(ctx, err)
}
}
func (c *PubSub) reconnect(ctx context.Context, reason error) {
_ = c.closeTheCn(reason)
_, _ = c.conn(ctx, nil)
}
func (c *PubSub) closeTheCn(reason error) error {
if c.cn == nil {
return nil
}
if !c.closed {
internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason)
}
err := c.closeConn(c.cn)
c.cn = nil
return err
}
func (c *PubSub) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return pool.ErrClosed
}
c.closed = true
close(c.exit)
return c.closeTheCn(pool.ErrClosed)
}
// Subscribe the client to the specified channels. It returns
// empty subscription if there are no channels.
func (c *PubSub) Subscribe(ctx context.Context, channels ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
err := c.subscribe(ctx, "subscribe", channels...)
if c.channels == nil {
c.channels = make(map[string]struct{})
}
for _, s := range channels {
c.channels[s] = struct{}{}
}
return err
}
// PSubscribe the client to the given patterns. It returns
// empty subscription if there are no patterns.
func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
err := c.subscribe(ctx, "psubscribe", patterns...)
if c.patterns == nil {
c.patterns = make(map[string]struct{})
}
for _, s := range patterns {
c.patterns[s] = struct{}{}
}
return err
}
// Unsubscribe the client from the given channels, or from all of
// them if none is given.
func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
for _, channel := range channels {
delete(c.channels, channel)
}
err := c.subscribe(ctx, "unsubscribe", channels...)
return err
}
// PUnsubscribe the client from the given patterns, or from all of
// them if none is given.
func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
for _, pattern := range patterns {
delete(c.patterns, pattern)
}
err := c.subscribe(ctx, "punsubscribe", patterns...)
return err
}
func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error {
cn, err := c.conn(ctx, channels)
if err != nil {
return err
}
err = c._subscribe(ctx, cn, redisCmd, channels)
c.releaseConn(ctx, cn, err, false)
return err
}
func (c *PubSub) Ping(ctx context.Context, payload ...string) error {
args := []interface{}{"ping"}
if len(payload) == 1 {
args = append(args, payload[0])
}
cmd := NewCmd(ctx, args...)
c.mu.Lock()
defer c.mu.Unlock()
cn, err := c.conn(ctx, nil)
if err != nil {
return err
}
err = c.writeCmd(ctx, cn, cmd)
c.releaseConn(ctx, cn, err, false)
return err
}
// Subscription received after a successful subscription to channel.
type Subscription struct {
// Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe".
Kind string
// Channel name we have subscribed to.
Channel string
// Number of channels we are currently subscribed to.
Count int
}
func (m *Subscription) String() string {
return fmt.Sprintf("%s: %s", m.Kind, m.Channel)
}
// Message received as result of a PUBLISH command issued by another client.
type Message struct {
Channel string
Pattern string
Payload string
PayloadSlice []string
}
func (m *Message) String() string {
return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
}
// Pong received as result of a PING command issued by another client.
type Pong struct {
Payload string
}
func (p *Pong) String() string {
if p.Payload != "" {
return fmt.Sprintf("Pong<%s>", p.Payload)
}
return "Pong"
}
func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
switch reply := reply.(type) {
case string:
return &Pong{
Payload: reply,
}, nil
case []interface{}:
switch kind := reply[0].(string); kind {
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
// Can be nil in case of "unsubscribe".
channel, _ := reply[1].(string)
return &Subscription{
Kind: kind,
Channel: channel,
Count: int(reply[2].(int64)),
}, nil
case "message":
switch payload := reply[2].(type) {
case string:
return &Message{
Channel: reply[1].(string),
Payload: payload,
}, nil
case []interface{}:
ss := make([]string, len(payload))
for i, s := range payload {
ss[i] = s.(string)
}
return &Message{
Channel: reply[1].(string),
PayloadSlice: ss,
}, nil
default:
return nil, fmt.Errorf("redis: unsupported pubsub message payload: %T", payload)
}
case "pmessage":
return &Message{
Pattern: reply[1].(string),
Channel: reply[2].(string),
Payload: reply[3].(string),
}, nil
case "pong":
return &Pong{
Payload: reply[1].(string),
}, nil
default:
return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
}
default:
return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply)
}
}
// ReceiveTimeout acts like Receive but returns an error if message
// is not received in time. This is low-level API and in most cases
// Channel should be used instead.
func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) {
if c.cmd == nil {
c.cmd = NewCmd(ctx)
}
// Don't hold the lock to allow subscriptions and pings.
cn, err := c.connWithLock(ctx)
if err != nil {
return nil, err
}
err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error {
return c.cmd.readReply(rd)
})
c.releaseConnWithLock(ctx, cn, err, timeout > 0)
if err != nil {
return nil, err
}
return c.newMessage(c.cmd.Val())
}
// Receive returns a message as a Subscription, Message, Pong or error.
// See PubSub example for details. This is low-level API and in most cases
// Channel should be used instead.
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
return c.ReceiveTimeout(ctx, 0)
}
// ReceiveMessage returns a Message or error ignoring Subscription and Pong
// messages. This is low-level API and in most cases Channel should be used
// instead.
func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
for {
msg, err := c.Receive(ctx)
if err != nil {
return nil, err
}
switch msg := msg.(type) {
case *Subscription:
// Ignore.
case *Pong:
// Ignore.
case *Message:
return msg, nil
default:
err := fmt.Errorf("redis: unknown message: %T", msg)
return nil, err
}
}
}
func (c *PubSub) getContext() context.Context {
if c.cmd != nil {
return c.cmd.ctx
}
return context.Background()
}
//------------------------------------------------------------------------------
// Channel returns a Go channel for concurrently receiving messages.
// The channel is closed together with the PubSub. If the Go channel
// is blocked full for 30 seconds the message is dropped.
// Receive* APIs can not be used after channel is created.
//
// go-redis periodically sends ping messages to test connection health
// and re-subscribes if ping can not not received for 30 seconds.
func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message {
c.chOnce.Do(func() {
c.msgCh = newChannel(c, opts...)
c.msgCh.initMsgChan()
})
if c.msgCh == nil {
err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions")
panic(err)
}
return c.msgCh.msgCh
}
// ChannelSize is like Channel, but creates a Go channel
// with specified buffer size.
//
// Deprecated: use Channel(WithChannelSize(size)), remove in v9.
func (c *PubSub) ChannelSize(size int) <-chan *Message {
return c.Channel(WithChannelSize(size))
}
// ChannelWithSubscriptions is like Channel, but message type can be either
// *Subscription or *Message. Subscription messages can be used to detect
// reconnections.
//
// ChannelWithSubscriptions can not be used together with Channel or ChannelSize.
func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interface{} {
c.chOnce.Do(func() {
c.allCh = newChannel(c, opts...)
c.allCh.initAllChan()
})
if c.allCh == nil {
err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel")
panic(err)
}
return c.allCh.allCh
}
type ChannelOption func(c *channel)
// WithChannelSize specifies the Go chan size that is used to buffer incoming messages.
//
// The default is 100 messages.
func WithChannelSize(size int) ChannelOption {
return func(c *channel) {
c.chanSize = size
}
}
// WithChannelHealthCheckInterval specifies the health check interval.
// PubSub will ping Redis Server if it does not receive any messages within the interval.
// To disable health check, use zero interval.
//
// The default is 3 seconds.
func WithChannelHealthCheckInterval(d time.Duration) ChannelOption {
return func(c *channel) {
c.checkInterval = d
}
}
// WithChannelSendTimeout specifies the channel send timeout after which
// the message is dropped.
//
// The default is 60 seconds.
func WithChannelSendTimeout(d time.Duration) ChannelOption {
return func(c *channel) {
c.chanSendTimeout = d
}
}
type channel struct {
pubSub *PubSub
msgCh chan *Message
allCh chan interface{}
ping chan struct{}
chanSize int
chanSendTimeout time.Duration
checkInterval time.Duration
}
func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel {
c := &channel{
pubSub: pubSub,
chanSize: 100,
chanSendTimeout: time.Minute,
checkInterval: 3 * time.Second,
}
for _, opt := range opts {
opt(c)
}
if c.checkInterval > 0 {
c.initHealthCheck()
}
return c
}
func (c *channel) initHealthCheck() {
ctx := context.TODO()
c.ping = make(chan struct{}, 1)
go func() {
timer := time.NewTimer(time.Minute)
timer.Stop()
for {
timer.Reset(c.checkInterval)
select {
case <-c.ping:
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
if pingErr := c.pubSub.Ping(ctx); pingErr != nil {
c.pubSub.mu.Lock()
c.pubSub.reconnect(ctx, pingErr)
c.pubSub.mu.Unlock()
}
case <-c.pubSub.exit:
return
}
}
}()
}
// initMsgChan must be in sync with initAllChan.
func (c *channel) initMsgChan() {
ctx := context.TODO()
c.msgCh = make(chan *Message, c.chanSize)
go func() {
timer := time.NewTimer(time.Minute)
timer.Stop()
var errCount int
for {
msg, err := c.pubSub.Receive(ctx)
if err != nil {
if err == pool.ErrClosed {
close(c.msgCh)
return
}
if errCount > 0 {
time.Sleep(100 * time.Millisecond)
}
errCount++
continue
}
errCount = 0
// Any message is as good as a ping.
select {
case c.ping <- struct{}{}:
default:
}
switch msg := msg.(type) {
case *Subscription:
// Ignore.
case *Pong:
// Ignore.
case *Message:
timer.Reset(c.chanSendTimeout)
select {
case c.msgCh <- msg:
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
internal.Logger.Printf(
ctx, "redis: %s channel is full for %s (message is dropped)",
c, c.chanSendTimeout)
}
default:
internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
}
}
}()
}
// initAllChan must be in sync with initMsgChan.
func (c *channel) initAllChan() {
ctx := context.TODO()
c.allCh = make(chan interface{}, c.chanSize)
go func() {
timer := time.NewTimer(time.Minute)
timer.Stop()
var errCount int
for {
msg, err := c.pubSub.Receive(ctx)
if err != nil {
if err == pool.ErrClosed {
close(c.allCh)
return
}
if errCount > 0 {
time.Sleep(100 * time.Millisecond)
}
errCount++
continue
}
errCount = 0
// Any message is as good as a ping.
select {
case c.ping <- struct{}{}:
default:
}
switch msg := msg.(type) {
case *Pong:
// Ignore.
case *Subscription, *Message:
timer.Reset(c.chanSendTimeout)
select {
case c.allCh <- msg:
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
internal.Logger.Printf(
ctx, "redis: %s channel is full for %s (message is dropped)",
c, c.chanSendTimeout)
}
default:
internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
}
}
}()
}

@ -0,0 +1,779 @@
package redis
import (
"context"
"errors"
"fmt"
"sync/atomic"
"time"
"github.com/go-redis/redis/v9/internal"
"github.com/go-redis/redis/v9/internal/pool"
"github.com/go-redis/redis/v9/internal/proto"
)
// Nil reply returned by Redis when key does not exist.
const Nil = proto.Nil
// SetLogger set custom log
func SetLogger(logger internal.Logging) {
internal.Logger = logger
}
//------------------------------------------------------------------------------
type Hook interface {
BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error)
AfterProcess(ctx context.Context, cmd Cmder) error
BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
AfterProcessPipeline(ctx context.Context, cmds []Cmder) error
}
type hooks struct {
hooks []Hook
}
func (hs *hooks) lock() {
hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
}
func (hs hooks) clone() hooks {
clone := hs
clone.lock()
return clone
}
func (hs *hooks) AddHook(hook Hook) {
hs.hooks = append(hs.hooks, hook)
}
func (hs hooks) process(
ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
) error {
if len(hs.hooks) == 0 {
err := fn(ctx, cmd)
cmd.SetErr(err)
return err
}
var hookIndex int
var retErr error
for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd)
if retErr != nil {
cmd.SetErr(retErr)
}
}
if retErr == nil {
retErr = fn(ctx, cmd)
cmd.SetErr(retErr)
}
for hookIndex--; hookIndex >= 0; hookIndex-- {
if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil {
retErr = err
cmd.SetErr(retErr)
}
}
return retErr
}
func (hs hooks) processPipeline(
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
) error {
if len(hs.hooks) == 0 {
err := fn(ctx, cmds)
return err
}
var hookIndex int
var retErr error
for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds)
if retErr != nil {
setCmdsErr(cmds, retErr)
}
}
if retErr == nil {
retErr = fn(ctx, cmds)
}
for hookIndex--; hookIndex >= 0; hookIndex-- {
if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil {
retErr = err
setCmdsErr(cmds, retErr)
}
}
return retErr
}
func (hs hooks) processTxPipeline(
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
) error {
cmds = wrapMultiExec(ctx, cmds)
return hs.processPipeline(ctx, cmds, fn)
}
//------------------------------------------------------------------------------
type baseClient struct {
opt *Options
connPool pool.Pooler
onClose func() error // hook called when client is closed
}
func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient {
return &baseClient{
opt: opt,
connPool: connPool,
}
}
func (c *baseClient) clone() *baseClient {
clone := *c
return &clone
}
func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
opt := c.opt.clone()
opt.ReadTimeout = timeout
opt.WriteTimeout = timeout
clone := c.clone()
clone.opt = opt
return clone
}
func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
}
func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
cn, err := c.connPool.NewConn(ctx)
if err != nil {
return nil, err
}
err = c.initConn(ctx, cn)
if err != nil {
_ = c.connPool.CloseConn(cn)
return nil, err
}
return cn, nil
}
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
if c.opt.Limiter != nil {
err := c.opt.Limiter.Allow()
if err != nil {
return nil, err
}
}
cn, err := c._getConn(ctx)
if err != nil {
if c.opt.Limiter != nil {
c.opt.Limiter.ReportResult(err)
}
return nil, err
}
return cn, nil
}
func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
cn, err := c.connPool.Get(ctx)
if err != nil {
return nil, err
}
if cn.Inited {
return cn, nil
}
if err := c.initConn(ctx, cn); err != nil {
c.connPool.Remove(ctx, cn, err)
if err := errors.Unwrap(err); err != nil {
return nil, err
}
return nil, err
}
return cn, nil
}
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if cn.Inited {
return nil
}
cn.Inited = true
username, password := c.opt.Username, c.opt.Password
if c.opt.CredentialsProvider != nil {
username, password = c.opt.CredentialsProvider()
}
connPool := pool.NewSingleConnPool(c.connPool, cn)
conn := newConn(ctx, c.opt, connPool)
var auth bool
// For redis-server <6.0 that does not support the Hello command,
// we continue to provide services with RESP2.
if err := conn.Hello(ctx, 3, username, password, "").Err(); err == nil {
auth = true
} else if err.Error() != "ERR unknown command 'hello'" {
return err
}
_, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
if !auth && password != "" {
if username != "" {
pipe.AuthACL(ctx, username, password)
} else {
pipe.Auth(ctx, password)
}
}
if c.opt.DB > 0 {
pipe.Select(ctx, c.opt.DB)
}
if c.opt.readOnly {
pipe.ReadOnly(ctx)
}
return nil
})
if err != nil {
return err
}
if c.opt.OnConnect != nil {
return c.opt.OnConnect(ctx, conn)
}
return nil
}
func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
if c.opt.Limiter != nil {
c.opt.Limiter.ReportResult(err)
}
if isBadConn(err, false, c.opt.Addr) {
c.connPool.Remove(ctx, cn, err)
} else {
c.connPool.Put(ctx, cn)
}
}
func (c *baseClient) withConn(
ctx context.Context, fn func(context.Context, *pool.Conn) error,
) error {
cn, err := c.getConn(ctx)
if err != nil {
return err
}
defer func() {
c.releaseConn(ctx, cn, err)
}()
done := ctx.Done() //nolint:ifshort
if done == nil {
err = fn(ctx, cn)
return err
}
errc := make(chan error, 1)
go func() { errc <- fn(ctx, cn) }()
select {
case <-done:
_ = cn.Close()
// Wait for the goroutine to finish and send something.
<-errc
err = ctx.Err()
return err
case err = <-errc:
return err
}
}
func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
var lastErr error
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
attempt := attempt
retry, err := c._process(ctx, cmd, attempt)
if err == nil || !retry {
return err
}
lastErr = err
}
return lastErr
}
func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool, error) {
if attempt > 0 {
if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return false, err
}
}
retryTimeout := uint32(1)
err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
})
if err != nil {
return err
}
err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
if err != nil {
if cmd.readTimeout() == nil {
atomic.StoreUint32(&retryTimeout, 1)
}
return err
}
return nil
})
if err == nil {
return false, nil
}
retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
return retry, err
}
func (c *baseClient) retryBackoff(attempt int) time.Duration {
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
}
func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
if timeout := cmd.readTimeout(); timeout != nil {
t := *timeout
if t == 0 {
return 0
}
return t + 10*time.Second
}
return c.opt.ReadTimeout
}
// Close closes the client, releasing any open resources.
//
// It is rare to Close a Client, as the Client is meant to be
// long-lived and shared between many goroutines.
func (c *baseClient) Close() error {
var firstErr error
if c.onClose != nil {
if err := c.onClose(); err != nil {
firstErr = err
}
}
if err := c.connPool.Close(); err != nil && firstErr == nil {
firstErr = err
}
return firstErr
}
func (c *baseClient) getAddr() string {
return c.opt.Addr
}
func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds)
}
func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
}
type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error)
func (c *baseClient) generalProcessPipeline(
ctx context.Context, cmds []Cmder, p pipelineProcessor,
) error {
err := c._generalProcessPipeline(ctx, cmds, p)
if err != nil {
setCmdsErr(cmds, err)
return err
}
return cmdsFirstErr(cmds)
}
func (c *baseClient) _generalProcessPipeline(
ctx context.Context, cmds []Cmder, p pipelineProcessor,
) error {
var lastErr error
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 {
if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return err
}
}
var canRetry bool
lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
var err error
canRetry, err = p(ctx, cn, cmds)
return err
})
if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) {
return lastErr
}
}
return lastErr
}
func (c *baseClient) pipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
})
if err != nil {
return true, err
}
err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return pipelineReadCmds(rd, cmds)
})
return true, err
}
func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
for _, cmd := range cmds {
err := cmd.readReply(rd)
cmd.SetErr(err)
if err != nil && !isRedisError(err) {
return err
}
}
return nil
}
func (c *baseClient) txPipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
})
if err != nil {
return true, err
}
err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
statusCmd := cmds[0].(*StatusCmd)
// Trim multi and exec.
cmds = cmds[1 : len(cmds)-1]
err := txPipelineReadQueued(rd, statusCmd, cmds)
if err != nil {
return err
}
return pipelineReadCmds(rd, cmds)
})
return false, err
}
func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder {
if len(cmds) == 0 {
panic("not reached")
}
cmdCopy := make([]Cmder, len(cmds)+2)
cmdCopy[0] = NewStatusCmd(ctx, "multi")
copy(cmdCopy[1:], cmds)
cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec")
return cmdCopy
}
func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
// Parse +OK.
if err := statusCmd.readReply(rd); err != nil {
return err
}
// Parse +QUEUED.
for range cmds {
if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) {
return err
}
}
// Parse number of replies.
line, err := rd.ReadLine()
if err != nil {
if err == Nil {
err = TxFailedErr
}
return err
}
if line[0] != proto.RespArray {
return fmt.Errorf("redis: expected '*', but got line %q", line)
}
return nil
}
//------------------------------------------------------------------------------
// Client is a Redis client representing a pool of zero or more underlying connections.
// It's safe for concurrent use by multiple goroutines.
//
// Client creates and frees connections automatically; it also maintains a free pool
// of idle connections. You can control the pool size with Config.PoolSize option.
type Client struct {
*baseClient
cmdable
hooks
ctx context.Context
}
// NewClient returns a client to the Redis Server specified by Options.
func NewClient(opt *Options) *Client {
opt.init()
c := Client{
baseClient: newBaseClient(opt, newConnPool(opt)),
ctx: context.Background(),
}
c.cmdable = c.Process
return &c
}
func (c *Client) clone() *Client {
clone := *c
clone.cmdable = clone.Process
clone.hooks.lock()
return &clone
}
func (c *Client) WithTimeout(timeout time.Duration) *Client {
clone := c.clone()
clone.baseClient = c.baseClient.withTimeout(timeout)
return clone
}
func (c *Client) Context() context.Context {
return c.ctx
}
func (c *Client) WithContext(ctx context.Context) *Client {
if ctx == nil {
panic("nil context")
}
clone := c.clone()
clone.ctx = ctx
return clone
}
func (c *Client) Conn(ctx context.Context) *Conn {
return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool))
}
// Do creates a Cmd from the args and processes the cmd.
func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(ctx, args...)
_ = c.Process(ctx, cmd)
return cmd
}
func (c *Client) Process(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.baseClient.process)
}
func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
}
func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
}
// Options returns read-only Options that were used to create the client.
func (c *Client) Options() *Options {
return c.opt
}
type PoolStats pool.Stats
// PoolStats returns connection pool stats.
func (c *Client) PoolStats() *PoolStats {
stats := c.connPool.Stats()
return (*PoolStats)(stats)
}
func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(ctx, fn)
}
func (c *Client) Pipeline() Pipeliner {
pipe := Pipeline{
ctx: c.ctx,
exec: c.processPipeline,
}
pipe.init()
return &pipe
}
func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(ctx, fn)
}
// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
func (c *Client) TxPipeline() Pipeliner {
pipe := Pipeline{
ctx: c.ctx,
exec: c.processTxPipeline,
}
pipe.init()
return &pipe
}
func (c *Client) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
},
closeConn: c.connPool.CloseConn,
}
pubsub.init()
return pubsub
}
// Subscribe subscribes the client to the specified channels.
// Channels can be omitted to create empty subscription.
// Note that this method does not wait on a response from Redis, so the
// subscription may not be active immediately. To force the connection to wait,
// you may call the Receive() method on the returned *PubSub like so:
//
// sub := client.Subscribe(queryResp)
// iface, err := sub.Receive()
// if err != nil {
// // handle error
// }
//
// // Should be *Subscription, but others are possible if other actions have been
// // taken on sub since it was created.
// switch iface.(type) {
// case *Subscription:
// // subscribe succeeded
// case *Message:
// // received first message
// case *Pong:
// // pong received
// default:
// // handle error
// }
//
// ch := sub.Channel()
func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.Subscribe(ctx, channels...)
}
return pubsub
}
// PSubscribe subscribes the client to the given patterns.
// Patterns can be omitted to create empty subscription.
func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.PSubscribe(ctx, channels...)
}
return pubsub
}
//------------------------------------------------------------------------------
type conn struct {
baseClient
cmdable
statefulCmdable
hooks // TODO: inherit hooks
}
// Conn represents a single Redis connection rather than a pool of connections.
// Prefer running commands from Client unless there is a specific need
// for a continuous single Redis connection.
type Conn struct {
*conn
ctx context.Context
}
func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn {
c := Conn{
conn: &conn{
baseClient: baseClient{
opt: opt,
connPool: connPool,
},
},
ctx: ctx,
}
c.cmdable = c.Process
c.statefulCmdable = c.Process
return &c
}
func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.baseClient.process)
}
func (c *Conn) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
}
func (c *Conn) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
}
func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(ctx, fn)
}
func (c *Conn) Pipeline() Pipeliner {
pipe := Pipeline{
ctx: c.ctx,
exec: c.processPipeline,
}
pipe.init()
return &pipe
}
func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(ctx, fn)
}
// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
func (c *Conn) TxPipeline() Pipeliner {
pipe := Pipeline{
ctx: c.ctx,
exec: c.processTxPipeline,
}
pipe.init()
return &pipe
}

@ -0,0 +1,188 @@
package redis
import "time"
// NewCmdResult returns a Cmd initialised with val and err for testing.
func NewCmdResult(val interface{}, err error) *Cmd {
var cmd Cmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewSliceResult returns a SliceCmd initialised with val and err for testing.
func NewSliceResult(val []interface{}, err error) *SliceCmd {
var cmd SliceCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewStatusResult returns a StatusCmd initialised with val and err for testing.
func NewStatusResult(val string, err error) *StatusCmd {
var cmd StatusCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewIntResult returns an IntCmd initialised with val and err for testing.
func NewIntResult(val int64, err error) *IntCmd {
var cmd IntCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewDurationResult returns a DurationCmd initialised with val and err for testing.
func NewDurationResult(val time.Duration, err error) *DurationCmd {
var cmd DurationCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewBoolResult returns a BoolCmd initialised with val and err for testing.
func NewBoolResult(val bool, err error) *BoolCmd {
var cmd BoolCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewStringResult returns a StringCmd initialised with val and err for testing.
func NewStringResult(val string, err error) *StringCmd {
var cmd StringCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewFloatResult returns a FloatCmd initialised with val and err for testing.
func NewFloatResult(val float64, err error) *FloatCmd {
var cmd FloatCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewStringSliceResult returns a StringSliceCmd initialised with val and err for testing.
func NewStringSliceResult(val []string, err error) *StringSliceCmd {
var cmd StringSliceCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewBoolSliceResult returns a BoolSliceCmd initialised with val and err for testing.
func NewBoolSliceResult(val []bool, err error) *BoolSliceCmd {
var cmd BoolSliceCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewStringStringMapResult returns a StringStringMapCmd initialised with val and err for testing.
func NewStringStringMapResult(val map[string]string, err error) *MapStringStringCmd {
var cmd MapStringStringCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewStringIntMapCmdResult returns a StringIntMapCmd initialised with val and err for testing.
func NewStringIntMapCmdResult(val map[string]int64, err error) *StringIntMapCmd {
var cmd StringIntMapCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewTimeCmdResult returns a TimeCmd initialised with val and err for testing.
func NewTimeCmdResult(val time.Time, err error) *TimeCmd {
var cmd TimeCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewZSliceCmdResult returns a ZSliceCmd initialised with val and err for testing.
func NewZSliceCmdResult(val []Z, err error) *ZSliceCmd {
var cmd ZSliceCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewZWithKeyCmdResult returns a NewZWithKeyCmd initialised with val and err for testing.
func NewZWithKeyCmdResult(val *ZWithKey, err error) *ZWithKeyCmd {
var cmd ZWithKeyCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewScanCmdResult returns a ScanCmd initialised with val and err for testing.
func NewScanCmdResult(keys []string, cursor uint64, err error) *ScanCmd {
var cmd ScanCmd
cmd.page = keys
cmd.cursor = cursor
cmd.SetErr(err)
return &cmd
}
// NewClusterSlotsCmdResult returns a ClusterSlotsCmd initialised with val and err for testing.
func NewClusterSlotsCmdResult(val []ClusterSlot, err error) *ClusterSlotsCmd {
var cmd ClusterSlotsCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewGeoLocationCmdResult returns a GeoLocationCmd initialised with val and err for testing.
func NewGeoLocationCmdResult(val []GeoLocation, err error) *GeoLocationCmd {
var cmd GeoLocationCmd
cmd.locations = val
cmd.SetErr(err)
return &cmd
}
// NewGeoPosCmdResult returns a GeoPosCmd initialised with val and err for testing.
func NewGeoPosCmdResult(val []*GeoPos, err error) *GeoPosCmd {
var cmd GeoPosCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewCommandsInfoCmdResult returns a CommandsInfoCmd initialised with val and err for testing.
func NewCommandsInfoCmdResult(val map[string]*CommandInfo, err error) *CommandsInfoCmd {
var cmd CommandsInfoCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewXMessageSliceCmdResult returns a XMessageSliceCmd initialised with val and err for testing.
func NewXMessageSliceCmdResult(val []XMessage, err error) *XMessageSliceCmd {
var cmd XMessageSliceCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewXStreamSliceCmdResult returns a XStreamSliceCmd initialised with val and err for testing.
func NewXStreamSliceCmdResult(val []XStream, err error) *XStreamSliceCmd {
var cmd XStreamSliceCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}
// NewXPendingResult returns a XPendingCmd initialised with val and err for testing.
func NewXPendingResult(val *XPending, err error) *XPendingCmd {
var cmd XPendingCmd
cmd.val = val
cmd.SetErr(err)
return &cmd
}

@ -0,0 +1,736 @@
package redis
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/cespare/xxhash/v2"
rendezvous "github.com/dgryski/go-rendezvous" //nolint
"github.com/go-redis/redis/v9/internal"
"github.com/go-redis/redis/v9/internal/hashtag"
"github.com/go-redis/redis/v9/internal/pool"
"github.com/go-redis/redis/v9/internal/rand"
)
var errRingShardsDown = errors.New("redis: all ring shards are down")
//------------------------------------------------------------------------------
type ConsistentHash interface {
Get(string) string
}
type rendezvousWrapper struct {
*rendezvous.Rendezvous
}
func (w rendezvousWrapper) Get(key string) string {
return w.Lookup(key)
}
func newRendezvous(shards []string) ConsistentHash {
return rendezvousWrapper{rendezvous.New(shards, xxhash.Sum64String)}
}
//------------------------------------------------------------------------------
// RingOptions are used to configure a ring client and should be
// passed to NewRing.
type RingOptions struct {
// Map of name => host:port addresses of ring shards.
Addrs map[string]string
// NewClient creates a shard client with provided name and options.
NewClient func(name string, opt *Options) *Client
// Frequency of PING commands sent to check shards availability.
// Shard is considered down after 3 subsequent failed checks.
HeartbeatFrequency time.Duration
// NewConsistentHash returns a consistent hash that is used
// to distribute keys across the shards.
//
// See https://medium.com/@dgryski/consistent-hashing-algorithmic-tradeoffs-ef6b8e2fcae8
// for consistent hashing algorithmic tradeoffs.
NewConsistentHash func(shards []string) ConsistentHash
// Following options are copied from Options struct.
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
OnConnect func(ctx context.Context, cn *Conn) error
Username string
Password string
DB int
MaxRetries int
MinRetryBackoff time.Duration
MaxRetryBackoff time.Duration
DialTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
// PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO).
PoolFIFO bool
PoolSize int
MinIdleConns int
MaxConnAge time.Duration
PoolTimeout time.Duration
IdleTimeout time.Duration
IdleCheckFrequency time.Duration
TLSConfig *tls.Config
Limiter Limiter
}
func (opt *RingOptions) init() {
if opt.NewClient == nil {
opt.NewClient = func(name string, opt *Options) *Client {
return NewClient(opt)
}
}
if opt.HeartbeatFrequency == 0 {
opt.HeartbeatFrequency = 500 * time.Millisecond
}
if opt.NewConsistentHash == nil {
opt.NewConsistentHash = newRendezvous
}
if opt.MaxRetries == -1 {
opt.MaxRetries = 0
} else if opt.MaxRetries == 0 {
opt.MaxRetries = 3
}
switch opt.MinRetryBackoff {
case -1:
opt.MinRetryBackoff = 0
case 0:
opt.MinRetryBackoff = 8 * time.Millisecond
}
switch opt.MaxRetryBackoff {
case -1:
opt.MaxRetryBackoff = 0
case 0:
opt.MaxRetryBackoff = 512 * time.Millisecond
}
}
func (opt *RingOptions) clientOptions() *Options {
return &Options{
Dialer: opt.Dialer,
OnConnect: opt.OnConnect,
Username: opt.Username,
Password: opt.Password,
DB: opt.DB,
MaxRetries: -1,
DialTimeout: opt.DialTimeout,
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
PoolFIFO: opt.PoolFIFO,
PoolSize: opt.PoolSize,
MinIdleConns: opt.MinIdleConns,
MaxConnAge: opt.MaxConnAge,
PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout,
IdleCheckFrequency: opt.IdleCheckFrequency,
TLSConfig: opt.TLSConfig,
Limiter: opt.Limiter,
}
}
//------------------------------------------------------------------------------
type ringShard struct {
Client *Client
down int32
}
func newRingShard(opt *RingOptions, name, addr string) *ringShard {
clopt := opt.clientOptions()
clopt.Addr = addr
return &ringShard{
Client: opt.NewClient(name, clopt),
}
}
func (shard *ringShard) String() string {
var state string
if shard.IsUp() {
state = "up"
} else {
state = "down"
}
return fmt.Sprintf("%s is %s", shard.Client, state)
}
func (shard *ringShard) IsDown() bool {
const threshold = 3
return atomic.LoadInt32(&shard.down) >= threshold
}
func (shard *ringShard) IsUp() bool {
return !shard.IsDown()
}
// Vote votes to set shard state and returns true if state was changed.
func (shard *ringShard) Vote(up bool) bool {
if up {
changed := shard.IsDown()
atomic.StoreInt32(&shard.down, 0)
return changed
}
if shard.IsDown() {
return false
}
atomic.AddInt32(&shard.down, 1)
return shard.IsDown()
}
//------------------------------------------------------------------------------
type ringShards struct {
opt *RingOptions
mu sync.RWMutex
hash ConsistentHash
shards map[string]*ringShard // read only
list []*ringShard // read only
numShard int
closed bool
}
func newRingShards(opt *RingOptions) *ringShards {
shards := make(map[string]*ringShard, len(opt.Addrs))
list := make([]*ringShard, 0, len(shards))
for name, addr := range opt.Addrs {
shard := newRingShard(opt, name, addr)
shards[name] = shard
list = append(list, shard)
}
c := &ringShards{
opt: opt,
shards: shards,
list: list,
}
c.rebalance()
return c
}
func (c *ringShards) List() []*ringShard {
var list []*ringShard
c.mu.RLock()
if !c.closed {
list = c.list
}
c.mu.RUnlock()
return list
}
func (c *ringShards) Hash(key string) string {
key = hashtag.Key(key)
var hash string
c.mu.RLock()
if c.numShard > 0 {
hash = c.hash.Get(key)
}
c.mu.RUnlock()
return hash
}
func (c *ringShards) GetByKey(key string) (*ringShard, error) {
key = hashtag.Key(key)
c.mu.RLock()
if c.closed {
c.mu.RUnlock()
return nil, pool.ErrClosed
}
if c.numShard == 0 {
c.mu.RUnlock()
return nil, errRingShardsDown
}
hash := c.hash.Get(key)
if hash == "" {
c.mu.RUnlock()
return nil, errRingShardsDown
}
shard := c.shards[hash]
c.mu.RUnlock()
return shard, nil
}
func (c *ringShards) GetByName(shardName string) (*ringShard, error) {
if shardName == "" {
return c.Random()
}
c.mu.RLock()
shard := c.shards[shardName]
c.mu.RUnlock()
return shard, nil
}
func (c *ringShards) Random() (*ringShard, error) {
return c.GetByKey(strconv.Itoa(rand.Int()))
}
// Heartbeat monitors state of each shard in the ring.
func (c *ringShards) Heartbeat(frequency time.Duration) {
ticker := time.NewTicker(frequency)
defer ticker.Stop()
ctx := context.Background()
for range ticker.C {
var rebalance bool
for _, shard := range c.List() {
err := shard.Client.Ping(ctx).Err()
isUp := err == nil || err == pool.ErrPoolTimeout
if shard.Vote(isUp) {
internal.Logger.Printf(context.Background(), "ring shard state changed: %s", shard)
rebalance = true
}
}
if rebalance {
c.rebalance()
}
}
}
// rebalance removes dead shards from the Ring.
func (c *ringShards) rebalance() {
c.mu.RLock()
shards := c.shards
c.mu.RUnlock()
liveShards := make([]string, 0, len(shards))
for name, shard := range shards {
if shard.IsUp() {
liveShards = append(liveShards, name)
}
}
hash := c.opt.NewConsistentHash(liveShards)
c.mu.Lock()
c.hash = hash
c.numShard = len(liveShards)
c.mu.Unlock()
}
func (c *ringShards) Len() int {
c.mu.RLock()
l := c.numShard
c.mu.RUnlock()
return l
}
func (c *ringShards) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil
}
c.closed = true
var firstErr error
for _, shard := range c.shards {
if err := shard.Client.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
c.hash = nil
c.shards = nil
c.list = nil
return firstErr
}
//------------------------------------------------------------------------------
type ring struct {
opt *RingOptions
shards *ringShards
cmdsInfoCache *cmdsInfoCache //nolint:structcheck
}
// Ring is a Redis client that uses consistent hashing to distribute
// keys across multiple Redis servers (shards). It's safe for
// concurrent use by multiple goroutines.
//
// Ring monitors the state of each shard and removes dead shards from
// the ring. When a shard comes online it is added back to the ring. This
// gives you maximum availability and partition tolerance, but no
// consistency between different shards or even clients. Each client
// uses shards that are available to the client and does not do any
// coordination when shard state is changed.
//
// Ring should be used when you need multiple Redis servers for caching
// and can tolerate losing data when one of the servers dies.
// Otherwise you should use Redis Cluster.
type Ring struct {
*ring
cmdable
hooks
ctx context.Context
}
func NewRing(opt *RingOptions) *Ring {
opt.init()
ring := Ring{
ring: &ring{
opt: opt,
shards: newRingShards(opt),
},
ctx: context.Background(),
}
ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo)
ring.cmdable = ring.Process
go ring.shards.Heartbeat(opt.HeartbeatFrequency)
return &ring
}
func (c *Ring) Context() context.Context {
return c.ctx
}
func (c *Ring) WithContext(ctx context.Context) *Ring {
if ctx == nil {
panic("nil context")
}
clone := *c
clone.cmdable = clone.Process
clone.hooks.lock()
clone.ctx = ctx
return &clone
}
// Do creates a Cmd from the args and processes the cmd.
func (c *Ring) Do(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(ctx, args...)
_ = c.Process(ctx, cmd)
return cmd
}
func (c *Ring) Process(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.process)
}
// Options returns read-only Options that were used to create the client.
func (c *Ring) Options() *RingOptions {
return c.opt
}
func (c *Ring) retryBackoff(attempt int) time.Duration {
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
}
// PoolStats returns accumulated connection pool stats.
func (c *Ring) PoolStats() *PoolStats {
shards := c.shards.List()
var acc PoolStats
for _, shard := range shards {
s := shard.Client.connPool.Stats()
acc.Hits += s.Hits
acc.Misses += s.Misses
acc.Timeouts += s.Timeouts
acc.TotalConns += s.TotalConns
acc.IdleConns += s.IdleConns
}
return &acc
}
// Len returns the current number of shards in the ring.
func (c *Ring) Len() int {
return c.shards.Len()
}
// Subscribe subscribes the client to the specified channels.
func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub {
if len(channels) == 0 {
panic("at least one channel is required")
}
shard, err := c.shards.GetByKey(channels[0])
if err != nil {
// TODO: return PubSub with sticky error
panic(err)
}
return shard.Client.Subscribe(ctx, channels...)
}
// PSubscribe subscribes the client to the given patterns.
func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub {
if len(channels) == 0 {
panic("at least one channel is required")
}
shard, err := c.shards.GetByKey(channels[0])
if err != nil {
// TODO: return PubSub with sticky error
panic(err)
}
return shard.Client.PSubscribe(ctx, channels...)
}
// ForEachShard concurrently calls the fn on each live shard in the ring.
// It returns the first error if any.
func (c *Ring) ForEachShard(
ctx context.Context,
fn func(ctx context.Context, client *Client) error,
) error {
shards := c.shards.List()
var wg sync.WaitGroup
errCh := make(chan error, 1)
for _, shard := range shards {
if shard.IsDown() {
continue
}
wg.Add(1)
go func(shard *ringShard) {
defer wg.Done()
err := fn(ctx, shard.Client)
if err != nil {
select {
case errCh <- err:
default:
}
}
}(shard)
}
wg.Wait()
select {
case err := <-errCh:
return err
default:
return nil
}
}
func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) {
shards := c.shards.List()
var firstErr error
for _, shard := range shards {
cmdsInfo, err := shard.Client.Command(ctx).Result()
if err == nil {
return cmdsInfo, nil
}
if firstErr == nil {
firstErr = err
}
}
if firstErr == nil {
return nil, errRingShardsDown
}
return nil, firstErr
}
func (c *Ring) cmdInfo(ctx context.Context, name string) *CommandInfo {
cmdsInfo, err := c.cmdsInfoCache.Get(ctx)
if err != nil {
return nil
}
info := cmdsInfo[name]
if info == nil {
internal.Logger.Printf(ctx, "info for cmd=%s not found", name)
}
return info
}
func (c *Ring) cmdShard(ctx context.Context, cmd Cmder) (*ringShard, error) {
cmdInfo := c.cmdInfo(ctx, cmd.Name())
pos := cmdFirstKeyPos(cmd, cmdInfo)
if pos == 0 {
return c.shards.Random()
}
firstKey := cmd.stringArg(pos)
return c.shards.GetByKey(firstKey)
}
func (c *Ring) process(ctx context.Context, cmd Cmder) error {
var lastErr error
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 {
if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
return err
}
}
shard, err := c.cmdShard(ctx, cmd)
if err != nil {
return err
}
lastErr = shard.Client.Process(ctx, cmd)
if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) {
return lastErr
}
}
return lastErr
}
func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(ctx, fn)
}
func (c *Ring) Pipeline() Pipeliner {
pipe := Pipeline{
ctx: c.ctx,
exec: c.processPipeline,
}
pipe.init()
return &pipe
}
func (c *Ring) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
return c.generalProcessPipeline(ctx, cmds, false)
})
}
func (c *Ring) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(ctx, fn)
}
func (c *Ring) TxPipeline() Pipeliner {
pipe := Pipeline{
ctx: c.ctx,
exec: c.processTxPipeline,
}
pipe.init()
return &pipe
}
func (c *Ring) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
return c.generalProcessPipeline(ctx, cmds, true)
})
}
func (c *Ring) generalProcessPipeline(
ctx context.Context, cmds []Cmder, tx bool,
) error {
cmdsMap := make(map[string][]Cmder)
for _, cmd := range cmds {
cmdInfo := c.cmdInfo(ctx, cmd.Name())
hash := cmd.stringArg(cmdFirstKeyPos(cmd, cmdInfo))
if hash != "" {
hash = c.shards.Hash(hash)
}
cmdsMap[hash] = append(cmdsMap[hash], cmd)
}
var wg sync.WaitGroup
for hash, cmds := range cmdsMap {
wg.Add(1)
go func(hash string, cmds []Cmder) {
defer wg.Done()
_ = c.processShardPipeline(ctx, hash, cmds, tx)
}(hash, cmds)
}
wg.Wait()
return cmdsFirstErr(cmds)
}
func (c *Ring) processShardPipeline(
ctx context.Context, hash string, cmds []Cmder, tx bool,
) error {
// TODO: retry?
shard, err := c.shards.GetByName(hash)
if err != nil {
setCmdsErr(cmds, err)
return err
}
if tx {
return shard.Client.processTxPipeline(ctx, cmds)
}
return shard.Client.processPipeline(ctx, cmds)
}
func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
if len(keys) == 0 {
return fmt.Errorf("redis: Watch requires at least one key")
}
var shards []*ringShard
for _, key := range keys {
if key != "" {
shard, err := c.shards.GetByKey(hashtag.Key(key))
if err != nil {
return err
}
shards = append(shards, shard)
}
}
if len(shards) == 0 {
return fmt.Errorf("redis: Watch requires at least one shard")
}
if len(shards) > 1 {
for _, shard := range shards[1:] {
if shard.Client != shards[0].Client {
err := fmt.Errorf("redis: Watch requires all keys to be in the same shard")
return err
}
}
}
return shards[0].Client.Watch(ctx, fn, keys...)
}
// Close closes the ring client, releasing any open resources.
//
// It is rare to Close a Ring, as the Ring is meant to be long-lived
// and shared between many goroutines.
func (c *Ring) Close() error {
return c.shards.Close()
}

@ -0,0 +1,65 @@
package redis
import (
"context"
"crypto/sha1"
"encoding/hex"
"io"
"strings"
)
type Scripter interface {
Eval(ctx context.Context, script string, keys []string, args ...interface{}) *Cmd
EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *Cmd
ScriptExists(ctx context.Context, hashes ...string) *BoolSliceCmd
ScriptLoad(ctx context.Context, script string) *StringCmd
}
var (
_ Scripter = (*Client)(nil)
_ Scripter = (*Ring)(nil)
_ Scripter = (*ClusterClient)(nil)
)
type Script struct {
src, hash string
}
func NewScript(src string) *Script {
h := sha1.New()
_, _ = io.WriteString(h, src)
return &Script{
src: src,
hash: hex.EncodeToString(h.Sum(nil)),
}
}
func (s *Script) Hash() string {
return s.hash
}
func (s *Script) Load(ctx context.Context, c Scripter) *StringCmd {
return c.ScriptLoad(ctx, s.src)
}
func (s *Script) Exists(ctx context.Context, c Scripter) *BoolSliceCmd {
return c.ScriptExists(ctx, s.hash)
}
func (s *Script) Eval(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd {
return c.Eval(ctx, s.src, keys, args...)
}
func (s *Script) EvalSha(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd {
return c.EvalSha(ctx, s.hash, keys, args...)
}
// Run optimistically uses EVALSHA to run the script. If script does not exist
// it is retried using EVAL.
func (s *Script) Run(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd {
r := s.EvalSha(ctx, c, keys, args...)
if err := r.Err(); err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT ") {
return s.Eval(ctx, c, keys, args...)
}
return r
}

@ -0,0 +1,777 @@
package redis
import (
"context"
"crypto/tls"
"errors"
"net"
"strings"
"sync"
"time"
"github.com/go-redis/redis/v9/internal"
"github.com/go-redis/redis/v9/internal/pool"
"github.com/go-redis/redis/v9/internal/rand"
)
//------------------------------------------------------------------------------
// FailoverOptions are used to configure a failover client and should
// be passed to NewFailoverClient.
type FailoverOptions struct {
// The master name.
MasterName string
// A seed list of host:port addresses of sentinel nodes.
SentinelAddrs []string
// If specified with SentinelPassword, enables ACL-based authentication (via
// AUTH <user> <pass>).
SentinelUsername string
// Sentinel password from "requirepass <password>" (if enabled) in Sentinel
// configuration, or, if SentinelUsername is also supplied, used for ACL-based
// authentication.
SentinelPassword string
// Allows routing read-only commands to the closest master or replica node.
// This option only works with NewFailoverClusterClient.
RouteByLatency bool
// Allows routing read-only commands to the random master or replica node.
// This option only works with NewFailoverClusterClient.
RouteRandomly bool
// Route all commands to replica read-only nodes.
ReplicaOnly bool
// Use replicas disconnected with master when cannot get connected replicas
// Now, this option only works in RandomReplicaAddr function.
UseDisconnectedReplicas bool
// Following options are copied from Options struct.
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
OnConnect func(ctx context.Context, cn *Conn) error
Username string
Password string
DB int
MaxRetries int
MinRetryBackoff time.Duration
MaxRetryBackoff time.Duration
DialTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
// PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO).
PoolFIFO bool
PoolSize int
MinIdleConns int
MaxConnAge time.Duration
PoolTimeout time.Duration
IdleTimeout time.Duration
IdleCheckFrequency time.Duration
TLSConfig *tls.Config
}
func (opt *FailoverOptions) clientOptions() *Options {
return &Options{
Addr: "FailoverClient",
Dialer: opt.Dialer,
OnConnect: opt.OnConnect,
DB: opt.DB,
Username: opt.Username,
Password: opt.Password,
MaxRetries: opt.MaxRetries,
MinRetryBackoff: opt.MinRetryBackoff,
MaxRetryBackoff: opt.MaxRetryBackoff,
DialTimeout: opt.DialTimeout,
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
PoolFIFO: opt.PoolFIFO,
PoolSize: opt.PoolSize,
PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout,
IdleCheckFrequency: opt.IdleCheckFrequency,
MinIdleConns: opt.MinIdleConns,
MaxConnAge: opt.MaxConnAge,
TLSConfig: opt.TLSConfig,
}
}
func (opt *FailoverOptions) sentinelOptions(addr string) *Options {
return &Options{
Addr: addr,
Dialer: opt.Dialer,
OnConnect: opt.OnConnect,
DB: 0,
Username: opt.SentinelUsername,
Password: opt.SentinelPassword,
MaxRetries: opt.MaxRetries,
MinRetryBackoff: opt.MinRetryBackoff,
MaxRetryBackoff: opt.MaxRetryBackoff,
DialTimeout: opt.DialTimeout,
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
PoolFIFO: opt.PoolFIFO,
PoolSize: opt.PoolSize,
PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout,
IdleCheckFrequency: opt.IdleCheckFrequency,
MinIdleConns: opt.MinIdleConns,
MaxConnAge: opt.MaxConnAge,
TLSConfig: opt.TLSConfig,
}
}
func (opt *FailoverOptions) clusterOptions() *ClusterOptions {
return &ClusterOptions{
Dialer: opt.Dialer,
OnConnect: opt.OnConnect,
Username: opt.Username,
Password: opt.Password,
MaxRedirects: opt.MaxRetries,
RouteByLatency: opt.RouteByLatency,
RouteRandomly: opt.RouteRandomly,
MinRetryBackoff: opt.MinRetryBackoff,
MaxRetryBackoff: opt.MaxRetryBackoff,
DialTimeout: opt.DialTimeout,
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
PoolFIFO: opt.PoolFIFO,
PoolSize: opt.PoolSize,
PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout,
IdleCheckFrequency: opt.IdleCheckFrequency,
MinIdleConns: opt.MinIdleConns,
MaxConnAge: opt.MaxConnAge,
TLSConfig: opt.TLSConfig,
}
}
// NewFailoverClient returns a Redis client that uses Redis Sentinel
// for automatic failover. It's safe for concurrent use by multiple
// goroutines.
func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
if failoverOpt.RouteByLatency {
panic("to route commands by latency, use NewFailoverClusterClient")
}
if failoverOpt.RouteRandomly {
panic("to route commands randomly, use NewFailoverClusterClient")
}
sentinelAddrs := make([]string, len(failoverOpt.SentinelAddrs))
copy(sentinelAddrs, failoverOpt.SentinelAddrs)
rand.Shuffle(len(sentinelAddrs), func(i, j int) {
sentinelAddrs[i], sentinelAddrs[j] = sentinelAddrs[j], sentinelAddrs[i]
})
failover := &sentinelFailover{
opt: failoverOpt,
sentinelAddrs: sentinelAddrs,
}
opt := failoverOpt.clientOptions()
opt.Dialer = masterReplicaDialer(failover)
opt.init()
connPool := newConnPool(opt)
failover.mu.Lock()
failover.onFailover = func(ctx context.Context, addr string) {
_ = connPool.Filter(func(cn *pool.Conn) bool {
return cn.RemoteAddr().String() != addr
})
}
failover.mu.Unlock()
c := Client{
baseClient: newBaseClient(opt, connPool),
ctx: context.Background(),
}
c.cmdable = c.Process
c.onClose = failover.Close
return &c
}
func masterReplicaDialer(
failover *sentinelFailover,
) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, _ string) (net.Conn, error) {
var addr string
var err error
if failover.opt.ReplicaOnly {
addr, err = failover.RandomReplicaAddr(ctx)
} else {
addr, err = failover.MasterAddr(ctx)
if err == nil {
failover.trySwitchMaster(ctx, addr)
}
}
if err != nil {
return nil, err
}
if failover.opt.Dialer != nil {
return failover.opt.Dialer(ctx, network, addr)
}
netDialer := &net.Dialer{
Timeout: failover.opt.DialTimeout,
KeepAlive: 5 * time.Minute,
}
if failover.opt.TLSConfig == nil {
return netDialer.DialContext(ctx, network, addr)
}
return tls.DialWithDialer(netDialer, network, addr, failover.opt.TLSConfig)
}
}
//------------------------------------------------------------------------------
// SentinelClient is a client for a Redis Sentinel.
type SentinelClient struct {
*baseClient
hooks
ctx context.Context
}
func NewSentinelClient(opt *Options) *SentinelClient {
opt.init()
c := &SentinelClient{
baseClient: &baseClient{
opt: opt,
connPool: newConnPool(opt),
},
ctx: context.Background(),
}
return c
}
func (c *SentinelClient) Context() context.Context {
return c.ctx
}
func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient {
if ctx == nil {
panic("nil context")
}
clone := *c
clone.ctx = ctx
return &clone
}
func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.baseClient.process)
}
func (c *SentinelClient) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
},
closeConn: c.connPool.CloseConn,
}
pubsub.init()
return pubsub
}
// Ping is used to test if a connection is still alive, or to
// measure latency.
func (c *SentinelClient) Ping(ctx context.Context) *StringCmd {
cmd := NewStringCmd(ctx, "ping")
_ = c.Process(ctx, cmd)
return cmd
}
// Subscribe subscribes the client to the specified channels.
// Channels can be omitted to create empty subscription.
func (c *SentinelClient) Subscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.Subscribe(ctx, channels...)
}
return pubsub
}
// PSubscribe subscribes the client to the given patterns.
// Patterns can be omitted to create empty subscription.
func (c *SentinelClient) PSubscribe(ctx context.Context, channels ...string) *PubSub {
pubsub := c.pubSub()
if len(channels) > 0 {
_ = pubsub.PSubscribe(ctx, channels...)
}
return pubsub
}
func (c *SentinelClient) GetMasterAddrByName(ctx context.Context, name string) *StringSliceCmd {
cmd := NewStringSliceCmd(ctx, "sentinel", "get-master-addr-by-name", name)
_ = c.Process(ctx, cmd)
return cmd
}
func (c *SentinelClient) Sentinels(ctx context.Context, name string) *MapStringStringSliceCmd {
cmd := NewMapStringStringSliceCmd(ctx, "sentinel", "sentinels", name)
_ = c.Process(ctx, cmd)
return cmd
}
// Failover forces a failover as if the master was not reachable, and without
// asking for agreement to other Sentinels.
func (c *SentinelClient) Failover(ctx context.Context, name string) *StatusCmd {
cmd := NewStatusCmd(ctx, "sentinel", "failover", name)
_ = c.Process(ctx, cmd)
return cmd
}
// Reset resets all the masters with matching name. The pattern argument is a
// glob-style pattern. The reset process clears any previous state in a master
// (including a failover in progress), and removes every replica and sentinel
// already discovered and associated with the master.
func (c *SentinelClient) Reset(ctx context.Context, pattern string) *IntCmd {
cmd := NewIntCmd(ctx, "sentinel", "reset", pattern)
_ = c.Process(ctx, cmd)
return cmd
}
// FlushConfig forces Sentinel to rewrite its configuration on disk, including
// the current Sentinel state.
func (c *SentinelClient) FlushConfig(ctx context.Context) *StatusCmd {
cmd := NewStatusCmd(ctx, "sentinel", "flushconfig")
_ = c.Process(ctx, cmd)
return cmd
}
// Master shows the state and info of the specified master.
func (c *SentinelClient) Master(ctx context.Context, name string) *MapStringStringCmd {
cmd := NewMapStringStringCmd(ctx, "sentinel", "master", name)
_ = c.Process(ctx, cmd)
return cmd
}
// Masters shows a list of monitored masters and their state.
func (c *SentinelClient) Masters(ctx context.Context) *SliceCmd {
cmd := NewSliceCmd(ctx, "sentinel", "masters")
_ = c.Process(ctx, cmd)
return cmd
}
// Replicas shows a list of replicas for the specified master and their state.
func (c *SentinelClient) Replicas(ctx context.Context, name string) *MapStringStringSliceCmd {
cmd := NewMapStringStringSliceCmd(ctx, "sentinel", "replicas", name)
_ = c.Process(ctx, cmd)
return cmd
}
// CkQuorum checks if the current Sentinel configuration is able to reach the
// quorum needed to failover a master, and the majority needed to authorize the
// failover. This command should be used in monitoring systems to check if a
// Sentinel deployment is ok.
func (c *SentinelClient) CkQuorum(ctx context.Context, name string) *StringCmd {
cmd := NewStringCmd(ctx, "sentinel", "ckquorum", name)
_ = c.Process(ctx, cmd)
return cmd
}
// Monitor tells the Sentinel to start monitoring a new master with the specified
// name, ip, port, and quorum.
func (c *SentinelClient) Monitor(ctx context.Context, name, ip, port, quorum string) *StringCmd {
cmd := NewStringCmd(ctx, "sentinel", "monitor", name, ip, port, quorum)
_ = c.Process(ctx, cmd)
return cmd
}
// Set is used in order to change configuration parameters of a specific master.
func (c *SentinelClient) Set(ctx context.Context, name, option, value string) *StringCmd {
cmd := NewStringCmd(ctx, "sentinel", "set", name, option, value)
_ = c.Process(ctx, cmd)
return cmd
}
// Remove is used in order to remove the specified master: the master will no
// longer be monitored, and will totally be removed from the internal state of
// the Sentinel.
func (c *SentinelClient) Remove(ctx context.Context, name string) *StringCmd {
cmd := NewStringCmd(ctx, "sentinel", "remove", name)
_ = c.Process(ctx, cmd)
return cmd
}
//------------------------------------------------------------------------------
type sentinelFailover struct {
opt *FailoverOptions
sentinelAddrs []string
onFailover func(ctx context.Context, addr string)
onUpdate func(ctx context.Context)
mu sync.RWMutex
_masterAddr string
sentinel *SentinelClient
pubsub *PubSub
}
func (c *sentinelFailover) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.sentinel != nil {
return c.closeSentinel()
}
return nil
}
func (c *sentinelFailover) closeSentinel() error {
firstErr := c.pubsub.Close()
c.pubsub = nil
err := c.sentinel.Close()
if err != nil && firstErr == nil {
firstErr = err
}
c.sentinel = nil
return firstErr
}
func (c *sentinelFailover) RandomReplicaAddr(ctx context.Context) (string, error) {
if c.opt == nil {
return "", errors.New("opt is nil")
}
addresses, err := c.replicaAddrs(ctx, false)
if err != nil {
return "", err
}
if len(addresses) == 0 && c.opt.UseDisconnectedReplicas {
addresses, err = c.replicaAddrs(ctx, true)
if err != nil {
return "", err
}
}
if len(addresses) == 0 {
return c.MasterAddr(ctx)
}
return addresses[rand.Intn(len(addresses))], nil
}
func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) {
c.mu.RLock()
sentinel := c.sentinel
c.mu.RUnlock()
if sentinel != nil {
addr := c.getMasterAddr(ctx, sentinel)
if addr != "" {
return addr, nil
}
}
c.mu.Lock()
defer c.mu.Unlock()
if c.sentinel != nil {
addr := c.getMasterAddr(ctx, c.sentinel)
if addr != "" {
return addr, nil
}
_ = c.closeSentinel()
}
for i, sentinelAddr := range c.sentinelAddrs {
sentinel := NewSentinelClient(c.opt.sentinelOptions(sentinelAddr))
masterAddr, err := sentinel.GetMasterAddrByName(ctx, c.opt.MasterName).Result()
if err != nil {
internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName master=%q failed: %s",
c.opt.MasterName, err)
_ = sentinel.Close()
continue
}
// Push working sentinel to the top.
c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0]
c.setSentinel(ctx, sentinel)
addr := net.JoinHostPort(masterAddr[0], masterAddr[1])
return addr, nil
}
return "", errors.New("redis: all sentinels specified in configuration are unreachable")
}
func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected bool) ([]string, error) {
c.mu.RLock()
sentinel := c.sentinel
c.mu.RUnlock()
if sentinel != nil {
addrs := c.getReplicaAddrs(ctx, sentinel)
if len(addrs) > 0 {
return addrs, nil
}
}
c.mu.Lock()
defer c.mu.Unlock()
if c.sentinel != nil {
addrs := c.getReplicaAddrs(ctx, c.sentinel)
if len(addrs) > 0 {
return addrs, nil
}
_ = c.closeSentinel()
}
var sentinelReachable bool
for i, sentinelAddr := range c.sentinelAddrs {
sentinel := NewSentinelClient(c.opt.sentinelOptions(sentinelAddr))
replicas, err := sentinel.Replicas(ctx, c.opt.MasterName).Result()
if err != nil {
internal.Logger.Printf(ctx, "sentinel: Replicas master=%q failed: %s",
c.opt.MasterName, err)
_ = sentinel.Close()
continue
}
sentinelReachable = true
addrs := parseReplicaAddrs(replicas, useDisconnected)
if len(addrs) == 0 {
continue
}
// Push working sentinel to the top.
c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0]
c.setSentinel(ctx, sentinel)
return addrs, nil
}
if sentinelReachable {
return []string{}, nil
}
return []string{}, errors.New("redis: all sentinels specified in configuration are unreachable")
}
func (c *sentinelFailover) getMasterAddr(ctx context.Context, sentinel *SentinelClient) string {
addr, err := sentinel.GetMasterAddrByName(ctx, c.opt.MasterName).Result()
if err != nil {
internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s",
c.opt.MasterName, err)
return ""
}
return net.JoinHostPort(addr[0], addr[1])
}
func (c *sentinelFailover) getReplicaAddrs(ctx context.Context, sentinel *SentinelClient) []string {
addrs, err := sentinel.Replicas(ctx, c.opt.MasterName).Result()
if err != nil {
internal.Logger.Printf(ctx, "sentinel: Replicas name=%q failed: %s",
c.opt.MasterName, err)
return []string{}
}
return parseReplicaAddrs(addrs, false)
}
func parseReplicaAddrs(addrs []map[string]string, keepDisconnected bool) []string {
nodes := make([]string, 0, len(addrs))
for _, node := range addrs {
isDown := false
if flags, ok := node["flags"]; ok {
for _, flag := range strings.Split(flags, ",") {
switch flag {
case "s_down", "o_down":
isDown = true
case "disconnected":
if !keepDisconnected {
isDown = true
}
}
}
}
if !isDown && node["ip"] != "" && node["port"] != "" {
nodes = append(nodes, net.JoinHostPort(node["ip"], node["port"]))
}
}
return nodes
}
func (c *sentinelFailover) trySwitchMaster(ctx context.Context, addr string) {
c.mu.RLock()
currentAddr := c._masterAddr //nolint:ifshort
c.mu.RUnlock()
if addr == currentAddr {
return
}
c.mu.Lock()
defer c.mu.Unlock()
if addr == c._masterAddr {
return
}
c._masterAddr = addr
internal.Logger.Printf(ctx, "sentinel: new master=%q addr=%q",
c.opt.MasterName, addr)
if c.onFailover != nil {
c.onFailover(ctx, addr)
}
}
func (c *sentinelFailover) setSentinel(ctx context.Context, sentinel *SentinelClient) {
if c.sentinel != nil {
panic("not reached")
}
c.sentinel = sentinel
c.discoverSentinels(ctx)
c.pubsub = sentinel.Subscribe(ctx, "+switch-master", "+replica-reconf-done")
go c.listen(c.pubsub)
}
func (c *sentinelFailover) discoverSentinels(ctx context.Context) {
sentinels, err := c.sentinel.Sentinels(ctx, c.opt.MasterName).Result()
if err != nil {
internal.Logger.Printf(ctx, "sentinel: Sentinels master=%q failed: %s", c.opt.MasterName, err)
return
}
for _, sentinel := range sentinels {
ip, ok := sentinel["ip"]
if !ok {
continue
}
port, ok := sentinel["port"]
if !ok {
continue
}
if ip != "" && port != "" {
sentinelAddr := net.JoinHostPort(ip, port)
if !contains(c.sentinelAddrs, sentinelAddr) {
internal.Logger.Printf(ctx, "sentinel: discovered new sentinel=%q for master=%q",
sentinelAddr, c.opt.MasterName)
c.sentinelAddrs = append(c.sentinelAddrs, sentinelAddr)
}
}
}
}
func (c *sentinelFailover) listen(pubsub *PubSub) {
ctx := context.TODO()
if c.onUpdate != nil {
c.onUpdate(ctx)
}
ch := pubsub.Channel()
for msg := range ch {
if msg.Channel == "+switch-master" {
parts := strings.Split(msg.Payload, " ")
if parts[0] != c.opt.MasterName {
internal.Logger.Printf(pubsub.getContext(), "sentinel: ignore addr for master=%q", parts[0])
continue
}
addr := net.JoinHostPort(parts[3], parts[4])
c.trySwitchMaster(pubsub.getContext(), addr)
}
if c.onUpdate != nil {
c.onUpdate(ctx)
}
}
}
func contains(slice []string, str string) bool {
for _, s := range slice {
if s == str {
return true
}
}
return false
}
//------------------------------------------------------------------------------
// NewFailoverClusterClient returns a client that supports routing read-only commands
// to a replica node.
func NewFailoverClusterClient(failoverOpt *FailoverOptions) *ClusterClient {
sentinelAddrs := make([]string, len(failoverOpt.SentinelAddrs))
copy(sentinelAddrs, failoverOpt.SentinelAddrs)
failover := &sentinelFailover{
opt: failoverOpt,
sentinelAddrs: sentinelAddrs,
}
opt := failoverOpt.clusterOptions()
opt.ClusterSlots = func(ctx context.Context) ([]ClusterSlot, error) {
masterAddr, err := failover.MasterAddr(ctx)
if err != nil {
return nil, err
}
nodes := []ClusterNode{{
Addr: masterAddr,
}}
replicaAddrs, err := failover.replicaAddrs(ctx, false)
if err != nil {
return nil, err
}
for _, replicaAddr := range replicaAddrs {
nodes = append(nodes, ClusterNode{
Addr: replicaAddr,
})
}
slots := []ClusterSlot{
{
Start: 0,
End: 16383,
Nodes: nodes,
},
}
return slots, nil
}
c := NewClusterClient(opt)
failover.mu.Lock()
failover.onUpdate = func(ctx context.Context) {
c.ReloadState(ctx)
}
failover.mu.Unlock()
return c
}

@ -0,0 +1,149 @@
package redis
import (
"context"
"github.com/go-redis/redis/v9/internal/pool"
"github.com/go-redis/redis/v9/internal/proto"
)
// TxFailedErr transaction redis failed.
const TxFailedErr = proto.RedisError("redis: transaction failed")
// Tx implements Redis transactions as described in
// http://redis.io/topics/transactions. It's NOT safe for concurrent use
// by multiple goroutines, because Exec resets list of watched keys.
//
// If you don't need WATCH, use Pipeline instead.
type Tx struct {
baseClient
cmdable
statefulCmdable
hooks
ctx context.Context
}
func (c *Client) newTx(ctx context.Context) *Tx {
tx := Tx{
baseClient: baseClient{
opt: c.opt,
connPool: pool.NewStickyConnPool(c.connPool),
},
hooks: c.hooks.clone(),
ctx: ctx,
}
tx.init()
return &tx
}
func (c *Tx) init() {
c.cmdable = c.Process
c.statefulCmdable = c.Process
}
func (c *Tx) Context() context.Context {
return c.ctx
}
func (c *Tx) WithContext(ctx context.Context) *Tx {
if ctx == nil {
panic("nil context")
}
clone := *c
clone.init()
clone.hooks.lock()
clone.ctx = ctx
return &clone
}
func (c *Tx) Process(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.baseClient.process)
}
// Watch prepares a transaction and marks the keys to be watched
// for conditional execution if there are any keys.
//
// The transaction is automatically closed when fn exits.
func (c *Client) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
tx := c.newTx(ctx)
defer tx.Close(ctx)
if len(keys) > 0 {
if err := tx.Watch(ctx, keys...).Err(); err != nil {
return err
}
}
return fn(tx)
}
// Close closes the transaction, releasing any open resources.
func (c *Tx) Close(ctx context.Context) error {
_ = c.Unwatch(ctx).Err()
return c.baseClient.Close()
}
// Watch marks the keys to be watched for conditional execution
// of a transaction.
func (c *Tx) Watch(ctx context.Context, keys ...string) *StatusCmd {
args := make([]interface{}, 1+len(keys))
args[0] = "watch"
for i, key := range keys {
args[1+i] = key
}
cmd := NewStatusCmd(ctx, args...)
_ = c.Process(ctx, cmd)
return cmd
}
// Unwatch flushes all the previously watched keys for a transaction.
func (c *Tx) Unwatch(ctx context.Context, keys ...string) *StatusCmd {
args := make([]interface{}, 1+len(keys))
args[0] = "unwatch"
for i, key := range keys {
args[1+i] = key
}
cmd := NewStatusCmd(ctx, args...)
_ = c.Process(ctx, cmd)
return cmd
}
// Pipeline creates a pipeline. Usually it is more convenient to use Pipelined.
func (c *Tx) Pipeline() Pipeliner {
pipe := Pipeline{
ctx: c.ctx,
exec: func(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
},
}
pipe.init()
return &pipe
}
// Pipelined executes commands queued in the fn outside of the transaction.
// Use TxPipelined if you need transactional behavior.
func (c *Tx) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(ctx, fn)
}
// TxPipelined executes commands queued in the fn in the transaction.
//
// When using WATCH, EXEC will execute commands only if the watched keys
// were not modified, allowing for a check-and-set mechanism.
//
// Exec always returns list of commands. If transaction fails
// TxFailedErr is returned. Otherwise Exec returns an error of the first
// failed command or nil.
func (c *Tx) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(ctx, fn)
}
// TxPipeline creates a pipeline. Usually it is more convenient to use TxPipelined.
func (c *Tx) TxPipeline() Pipeliner {
pipe := Pipeline{
ctx: c.ctx,
exec: func(ctx context.Context, cmds []Cmder) error {
return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
},
}
pipe.init()
return &pipe
}

@ -0,0 +1,215 @@
package redis
import (
"context"
"crypto/tls"
"net"
"time"
)
// UniversalOptions information is required by UniversalClient to establish
// connections.
type UniversalOptions struct {
// Either a single address or a seed list of host:port addresses
// of cluster/sentinel nodes.
Addrs []string
// Database to be selected after connecting to the server.
// Only single-node and failover clients.
DB int
// Common options.
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
OnConnect func(ctx context.Context, cn *Conn) error
Username string
Password string
SentinelUsername string
SentinelPassword string
MaxRetries int
MinRetryBackoff time.Duration
MaxRetryBackoff time.Duration
DialTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
// PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO).
PoolFIFO bool
PoolSize int
MinIdleConns int
MaxConnAge time.Duration
PoolTimeout time.Duration
IdleTimeout time.Duration
IdleCheckFrequency time.Duration
TLSConfig *tls.Config
// Only cluster clients.
MaxRedirects int
ReadOnly bool
RouteByLatency bool
RouteRandomly bool
// The sentinel master name.
// Only failover clients.
MasterName string
}
// Cluster returns cluster options created from the universal options.
func (o *UniversalOptions) Cluster() *ClusterOptions {
if len(o.Addrs) == 0 {
o.Addrs = []string{"127.0.0.1:6379"}
}
return &ClusterOptions{
Addrs: o.Addrs,
Dialer: o.Dialer,
OnConnect: o.OnConnect,
Username: o.Username,
Password: o.Password,
MaxRedirects: o.MaxRedirects,
ReadOnly: o.ReadOnly,
RouteByLatency: o.RouteByLatency,
RouteRandomly: o.RouteRandomly,
MaxRetries: o.MaxRetries,
MinRetryBackoff: o.MinRetryBackoff,
MaxRetryBackoff: o.MaxRetryBackoff,
DialTimeout: o.DialTimeout,
ReadTimeout: o.ReadTimeout,
WriteTimeout: o.WriteTimeout,
PoolFIFO: o.PoolFIFO,
PoolSize: o.PoolSize,
MinIdleConns: o.MinIdleConns,
MaxConnAge: o.MaxConnAge,
PoolTimeout: o.PoolTimeout,
IdleTimeout: o.IdleTimeout,
IdleCheckFrequency: o.IdleCheckFrequency,
TLSConfig: o.TLSConfig,
}
}
// Failover returns failover options created from the universal options.
func (o *UniversalOptions) Failover() *FailoverOptions {
if len(o.Addrs) == 0 {
o.Addrs = []string{"127.0.0.1:26379"}
}
return &FailoverOptions{
SentinelAddrs: o.Addrs,
MasterName: o.MasterName,
Dialer: o.Dialer,
OnConnect: o.OnConnect,
DB: o.DB,
Username: o.Username,
Password: o.Password,
SentinelUsername: o.SentinelUsername,
SentinelPassword: o.SentinelPassword,
MaxRetries: o.MaxRetries,
MinRetryBackoff: o.MinRetryBackoff,
MaxRetryBackoff: o.MaxRetryBackoff,
DialTimeout: o.DialTimeout,
ReadTimeout: o.ReadTimeout,
WriteTimeout: o.WriteTimeout,
PoolFIFO: o.PoolFIFO,
PoolSize: o.PoolSize,
MinIdleConns: o.MinIdleConns,
MaxConnAge: o.MaxConnAge,
PoolTimeout: o.PoolTimeout,
IdleTimeout: o.IdleTimeout,
IdleCheckFrequency: o.IdleCheckFrequency,
TLSConfig: o.TLSConfig,
}
}
// Simple returns basic options created from the universal options.
func (o *UniversalOptions) Simple() *Options {
addr := "127.0.0.1:6379"
if len(o.Addrs) > 0 {
addr = o.Addrs[0]
}
return &Options{
Addr: addr,
Dialer: o.Dialer,
OnConnect: o.OnConnect,
DB: o.DB,
Username: o.Username,
Password: o.Password,
MaxRetries: o.MaxRetries,
MinRetryBackoff: o.MinRetryBackoff,
MaxRetryBackoff: o.MaxRetryBackoff,
DialTimeout: o.DialTimeout,
ReadTimeout: o.ReadTimeout,
WriteTimeout: o.WriteTimeout,
PoolFIFO: o.PoolFIFO,
PoolSize: o.PoolSize,
MinIdleConns: o.MinIdleConns,
MaxConnAge: o.MaxConnAge,
PoolTimeout: o.PoolTimeout,
IdleTimeout: o.IdleTimeout,
IdleCheckFrequency: o.IdleCheckFrequency,
TLSConfig: o.TLSConfig,
}
}
// --------------------------------------------------------------------
// UniversalClient is an abstract client which - based on the provided options -
// represents either a ClusterClient, a FailoverClient, or a single-node Client.
// This can be useful for testing cluster-specific applications locally or having different
// clients in different environments.
type UniversalClient interface {
Cmdable
Context() context.Context
AddHook(Hook)
Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error
Do(ctx context.Context, args ...interface{}) *Cmd
Process(ctx context.Context, cmd Cmder) error
Subscribe(ctx context.Context, channels ...string) *PubSub
PSubscribe(ctx context.Context, channels ...string) *PubSub
Close() error
PoolStats() *PoolStats
}
var (
_ UniversalClient = (*Client)(nil)
_ UniversalClient = (*ClusterClient)(nil)
_ UniversalClient = (*Ring)(nil)
)
// NewUniversalClient returns a new multi client. The type of the returned client depends
// on the following conditions:
//
// 1. If the MasterName option is specified, a sentinel-backed FailoverClient is returned.
// 2. if the number of Addrs is two or more, a ClusterClient is returned.
// 3. Otherwise, a single-node Client is returned.
func NewUniversalClient(opts *UniversalOptions) UniversalClient {
if opts.MasterName != "" {
return NewFailoverClient(opts.Failover())
} else if len(opts.Addrs) > 1 {
return NewClusterClient(opts.Cluster())
}
return NewClient(opts.Simple())
}

@ -0,0 +1,6 @@
package redis
// Version is the current release version.
func Version() string {
return "9.0.0-beta.1"
}

@ -0,0 +1,9 @@
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
Icon?
ehthumbs.db
Thumbs.db
.idea

@ -0,0 +1,117 @@
# This is the official list of Go-MySQL-Driver authors for copyright purposes.
# If you are submitting a patch, please add your name or the name of the
# organization which holds the copyright to this list in alphabetical order.
# Names should be added to this file as
# Name <email address>
# The email address is not required for organizations.
# Please keep the list sorted.
# Individual Persons
Aaron Hopkins <go-sql-driver at die.net>
Achille Roussel <achille.roussel at gmail.com>
Alex Snast <alexsn at fb.com>
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
Andrew Reid <andrew.reid at tixtrack.com>
Animesh Ray <mail.rayanimesh at gmail.com>
Arne Hormann <arnehormann at gmail.com>
Ariel Mashraki <ariel at mashraki.co.il>
Asta Xie <xiemengjun at gmail.com>
Bulat Gaifullin <gaifullinbf at gmail.com>
Caine Jette <jette at alum.mit.edu>
Carlos Nieto <jose.carlos at menteslibres.net>
Chris Moos <chris at tech9computers.com>
Craig Wilson <craiggwilson at gmail.com>
Daniel Montoya <dsmontoyam at gmail.com>
Daniel Nichter <nil at codenode.com>
Daniël van Eeden <git at myname.nl>
Dave Protasowski <dprotaso at gmail.com>
DisposaBoy <disposaboy at dby.me>
Egor Smolyakov <egorsmkv at gmail.com>
Erwan Martin <hello at erwan.io>
Evan Shaw <evan at vendhq.com>
Frederick Mayle <frederickmayle at gmail.com>
Gustavo Kristic <gkristic at gmail.com>
Hajime Nakagami <nakagami at gmail.com>
Hanno Braun <mail at hannobraun.com>
Henri Yandell <flamefew at gmail.com>
Hirotaka Yamamoto <ymmt2005 at gmail.com>
Huyiguang <hyg at webterren.com>
ICHINOSE Shogo <shogo82148 at gmail.com>
Ilia Cimpoes <ichimpoesh at gmail.com>
INADA Naoki <songofacandy at gmail.com>
Jacek Szwec <szwec.jacek at gmail.com>
James Harr <james.harr at gmail.com>
Jeff Hodges <jeff at somethingsimilar.com>
Jeffrey Charles <jeffreycharles at gmail.com>
Jerome Meyer <jxmeyer at gmail.com>
Jiajia Zhong <zhong2plus at gmail.com>
Jian Zhen <zhenjl at gmail.com>
Joshua Prunier <joshua.prunier at gmail.com>
Julien Lefevre <julien.lefevr at gmail.com>
Julien Schmidt <go-sql-driver at julienschmidt.com>
Justin Li <jli at j-li.net>
Justin Nuß <nuss.justin at gmail.com>
Kamil Dziedzic <kamil at klecza.pl>
Kei Kamikawa <x00.x7f.x86 at gmail.com>
Kevin Malachowski <kevin at chowski.com>
Kieron Woodhouse <kieron.woodhouse at infosum.com>
Lennart Rudolph <lrudolph at hmc.edu>
Leonardo YongUk Kim <dalinaum at gmail.com>
Linh Tran Tuan <linhduonggnu at gmail.com>
Lion Yang <lion at aosc.xyz>
Luca Looz <luca.looz92 at gmail.com>
Lucas Liu <extrafliu at gmail.com>
Luke Scott <luke at webconnex.com>
Maciej Zimnoch <maciej.zimnoch at codilime.com>
Michael Woolnough <michael.woolnough at gmail.com>
Nathanial Murphy <nathanial.murphy at gmail.com>
Nicola Peduzzi <thenikso at gmail.com>
Olivier Mengué <dolmen at cpan.org>
oscarzhao <oscarzhaosl at gmail.com>
Paul Bonser <misterpib at gmail.com>
Peter Schultz <peter.schultz at classmarkets.com>
Rebecca Chin <rchin at pivotal.io>
Reed Allman <rdallman10 at gmail.com>
Richard Wilkes <wilkes at me.com>
Robert Russell <robert at rrbrussell.com>
Runrioter Wung <runrioter at gmail.com>
Sho Iizuka <sho.i518 at gmail.com>
Sho Ikeda <suicaicoca at gmail.com>
Shuode Li <elemount at qq.com>
Simon J Mudd <sjmudd at pobox.com>
Soroush Pour <me at soroushjp.com>
Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com>
Steven Hartland <steven.hartland at multiplay.co.uk>
Tan Jinhua <312841925 at qq.com>
Thomas Wodarek <wodarekwebpage at gmail.com>
Tim Ruffles <timruffles at gmail.com>
Tom Jenkinson <tom at tjenkinson.me>
Vladimir Kovpak <cn007b at gmail.com>
Vladyslav Zhelezniak <zhvladi at gmail.com>
Xiangyu Hu <xiangyu.hu at outlook.com>
Xiaobing Jiang <s7v7nislands at gmail.com>
Xiuming Chen <cc at cxm.cc>
Xuehong Chan <chanxuehong at gmail.com>
Zhenye Xie <xiezhenye at gmail.com>
Zhixin Wen <john.wenzhixin at gmail.com>
# Organizations
Barracuda Networks, Inc.
Counting Ltd.
DigitalOcean Inc.
Facebook Inc.
GitHub Inc.
Google Inc.
InfoSum Ltd.
Keybase Inc.
Multiplay Ltd.
Percona LLC
Pivotal Inc.
Stripe Inc.
Zendesk Inc.

@ -0,0 +1,232 @@
## Version 1.6 (2021-04-01)
Changes:
- Migrate the CI service from travis-ci to GitHub Actions (#1176, #1183, #1190)
- `NullTime` is deprecated (#960, #1144)
- Reduce allocations when building SET command (#1111)
- Performance improvement for time formatting (#1118)
- Performance improvement for time parsing (#1098, #1113)
New Features:
- Implement `driver.Validator` interface (#1106, #1174)
- Support returning `uint64` from `Valuer` in `ConvertValue` (#1143)
- Add `json.RawMessage` for converter and prepared statement (#1059)
- Interpolate `json.RawMessage` as `string` (#1058)
- Implements `CheckNamedValue` (#1090)
Bugfixes:
- Stop rounding times (#1121, #1172)
- Put zero filler into the SSL handshake packet (#1066)
- Fix checking cancelled connections back into the connection pool (#1095)
- Fix remove last 0 byte for mysql_old_password when password is empty (#1133)
## Version 1.5 (2020-01-07)
Changes:
- Dropped support Go 1.9 and lower (#823, #829, #886, #1016, #1017)
- Improve buffer handling (#890)
- Document potentially insecure TLS configs (#901)
- Use a double-buffering scheme to prevent data races (#943)
- Pass uint64 values without converting them to string (#838, #955)
- Update collations and make utf8mb4 default (#877, #1054)
- Make NullTime compatible with sql.NullTime in Go 1.13+ (#995)
- Removed CloudSQL support (#993, #1007)
- Add Go Module support (#1003)
New Features:
- Implement support of optional TLS (#900)
- Check connection liveness (#934, #964, #997, #1048, #1051, #1052)
- Implement Connector Interface (#941, #958, #1020, #1035)
Bugfixes:
- Mark connections as bad on error during ping (#875)
- Mark connections as bad on error during dial (#867)
- Fix connection leak caused by rapid context cancellation (#1024)
- Mark connections as bad on error during Conn.Prepare (#1030)
## Version 1.4.1 (2018-11-14)
Bugfixes:
- Fix TIME format for binary columns (#818)
- Fix handling of empty auth plugin names (#835)
- Fix caching_sha2_password with empty password (#826)
- Fix canceled context broke mysqlConn (#862)
- Fix OldAuthSwitchRequest support (#870)
- Fix Auth Response packet for cleartext password (#887)
## Version 1.4 (2018-06-03)
Changes:
- Documentation fixes (#530, #535, #567)
- Refactoring (#575, #579, #580, #581, #603, #615, #704)
- Cache column names (#444)
- Sort the DSN parameters in DSNs generated from a config (#637)
- Allow native password authentication by default (#644)
- Use the default port if it is missing in the DSN (#668)
- Removed the `strict` mode (#676)
- Do not query `max_allowed_packet` by default (#680)
- Dropped support Go 1.6 and lower (#696)
- Updated `ConvertValue()` to match the database/sql/driver implementation (#760)
- Document the usage of `0000-00-00T00:00:00` as the time.Time zero value (#783)
- Improved the compatibility of the authentication system (#807)
New Features:
- Multi-Results support (#537)
- `rejectReadOnly` DSN option (#604)
- `context.Context` support (#608, #612, #627, #761)
- Transaction isolation level support (#619, #744)
- Read-Only transactions support (#618, #634)
- `NewConfig` function which initializes a config with default values (#679)
- Implemented the `ColumnType` interfaces (#667, #724)
- Support for custom string types in `ConvertValue` (#623)
- Implemented `NamedValueChecker`, improving support for uint64 with high bit set (#690, #709, #710)
- `caching_sha2_password` authentication plugin support (#794, #800, #801, #802)
- Implemented `driver.SessionResetter` (#779)
- `sha256_password` authentication plugin support (#808)
Bugfixes:
- Use the DSN hostname as TLS default ServerName if `tls=true` (#564, #718)
- Fixed LOAD LOCAL DATA INFILE for empty files (#590)
- Removed columns definition cache since it sometimes cached invalid data (#592)
- Don't mutate registered TLS configs (#600)
- Make RegisterTLSConfig concurrency-safe (#613)
- Handle missing auth data in the handshake packet correctly (#646)
- Do not retry queries when data was written to avoid data corruption (#302, #736)
- Cache the connection pointer for error handling before invalidating it (#678)
- Fixed imports for appengine/cloudsql (#700)
- Fix sending STMT_LONG_DATA for 0 byte data (#734)
- Set correct capacity for []bytes read from length-encoded strings (#766)
- Make RegisterDial concurrency-safe (#773)
## Version 1.3 (2016-12-01)
Changes:
- Go 1.1 is no longer supported
- Use decimals fields in MySQL to format time types (#249)
- Buffer optimizations (#269)
- TLS ServerName defaults to the host (#283)
- Refactoring (#400, #410, #437)
- Adjusted documentation for second generation CloudSQL (#485)
- Documented DSN system var quoting rules (#502)
- Made statement.Close() calls idempotent to avoid errors in Go 1.6+ (#512)
New Features:
- Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249)
- Support for returning table alias on Columns() (#289, #359, #382)
- Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318, #490)
- Support for uint64 parameters with high bit set (#332, #345)
- Cleartext authentication plugin support (#327)
- Exported ParseDSN function and the Config struct (#403, #419, #429)
- Read / Write timeouts (#401)
- Support for JSON field type (#414)
- Support for multi-statements and multi-results (#411, #431)
- DSN parameter to set the driver-side max_allowed_packet value manually (#489)
- Native password authentication plugin support (#494, #524)
Bugfixes:
- Fixed handling of queries without columns and rows (#255)
- Fixed a panic when SetKeepAlive() failed (#298)
- Handle ERR packets while reading rows (#321)
- Fixed reading NULL length-encoded integers in MySQL 5.6+ (#349)
- Fixed absolute paths support in LOAD LOCAL DATA INFILE (#356)
- Actually zero out bytes in handshake response (#378)
- Fixed race condition in registering LOAD DATA INFILE handler (#383)
- Fixed tests with MySQL 5.7.9+ (#380)
- QueryUnescape TLS config names (#397)
- Fixed "broken pipe" error by writing to closed socket (#390)
- Fixed LOAD LOCAL DATA INFILE buffering (#424)
- Fixed parsing of floats into float64 when placeholders are used (#434)
- Fixed DSN tests with Go 1.7+ (#459)
- Handle ERR packets while waiting for EOF (#473)
- Invalidate connection on error while discarding additional results (#513)
- Allow terminating packets of length 0 (#516)
## Version 1.2 (2014-06-03)
Changes:
- We switched back to a "rolling release". `go get` installs the current master branch again
- Version v1 of the driver will not be maintained anymore. Go 1.0 is no longer supported by this driver
- Exported errors to allow easy checking from application code
- Enabled TCP Keepalives on TCP connections
- Optimized INFILE handling (better buffer size calculation, lazy init, ...)
- The DSN parser also checks for a missing separating slash
- Faster binary date / datetime to string formatting
- Also exported the MySQLWarning type
- mysqlConn.Close returns the first error encountered instead of ignoring all errors
- writePacket() automatically writes the packet size to the header
- readPacket() uses an iterative approach instead of the recursive approach to merge splitted packets
New Features:
- `RegisterDial` allows the usage of a custom dial function to establish the network connection
- Setting the connection collation is possible with the `collation` DSN parameter. This parameter should be preferred over the `charset` parameter
- Logging of critical errors is configurable with `SetLogger`
- Google CloudSQL support
Bugfixes:
- Allow more than 32 parameters in prepared statements
- Various old_password fixes
- Fixed TestConcurrent test to pass Go's race detection
- Fixed appendLengthEncodedInteger for large numbers
- Renamed readLengthEnodedString to readLengthEncodedString and skipLengthEnodedString to skipLengthEncodedString (fixed typo)
## Version 1.1 (2013-11-02)
Changes:
- Go-MySQL-Driver now requires Go 1.1
- Connections now use the collation `utf8_general_ci` by default. Adding `&charset=UTF8` to the DSN should not be necessary anymore
- Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors
- `[]byte(nil)` is now treated as a NULL value. Before, it was treated like an empty string / `[]byte("")`
- DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'.
- Use the IO buffer also for writing. This results in zero allocations (by the driver) for most queries
- Optimized the buffer for reading
- stmt.Query now caches column metadata
- New Logo
- Changed the copyright header to include all contributors
- Improved the LOAD INFILE documentation
- The driver struct is now exported to make the driver directly accessible
- Refactored the driver tests
- Added more benchmarks and moved all to a separate file
- Other small refactoring
New Features:
- Added *old_passwords* support: Required in some cases, but must be enabled by adding `allowOldPasswords=true` to the DSN since it is insecure
- Added a `clientFoundRows` parameter: Return the number of matching rows instead of the number of rows changed on UPDATEs
- Added TLS/SSL support: Use a TLS/SSL encrypted connection to the server. Custom TLS configs can be registered and used
Bugfixes:
- Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification
- Convert to DB timezone when inserting `time.Time`
- Splitted packets (more than 16MB) are now merged correctly
- Fixed false positive `io.EOF` errors when the data was fully read
- Avoid panics on reuse of closed connections
- Fixed empty string producing false nil values
- Fixed sign byte for positive TIME fields
## Version 1.0 (2013-05-14)
Initial Release

@ -0,0 +1,373 @@
Mozilla Public License Version 2.0
==================================
1. Definitions
--------------
1.1. "Contributor"
means each individual or legal entity that creates, contributes to
the creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used
by a Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached
the notice in Exhibit A, the Executable Form of such Source Code
Form, and Modifications of such Source Code Form, in each case
including portions thereof.
1.5. "Incompatible With Secondary Licenses"
means
(a) that the initial Contributor has attached the notice described
in Exhibit B to the Covered Software; or
(b) that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the
terms of a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in
a separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible,
whether at the time of the initial grant or subsequently, any and
all of the rights conveyed by this License.
1.10. "Modifications"
means any of the following:
(a) any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered
Software; or
(b) any new file in Source Code Form that contains any Covered
Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the
License, by the making, using, selling, offering for sale, having
made, import, or transfer of either its Contributions or its
Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU
Lesser General Public License, Version 2.1, the GNU Affero General
Public License, Version 3.0, or any later versions of those
licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that
controls, is controlled by, or is under common control with You. For
purposes of this definition, "control" means (a) the power, direct
or indirect, to cause the direction or management of such entity,
whether by contract or otherwise, or (b) ownership of more than
fifty percent (50%) of the outstanding shares or beneficial
ownership of such entity.
2. License Grants and Conditions
--------------------------------
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
(a) under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
(b) under Patent Claims of such Contributor to make, use, sell, offer
for sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
(a) for any code that a Contributor has removed from Covered Software;
or
(b) for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
(c) under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.
3. Responsibilities
-------------------
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
(a) such Covered Software must also be made available in Source Code
Form, as described in Section 3.1, and You must inform recipients of
the Executable Form how they can obtain a copy of such Source Code
Form by reasonable means in a timely manner, at a charge no more
than the cost of distribution to the recipient; and
(b) You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter
the recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------
If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.
5. Termination
--------------
5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.
************************************************************************
* *
* 6. Disclaimer of Warranty *
* ------------------------- *
* *
* Covered Software is provided under this License on an "as is" *
* basis, without warranty of any kind, either expressed, implied, or *
* statutory, including, without limitation, warranties that the *
* Covered Software is free of defects, merchantable, fit for a *
* particular purpose or non-infringing. The entire risk as to the *
* quality and performance of the Covered Software is with You. *
* Should any Covered Software prove defective in any respect, You *
* (not any Contributor) assume the cost of any necessary servicing, *
* repair, or correction. This disclaimer of warranty constitutes an *
* essential part of this License. No use of any Covered Software is *
* authorized under this License except under this disclaimer. *
* *
************************************************************************
************************************************************************
* *
* 7. Limitation of Liability *
* -------------------------- *
* *
* Under no circumstances and under no legal theory, whether tort *
* (including negligence), contract, or otherwise, shall any *
* Contributor, or anyone who distributes Covered Software as *
* permitted above, be liable to You for any direct, indirect, *
* special, incidental, or consequential damages of any character *
* including, without limitation, damages for lost profits, loss of *
* goodwill, work stoppage, computer failure or malfunction, or any *
* and all other commercial damages or losses, even if such party *
* shall have been informed of the possibility of such damages. This *
* limitation of liability shall not apply to liability for death or *
* personal injury resulting from such party's negligence to the *
* extent applicable law prohibits such limitation. Some *
* jurisdictions do not allow the exclusion or limitation of *
* incidental or consequential damages, so this exclusion and *
* limitation may not apply to You. *
* *
************************************************************************
8. Litigation
-------------
Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.
9. Miscellaneous
----------------
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.
10. Versions of the License
---------------------------
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
-------------------------------------------
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------
This Source Code Form is "Incompatible With Secondary Licenses", as
defined by the Mozilla Public License, v. 2.0.

@ -0,0 +1,520 @@
# Go-MySQL-Driver
A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) package
![Go-MySQL-Driver logo](https://raw.github.com/wiki/go-sql-driver/mysql/gomysql_m.png "Golang Gopher holding the MySQL Dolphin")
---------------------------------------
* [Features](#features)
* [Requirements](#requirements)
* [Installation](#installation)
* [Usage](#usage)
* [DSN (Data Source Name)](#dsn-data-source-name)
* [Password](#password)
* [Protocol](#protocol)
* [Address](#address)
* [Parameters](#parameters)
* [Examples](#examples)
* [Connection pool and timeouts](#connection-pool-and-timeouts)
* [context.Context Support](#contextcontext-support)
* [ColumnType Support](#columntype-support)
* [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support)
* [time.Time support](#timetime-support)
* [Unicode support](#unicode-support)
* [Testing / Development](#testing--development)
* [License](#license)
---------------------------------------
## Features
* Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance")
* Native Go implementation. No C-bindings, just pure Go
* Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](https://godoc.org/github.com/go-sql-driver/mysql#DialFunc)
* Automatic handling of broken connections
* Automatic Connection Pooling *(by database/sql package)*
* Supports queries larger than 16MB
* Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support.
* Intelligent `LONG DATA` handling in prepared statements
* Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support
* Optional `time.Time` parsing
* Optional placeholder interpolation
## Requirements
* Go 1.10 or higher. We aim to support the 3 latest versions of Go.
* MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+)
---------------------------------------
## Installation
Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell:
```bash
$ go get -u github.com/go-sql-driver/mysql
```
Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`.
## Usage
_Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then.
Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`:
```go
import (
"database/sql"
"time"
_ "github.com/go-sql-driver/mysql"
)
// ...
db, err := sql.Open("mysql", "user:password@/dbname")
if err != nil {
panic(err)
}
// See "Important settings" section.
db.SetConnMaxLifetime(time.Minute * 3)
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(10)
```
[Examples are available in our Wiki](https://github.com/go-sql-driver/mysql/wiki/Examples "Go-MySQL-Driver Examples").
### Important settings
`db.SetConnMaxLifetime()` is required to ensure connections are closed by the driver safely before connection is closed by MySQL server, OS, or other middlewares. Since some middlewares close idle connections by 5 minutes, we recommend timeout shorter than 5 minutes. This setting helps load balancing and changing system variables too.
`db.SetMaxOpenConns()` is highly recommended to limit the number of connection used by the application. There is no recommended limit number because it depends on application and MySQL server.
`db.SetMaxIdleConns()` is recommended to be set same to (or greater than) `db.SetMaxOpenConns()`. When it is smaller than `SetMaxOpenConns()`, connections can be opened and closed very frequently than you expect. Idle connections can be closed by the `db.SetConnMaxLifetime()`. If you want to close idle connections more rapidly, you can use `db.SetConnMaxIdleTime()` since Go 1.15.
### DSN (Data Source Name)
The Data Source Name has a common format, like e.g. [PEAR DB](http://pear.php.net/manual/en/package.database.db.intro-dsn.php) uses it, but without type-prefix (optional parts marked by squared brackets):
```
[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...&paramN=valueN]
```
A DSN in its fullest form:
```
username:password@protocol(address)/dbname?param=value
```
Except for the databasename, all values are optional. So the minimal DSN is:
```
/dbname
```
If you do not want to preselect a database, leave `dbname` empty:
```
/
```
This has the same effect as an empty DSN string:
```
```
Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct.
#### Password
Passwords can consist of any character. Escaping is **not** necessary.
#### Protocol
See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available.
In general you should use an Unix domain socket if available and TCP otherwise for best performance.
#### Address
For TCP and UDP networks, addresses have the form `host[:port]`.
If `port` is omitted, the default port will be used.
If `host` is a literal IPv6 address, it must be enclosed in square brackets.
The functions [net.JoinHostPort](https://golang.org/pkg/net/#JoinHostPort) and [net.SplitHostPort](https://golang.org/pkg/net/#SplitHostPort) manipulate addresses in this form.
For Unix domain sockets the address is the absolute path to the MySQL-Server-socket, e.g. `/var/run/mysqld/mysqld.sock` or `/tmp/mysql.sock`.
#### Parameters
*Parameters are case-sensitive!*
Notice that any of `true`, `TRUE`, `True` or `1` is accepted to stand for a true boolean value. Not surprisingly, false can be specified as any of: `false`, `FALSE`, `False` or `0`.
##### `allowAllFiles`
```
Type: bool
Valid Values: true, false
Default: false
```
`allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files.
[*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)
##### `allowCleartextPasswords`
```
Type: bool
Valid Values: true, false
Default: false
```
`allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network.
##### `allowNativePasswords`
```
Type: bool
Valid Values: true, false
Default: true
```
`allowNativePasswords=false` disallows the usage of MySQL native password method.
##### `allowOldPasswords`
```
Type: bool
Valid Values: true, false
Default: false
```
`allowOldPasswords=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords).
##### `charset`
```
Type: string
Valid Values: <name>
Default: none
```
Sets the charset used for client-server interaction (`"SET NAMES <value>"`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
Usage of the `charset` parameter is discouraged because it issues additional queries to the server.
Unless you need the fallback behavior, please use `collation` instead.
##### `checkConnLiveness`
```
Type: bool
Valid Values: true, false
Default: true
```
On supported platforms connections retrieved from the connection pool are checked for liveness before using them. If the check fails, the respective connection is marked as bad and the query retried with another connection.
`checkConnLiveness=false` disables this liveness check of connections.
##### `collation`
```
Type: string
Valid Values: <name>
Default: utf8mb4_general_ci
```
Sets the collation used for client-server interaction on connection. In contrast to `charset`, `collation` does not issue additional queries. If the specified collation is unavailable on the target server, the connection will fail.
A list of valid charsets for a server is retrievable with `SHOW COLLATION`.
The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You should use an older collation (e.g. `utf8_general_ci`) for older MySQL.
Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)).
##### `clientFoundRows`
```
Type: bool
Valid Values: true, false
Default: false
```
`clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed.
##### `columnsWithAlias`
```
Type: bool
Valid Values: true, false
Default: false
```
When `columnsWithAlias` is true, calls to `sql.Rows.Columns()` will return the table alias and the column name separated by a dot. For example:
```
SELECT u.id FROM users as u
```
will return `u.id` instead of just `id` if `columnsWithAlias=true`.
##### `interpolateParams`
```
Type: bool
Valid Values: true, false
Default: false
```
If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`.
*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are rejected as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!*
##### `loc`
```
Type: string
Valid Values: <escaped name>
Default: UTC
```
Sets the location for time.Time values (when using `parseTime=true`). *"Local"* sets the system's location. See [time.LoadLocation](https://golang.org/pkg/time/#LoadLocation) for details.
Note that this sets the location for time.Time values but does not change MySQL's [time_zone setting](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html). For that see the [time_zone system variable](#system-variables), which can also be set as a DSN parameter.
Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`.
##### `maxAllowedPacket`
```
Type: decimal number
Default: 4194304
```
Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*.
##### `multiStatements`
```
Type: bool
Valid Values: true, false
Default: false
```
Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded.
When `multiStatements` is used, `?` parameters must only be used in the first statement.
##### `parseTime`
```
Type: bool
Valid Values: true, false
Default: false
```
`parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
The date or datetime like `0000-00-00 00:00:00` is converted into zero value of `time.Time`.
##### `readTimeout`
```
Type: duration
Default: 0
```
I/O read timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
##### `rejectReadOnly`
```
Type: bool
Valid Values: true, false
Default: false
```
`rejectReadOnly=true` causes the driver to reject read-only connections. This
is for a possible race condition during an automatic failover, where the mysql
client gets connected to a read-only replica after the failover.
Note that this should be a fairly rare case, as an automatic failover normally
happens when the primary is down, and the race condition shouldn't happen
unless it comes back up online as soon as the failover is kicked off. On the
other hand, when this happens, a MySQL application can get stuck on a
read-only connection until restarted. It is however fairly easy to reproduce,
for example, using a manual failover on AWS Aurora's MySQL-compatible cluster.
If you are not relying on read-only transactions to reject writes that aren't
supposed to happen, setting this on some MySQL providers (such as AWS Aurora)
is safer for failovers.
Note that ERROR 1290 can be returned for a `read-only` server and this option will
cause a retry for that error. However the same error number is used for some
other cases. You should ensure your application will never cause an ERROR 1290
except for `read-only` mode when enabling this option.
##### `serverPubKey`
```
Type: string
Valid Values: <name>
Default: none
```
Server public keys can be registered with [`mysql.RegisterServerPubKey`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterServerPubKey), which can then be used by the assigned name in the DSN.
Public keys are used to transmit encrypted data, e.g. for authentication.
If the server's public key is known, it should be set manually to avoid expensive and potentially insecure transmissions of the public key from the server to the client each time it is required.
##### `timeout`
```
Type: duration
Default: OS default
```
Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
##### `tls`
```
Type: bool / string
Valid Values: true, false, skip-verify, preferred, <name>
Default: false
```
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side) or use `preferred` to use TLS only when advertised by the server. This is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Neither `skip-verify` nor `preferred` add any reliable security. You can use a custom TLS config after registering it with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
##### `writeTimeout`
```
Type: duration
Default: 0
```
I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
##### System Variables
Any other parameters are interpreted as system variables:
* `<boolean_var>=<value>`: `SET <boolean_var>=<value>`
* `<enum_var>=<value>`: `SET <enum_var>=<value>`
* `<string_var>=%27<value>%27`: `SET <string_var>='<value>'`
Rules:
* The values for string variables must be quoted with `'`.
* The values must also be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed!
(which implies values of string variables must be wrapped with `%27`).
Examples:
* `autocommit=1`: `SET autocommit=1`
* [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'`
* [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'`
#### Examples
```
user@unix(/path/to/socket)/dbname
```
```
root:pw@unix(/tmp/mysql.sock)/myDatabase?loc=Local
```
```
user:password@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true
```
Treat warnings as errors by setting the system variable [`sql_mode`](https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html):
```
user:password@/dbname?sql_mode=TRADITIONAL
```
TCP via IPv6:
```
user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci
```
TCP on a remote host, e.g. Amazon RDS:
```
id:password@tcp(your-amazonaws-uri.com:3306)/dbname
```
Google Cloud SQL on App Engine:
```
user:password@unix(/cloudsql/project-id:region-name:instance-name)/dbname
```
TCP using default port (3306) on localhost:
```
user:password@tcp/dbname?charset=utf8mb4,utf8&sys_var=esc%40ped
```
Use the default protocol (tcp) and host (localhost:3306):
```
user:password@/dbname
```
No Database preselected:
```
user:password@/
```
### Connection pool and timeouts
The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively.
## `ColumnType` Support
This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported.
## `context.Context` Support
Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts.
See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details.
### `LOAD DATA LOCAL INFILE` support
For this feature you need direct access to the package. Therefore you must change the import path (no `_`):
```go
import "github.com/go-sql-driver/mysql"
```
Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)).
To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore.
See the [godoc of Go-MySQL-Driver](https://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details.
### `time.Time` support
The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your program.
However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical equivalent in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
**Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
### Unicode support
Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default.
Other collations / charsets can be set using the [`collation`](#collation) DSN parameter.
Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAMES utf8`) to the DSN to enable proper UTF-8 support. This is not necessary anymore. The [`collation`](#collation) parameter should be preferred to set another collation / charset than the default.
See http://dev.mysql.com/doc/refman/8.0/en/charset-unicode.html for more details on MySQL's Unicode support.
## Testing / Development
To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details.
Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated.
If you want to contribute, you can work on an [open issue](https://github.com/go-sql-driver/mysql/issues?state=open) or review a [pull request](https://github.com/go-sql-driver/mysql/pulls).
See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/.github/CONTRIBUTING.md) for details.
---------------------------------------
## License
Go-MySQL-Driver is licensed under the [Mozilla Public License Version 2.0](https://raw.github.com/go-sql-driver/mysql/master/LICENSE)
Mozilla summarizes the license scope as follows:
> MPL: The copyleft applies to any files containing MPLed code.
That means:
* You can **use** the **unchanged** source code both in private and commercially.
* When distributing, you **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0).
* You **needn't publish** the source code of your library as long as the files licensed under the MPL 2.0 are **unchanged**.
Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you have further questions regarding the license.
You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE).
![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow")

@ -0,0 +1,425 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/x509"
"encoding/pem"
"fmt"
"sync"
)
// server pub keys registry
var (
serverPubKeyLock sync.RWMutex
serverPubKeyRegistry map[string]*rsa.PublicKey
)
// RegisterServerPubKey registers a server RSA public key which can be used to
// send data in a secure manner to the server without receiving the public key
// in a potentially insecure way from the server first.
// Registered keys can afterwards be used adding serverPubKey=<name> to the DSN.
//
// Note: The provided rsa.PublicKey instance is exclusively owned by the driver
// after registering it and may not be modified.
//
// data, err := ioutil.ReadFile("mykey.pem")
// if err != nil {
// log.Fatal(err)
// }
//
// block, _ := pem.Decode(data)
// if block == nil || block.Type != "PUBLIC KEY" {
// log.Fatal("failed to decode PEM block containing public key")
// }
//
// pub, err := x509.ParsePKIXPublicKey(block.Bytes)
// if err != nil {
// log.Fatal(err)
// }
//
// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
// mysql.RegisterServerPubKey("mykey", rsaPubKey)
// } else {
// log.Fatal("not a RSA public key")
// }
//
func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) {
serverPubKeyLock.Lock()
if serverPubKeyRegistry == nil {
serverPubKeyRegistry = make(map[string]*rsa.PublicKey)
}
serverPubKeyRegistry[name] = pubKey
serverPubKeyLock.Unlock()
}
// DeregisterServerPubKey removes the public key registered with the given name.
func DeregisterServerPubKey(name string) {
serverPubKeyLock.Lock()
if serverPubKeyRegistry != nil {
delete(serverPubKeyRegistry, name)
}
serverPubKeyLock.Unlock()
}
func getServerPubKey(name string) (pubKey *rsa.PublicKey) {
serverPubKeyLock.RLock()
if v, ok := serverPubKeyRegistry[name]; ok {
pubKey = v
}
serverPubKeyLock.RUnlock()
return
}
// Hash password using pre 4.1 (old password) method
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
type myRnd struct {
seed1, seed2 uint32
}
const myRndMaxVal = 0x3FFFFFFF
// Pseudo random number generator
func newMyRnd(seed1, seed2 uint32) *myRnd {
return &myRnd{
seed1: seed1 % myRndMaxVal,
seed2: seed2 % myRndMaxVal,
}
}
// Tested to be equivalent to MariaDB's floating point variant
// http://play.golang.org/p/QHvhd4qved
// http://play.golang.org/p/RG0q4ElWDx
func (r *myRnd) NextByte() byte {
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
return byte(uint64(r.seed1) * 31 / myRndMaxVal)
}
// Generate binary hash from byte string using insecure pre 4.1 method
func pwHash(password []byte) (result [2]uint32) {
var add uint32 = 7
var tmp uint32
result[0] = 1345345333
result[1] = 0x12345671
for _, c := range password {
// skip spaces and tabs in password
if c == ' ' || c == '\t' {
continue
}
tmp = uint32(c)
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
result[1] += (result[1] << 8) ^ result[0]
add += tmp
}
// Remove sign bit (1<<31)-1)
result[0] &= 0x7FFFFFFF
result[1] &= 0x7FFFFFFF
return
}
// Hash password using insecure pre 4.1 method
func scrambleOldPassword(scramble []byte, password string) []byte {
scramble = scramble[:8]
hashPw := pwHash([]byte(password))
hashSc := pwHash(scramble)
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
var out [8]byte
for i := range out {
out[i] = r.NextByte() + 64
}
mask := r.NextByte()
for i := range out {
out[i] ^= mask
}
return out[:]
}
// Hash password using 4.1+ method (SHA1)
func scramblePassword(scramble []byte, password string) []byte {
if len(password) == 0 {
return nil
}
// stage1Hash = SHA1(password)
crypt := sha1.New()
crypt.Write([]byte(password))
stage1 := crypt.Sum(nil)
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
// inner Hash
crypt.Reset()
crypt.Write(stage1)
hash := crypt.Sum(nil)
// outer Hash
crypt.Reset()
crypt.Write(scramble)
crypt.Write(hash)
scramble = crypt.Sum(nil)
// token = scrambleHash XOR stage1Hash
for i := range scramble {
scramble[i] ^= stage1[i]
}
return scramble
}
// Hash password using MySQL 8+ method (SHA256)
func scrambleSHA256Password(scramble []byte, password string) []byte {
if len(password) == 0 {
return nil
}
// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
crypt := sha256.New()
crypt.Write([]byte(password))
message1 := crypt.Sum(nil)
crypt.Reset()
crypt.Write(message1)
message1Hash := crypt.Sum(nil)
crypt.Reset()
crypt.Write(message1Hash)
crypt.Write(scramble)
message2 := crypt.Sum(nil)
for i := range message1 {
message1[i] ^= message2[i]
}
return message1
}
func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) {
plain := make([]byte, len(password)+1)
copy(plain, password)
for i := range plain {
j := i % len(seed)
plain[i] ^= seed[j]
}
sha1 := sha1.New()
return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil)
}
func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
enc, err := encryptPassword(mc.cfg.Passwd, seed, pub)
if err != nil {
return err
}
return mc.writeAuthSwitchPacket(enc)
}
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
switch plugin {
case "caching_sha2_password":
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
return authResp, nil
case "mysql_old_password":
if !mc.cfg.AllowOldPasswords {
return nil, ErrOldPassword
}
if len(mc.cfg.Passwd) == 0 {
return nil, nil
}
// Note: there are edge cases where this should work but doesn't;
// this is currently "wontfix":
// https://github.com/go-sql-driver/mysql/issues/184
authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0)
return authResp, nil
case "mysql_clear_password":
if !mc.cfg.AllowCleartextPasswords {
return nil, ErrCleartextPassword
}
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
return append([]byte(mc.cfg.Passwd), 0), nil
case "mysql_native_password":
if !mc.cfg.AllowNativePasswords {
return nil, ErrNativePassword
}
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
// Native password authentication only need and will need 20-byte challenge.
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
return authResp, nil
case "sha256_password":
if len(mc.cfg.Passwd) == 0 {
return []byte{0}, nil
}
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
// write cleartext auth packet
return append([]byte(mc.cfg.Passwd), 0), nil
}
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
return []byte{1}, nil
}
// encrypted password
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
return enc, err
default:
errLog.Print("unknown auth plugin:", plugin)
return nil, ErrUnknownPlugin
}
}
func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
// Read Result Packet
authData, newPlugin, err := mc.readAuthResult()
if err != nil {
return err
}
// handle auth plugin switch, if requested
if newPlugin != "" {
// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is
// sent and we have to keep using the cipher sent in the init packet.
if authData == nil {
authData = oldAuthData
} else {
// copy data from read buffer to owned slice
copy(oldAuthData, authData)
}
plugin = newPlugin
authResp, err := mc.auth(authData, plugin)
if err != nil {
return err
}
if err = mc.writeAuthSwitchPacket(authResp); err != nil {
return err
}
// Read Result Packet
authData, newPlugin, err = mc.readAuthResult()
if err != nil {
return err
}
// Do not allow to change the auth plugin more than once
if newPlugin != "" {
return ErrMalformPkt
}
}
switch plugin {
// https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
case "caching_sha2_password":
switch len(authData) {
case 0:
return nil // auth successful
case 1:
switch authData[0] {
case cachingSha2PasswordFastAuthSuccess:
if err = mc.readResultOK(); err == nil {
return nil // auth successful
}
case cachingSha2PasswordPerformFullAuthentication:
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
// write cleartext auth packet
err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
if err != nil {
return err
}
} else {
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
return err
}
data[4] = cachingSha2PasswordRequestPublicKey
mc.writePacket(data)
// parse public key
if data, err = mc.readPacket(); err != nil {
return err
}
block, rest := pem.Decode(data[1:])
if block == nil {
return fmt.Errorf("No Pem data found, data: %s", rest)
}
pkix, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
}
pubKey = pkix.(*rsa.PublicKey)
}
// send encrypted password
err = mc.sendEncryptedPassword(oldAuthData, pubKey)
if err != nil {
return err
}
}
return mc.readResultOK()
default:
return ErrMalformPkt
}
default:
return ErrMalformPkt
}
case "sha256_password":
switch len(authData) {
case 0:
return nil // auth successful
default:
block, _ := pem.Decode(authData)
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
}
// send encrypted password
err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey))
if err != nil {
return err
}
return mc.readResultOK()
}
default:
return nil // auth successful
}
return err
}

@ -0,0 +1,182 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"io"
"net"
"time"
)
const defaultBufSize = 4096
const maxCachedBufSize = 256 * 1024
// A buffer which is used for both reading and writing.
// This is possible since communication on each connection is synchronous.
// In other words, we can't write and read simultaneously on the same connection.
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
// This buffer is backed by two byte slices in a double-buffering scheme
type buffer struct {
buf []byte // buf is a byte buffer who's length and capacity are equal.
nc net.Conn
idx int
length int
timeout time.Duration
dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer
flipcnt uint // flipccnt is the current buffer counter for double-buffering
}
// newBuffer allocates and returns a new buffer.
func newBuffer(nc net.Conn) buffer {
fg := make([]byte, defaultBufSize)
return buffer{
buf: fg,
nc: nc,
dbuf: [2][]byte{fg, nil},
}
}
// flip replaces the active buffer with the background buffer
// this is a delayed flip that simply increases the buffer counter;
// the actual flip will be performed the next time we call `buffer.fill`
func (b *buffer) flip() {
b.flipcnt += 1
}
// fill reads into the buffer until at least _need_ bytes are in it
func (b *buffer) fill(need int) error {
n := b.length
// fill data into its double-buffering target: if we've called
// flip on this buffer, we'll be copying to the background buffer,
// and then filling it with network data; otherwise we'll just move
// the contents of the current buffer to the front before filling it
dest := b.dbuf[b.flipcnt&1]
// grow buffer if necessary to fit the whole packet.
if need > len(dest) {
// Round up to the next multiple of the default size
dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
// if the allocated buffer is not too large, move it to backing storage
// to prevent extra allocations on applications that perform large reads
if len(dest) <= maxCachedBufSize {
b.dbuf[b.flipcnt&1] = dest
}
}
// if we're filling the fg buffer, move the existing data to the start of it.
// if we're filling the bg buffer, copy over the data
if n > 0 {
copy(dest[:n], b.buf[b.idx:])
}
b.buf = dest
b.idx = 0
for {
if b.timeout > 0 {
if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil {
return err
}
}
nn, err := b.nc.Read(b.buf[n:])
n += nn
switch err {
case nil:
if n < need {
continue
}
b.length = n
return nil
case io.EOF:
if n >= need {
b.length = n
return nil
}
return io.ErrUnexpectedEOF
default:
return err
}
}
}
// returns next N bytes from buffer.
// The returned slice is only guaranteed to be valid until the next read
func (b *buffer) readNext(need int) ([]byte, error) {
if b.length < need {
// refill
if err := b.fill(need); err != nil {
return nil, err
}
}
offset := b.idx
b.idx += need
b.length -= need
return b.buf[offset:b.idx], nil
}
// takeBuffer returns a buffer with the requested size.
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil, ErrBusyBuffer
}
// test (cheap) general case first
if length <= cap(b.buf) {
return b.buf[:length], nil
}
if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf, nil
}
// buffer is larger than we want to store.
return make([]byte, length), nil
}
// takeSmallBuffer is shortcut which can be used if length is
// known to be smaller than defaultBufSize.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil, ErrBusyBuffer
}
return b.buf[:length], nil
}
// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// cap and len of the returned buffer will be equal.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
if b.length > 0 {
return nil, ErrBusyBuffer
}
return b.buf, nil
}
// store stores buf, an updated buffer, if its suitable to do so.
func (b *buffer) store(buf []byte) error {
if b.length > 0 {
return ErrBusyBuffer
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
b.buf = buf[:cap(buf)]
}
return nil
}

@ -0,0 +1,265 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2014 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
const defaultCollation = "utf8mb4_general_ci"
const binaryCollation = "binary"
// A list of available collations mapped to the internal ID.
// To update this map use the following MySQL query:
// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID
//
// Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255.
//
// ucs2, utf16, and utf32 can't be used for connection charset.
// https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset
// They are commented out to reduce this map.
var collations = map[string]byte{
"big5_chinese_ci": 1,
"latin2_czech_cs": 2,
"dec8_swedish_ci": 3,
"cp850_general_ci": 4,
"latin1_german1_ci": 5,
"hp8_english_ci": 6,
"koi8r_general_ci": 7,
"latin1_swedish_ci": 8,
"latin2_general_ci": 9,
"swe7_swedish_ci": 10,
"ascii_general_ci": 11,
"ujis_japanese_ci": 12,
"sjis_japanese_ci": 13,
"cp1251_bulgarian_ci": 14,
"latin1_danish_ci": 15,
"hebrew_general_ci": 16,
"tis620_thai_ci": 18,
"euckr_korean_ci": 19,
"latin7_estonian_cs": 20,
"latin2_hungarian_ci": 21,
"koi8u_general_ci": 22,
"cp1251_ukrainian_ci": 23,
"gb2312_chinese_ci": 24,
"greek_general_ci": 25,
"cp1250_general_ci": 26,
"latin2_croatian_ci": 27,
"gbk_chinese_ci": 28,
"cp1257_lithuanian_ci": 29,
"latin5_turkish_ci": 30,
"latin1_german2_ci": 31,
"armscii8_general_ci": 32,
"utf8_general_ci": 33,
"cp1250_czech_cs": 34,
//"ucs2_general_ci": 35,
"cp866_general_ci": 36,
"keybcs2_general_ci": 37,
"macce_general_ci": 38,
"macroman_general_ci": 39,
"cp852_general_ci": 40,
"latin7_general_ci": 41,
"latin7_general_cs": 42,
"macce_bin": 43,
"cp1250_croatian_ci": 44,
"utf8mb4_general_ci": 45,
"utf8mb4_bin": 46,
"latin1_bin": 47,
"latin1_general_ci": 48,
"latin1_general_cs": 49,
"cp1251_bin": 50,
"cp1251_general_ci": 51,
"cp1251_general_cs": 52,
"macroman_bin": 53,
//"utf16_general_ci": 54,
//"utf16_bin": 55,
//"utf16le_general_ci": 56,
"cp1256_general_ci": 57,
"cp1257_bin": 58,
"cp1257_general_ci": 59,
//"utf32_general_ci": 60,
//"utf32_bin": 61,
//"utf16le_bin": 62,
"binary": 63,
"armscii8_bin": 64,
"ascii_bin": 65,
"cp1250_bin": 66,
"cp1256_bin": 67,
"cp866_bin": 68,
"dec8_bin": 69,
"greek_bin": 70,
"hebrew_bin": 71,
"hp8_bin": 72,
"keybcs2_bin": 73,
"koi8r_bin": 74,
"koi8u_bin": 75,
"utf8_tolower_ci": 76,
"latin2_bin": 77,
"latin5_bin": 78,
"latin7_bin": 79,
"cp850_bin": 80,
"cp852_bin": 81,
"swe7_bin": 82,
"utf8_bin": 83,
"big5_bin": 84,
"euckr_bin": 85,
"gb2312_bin": 86,
"gbk_bin": 87,
"sjis_bin": 88,
"tis620_bin": 89,
//"ucs2_bin": 90,
"ujis_bin": 91,
"geostd8_general_ci": 92,
"geostd8_bin": 93,
"latin1_spanish_ci": 94,
"cp932_japanese_ci": 95,
"cp932_bin": 96,
"eucjpms_japanese_ci": 97,
"eucjpms_bin": 98,
"cp1250_polish_ci": 99,
//"utf16_unicode_ci": 101,
//"utf16_icelandic_ci": 102,
//"utf16_latvian_ci": 103,
//"utf16_romanian_ci": 104,
//"utf16_slovenian_ci": 105,
//"utf16_polish_ci": 106,
//"utf16_estonian_ci": 107,
//"utf16_spanish_ci": 108,
//"utf16_swedish_ci": 109,
//"utf16_turkish_ci": 110,
//"utf16_czech_ci": 111,
//"utf16_danish_ci": 112,
//"utf16_lithuanian_ci": 113,
//"utf16_slovak_ci": 114,
//"utf16_spanish2_ci": 115,
//"utf16_roman_ci": 116,
//"utf16_persian_ci": 117,
//"utf16_esperanto_ci": 118,
//"utf16_hungarian_ci": 119,
//"utf16_sinhala_ci": 120,
//"utf16_german2_ci": 121,
//"utf16_croatian_ci": 122,
//"utf16_unicode_520_ci": 123,
//"utf16_vietnamese_ci": 124,
//"ucs2_unicode_ci": 128,
//"ucs2_icelandic_ci": 129,
//"ucs2_latvian_ci": 130,
//"ucs2_romanian_ci": 131,
//"ucs2_slovenian_ci": 132,
//"ucs2_polish_ci": 133,
//"ucs2_estonian_ci": 134,
//"ucs2_spanish_ci": 135,
//"ucs2_swedish_ci": 136,
//"ucs2_turkish_ci": 137,
//"ucs2_czech_ci": 138,
//"ucs2_danish_ci": 139,
//"ucs2_lithuanian_ci": 140,
//"ucs2_slovak_ci": 141,
//"ucs2_spanish2_ci": 142,
//"ucs2_roman_ci": 143,
//"ucs2_persian_ci": 144,
//"ucs2_esperanto_ci": 145,
//"ucs2_hungarian_ci": 146,
//"ucs2_sinhala_ci": 147,
//"ucs2_german2_ci": 148,
//"ucs2_croatian_ci": 149,
//"ucs2_unicode_520_ci": 150,
//"ucs2_vietnamese_ci": 151,
//"ucs2_general_mysql500_ci": 159,
//"utf32_unicode_ci": 160,
//"utf32_icelandic_ci": 161,
//"utf32_latvian_ci": 162,
//"utf32_romanian_ci": 163,
//"utf32_slovenian_ci": 164,
//"utf32_polish_ci": 165,
//"utf32_estonian_ci": 166,
//"utf32_spanish_ci": 167,
//"utf32_swedish_ci": 168,
//"utf32_turkish_ci": 169,
//"utf32_czech_ci": 170,
//"utf32_danish_ci": 171,
//"utf32_lithuanian_ci": 172,
//"utf32_slovak_ci": 173,
//"utf32_spanish2_ci": 174,
//"utf32_roman_ci": 175,
//"utf32_persian_ci": 176,
//"utf32_esperanto_ci": 177,
//"utf32_hungarian_ci": 178,
//"utf32_sinhala_ci": 179,
//"utf32_german2_ci": 180,
//"utf32_croatian_ci": 181,
//"utf32_unicode_520_ci": 182,
//"utf32_vietnamese_ci": 183,
"utf8_unicode_ci": 192,
"utf8_icelandic_ci": 193,
"utf8_latvian_ci": 194,
"utf8_romanian_ci": 195,
"utf8_slovenian_ci": 196,
"utf8_polish_ci": 197,
"utf8_estonian_ci": 198,
"utf8_spanish_ci": 199,
"utf8_swedish_ci": 200,
"utf8_turkish_ci": 201,
"utf8_czech_ci": 202,
"utf8_danish_ci": 203,
"utf8_lithuanian_ci": 204,
"utf8_slovak_ci": 205,
"utf8_spanish2_ci": 206,
"utf8_roman_ci": 207,
"utf8_persian_ci": 208,
"utf8_esperanto_ci": 209,
"utf8_hungarian_ci": 210,
"utf8_sinhala_ci": 211,
"utf8_german2_ci": 212,
"utf8_croatian_ci": 213,
"utf8_unicode_520_ci": 214,
"utf8_vietnamese_ci": 215,
"utf8_general_mysql500_ci": 223,
"utf8mb4_unicode_ci": 224,
"utf8mb4_icelandic_ci": 225,
"utf8mb4_latvian_ci": 226,
"utf8mb4_romanian_ci": 227,
"utf8mb4_slovenian_ci": 228,
"utf8mb4_polish_ci": 229,
"utf8mb4_estonian_ci": 230,
"utf8mb4_spanish_ci": 231,
"utf8mb4_swedish_ci": 232,
"utf8mb4_turkish_ci": 233,
"utf8mb4_czech_ci": 234,
"utf8mb4_danish_ci": 235,
"utf8mb4_lithuanian_ci": 236,
"utf8mb4_slovak_ci": 237,
"utf8mb4_spanish2_ci": 238,
"utf8mb4_roman_ci": 239,
"utf8mb4_persian_ci": 240,
"utf8mb4_esperanto_ci": 241,
"utf8mb4_hungarian_ci": 242,
"utf8mb4_sinhala_ci": 243,
"utf8mb4_german2_ci": 244,
"utf8mb4_croatian_ci": 245,
"utf8mb4_unicode_520_ci": 246,
"utf8mb4_vietnamese_ci": 247,
"gb18030_chinese_ci": 248,
"gb18030_bin": 249,
"gb18030_unicode_520_ci": 250,
"utf8mb4_0900_ai_ci": 255,
}
// A denylist of collations which is unsafe to interpolate parameters.
// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
var unsafeCollations = map[string]bool{
"big5_chinese_ci": true,
"sjis_japanese_ci": true,
"gbk_chinese_ci": true,
"big5_bin": true,
"gb2312_bin": true,
"gbk_bin": true,
"sjis_bin": true,
"cp932_japanese_ci": true,
"cp932_bin": true,
"gb18030_chinese_ci": true,
"gb18030_bin": true,
"gb18030_unicode_520_ci": true,
}

@ -0,0 +1,54 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos
package mysql
import (
"errors"
"io"
"net"
"syscall"
)
var errUnexpectedRead = errors.New("unexpected read from socket")
func connCheck(conn net.Conn) error {
var sysErr error
sysConn, ok := conn.(syscall.Conn)
if !ok {
return nil
}
rawConn, err := sysConn.SyscallConn()
if err != nil {
return err
}
err = rawConn.Read(func(fd uintptr) bool {
var buf [1]byte
n, err := syscall.Read(int(fd), buf[:])
switch {
case n == 0 && err == nil:
sysErr = io.EOF
case n > 0:
sysErr = errUnexpectedRead
case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK:
sysErr = nil
default:
sysErr = err
}
return true
})
if err != nil {
return err
}
return sysErr
}

@ -0,0 +1,17 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos
package mysql
import "net"
func connCheck(conn net.Conn) error {
return nil
}

@ -0,0 +1,650 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"io"
"net"
"strconv"
"strings"
"time"
)
type mysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows uint64
insertId uint64
cfg *Config
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
reset bool // set when the Go SQL package calls ResetSession
// for context support (Go 1.8+)
watching bool
watcher chan<- context.Context
closech chan struct{}
finished chan<- struct{}
canceled atomicError // set non-nil if conn is canceled
closed atomicBool // set when conn is closed, before closech is closed
}
// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
var cmdSet strings.Builder
for param, val := range mc.cfg.Params {
switch param {
// Charset: character_set_connection, character_set_client, character_set_results
case "charset":
charsets := strings.Split(val, ",")
for i := range charsets {
// ignore errors here - a charset may not exist
err = mc.exec("SET NAMES " + charsets[i])
if err == nil {
break
}
}
if err != nil {
return
}
// Other system vars accumulated in a single SET command
default:
if cmdSet.Len() == 0 {
// Heuristic: 29 chars for each other key=value to reduce reallocations
cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1))
cmdSet.WriteString("SET ")
} else {
cmdSet.WriteByte(',')
}
cmdSet.WriteString(param)
cmdSet.WriteByte('=')
cmdSet.WriteString(val)
}
}
if cmdSet.Len() > 0 {
err = mc.exec(cmdSet.String())
if err != nil {
return
}
}
return
}
func (mc *mysqlConn) markBadConn(err error) error {
if mc == nil {
return err
}
if err != errBadConnNoWrite {
return err
}
return driver.ErrBadConn
}
func (mc *mysqlConn) Begin() (driver.Tx, error) {
return mc.begin(false)
}
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
var q string
if readOnly {
q = "START TRANSACTION READ ONLY"
} else {
q = "START TRANSACTION"
}
err := mc.exec(q)
if err == nil {
return &mysqlTx{mc}, err
}
return nil, mc.markBadConn(err)
}
func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if !mc.closed.IsSet() {
err = mc.writeCommandPacket(comQuit)
}
mc.cleanup()
return
}
// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
if !mc.closed.TrySet(true) {
return
}
// Makes cleanup idempotent
close(mc.closech)
if mc.netConn == nil {
return
}
if err := mc.netConn.Close(); err != nil {
errLog.Print(err)
}
}
func (mc *mysqlConn) error() error {
if mc.closed.IsSet() {
if err := mc.canceled.Value(); err != nil {
return err
}
return ErrInvalidConn
}
return nil
}
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil {
// STMT_PREPARE is safe to retry. So we can return ErrBadConn here.
errLog.Print(err)
return nil, driver.ErrBadConn
}
stmt := &mysqlStmt{
mc: mc,
}
// Read Result
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
}
if columnCount > 0 {
err = mc.readUntilEOF()
}
}
return stmt, err
}
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
// Number of ? should be same to len(args)
if strings.Count(query, "?") != len(args) {
return "", driver.ErrSkip
}
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(err)
return "", ErrInvalidConn
}
buf = buf[:0]
argPos := 0
for i := 0; i < len(query); i++ {
q := strings.IndexByte(query[i:], '?')
if q == -1 {
buf = append(buf, query[i:]...)
break
}
buf = append(buf, query[i:i+q]...)
i += q
arg := args[argPos]
argPos++
if arg == nil {
buf = append(buf, "NULL"...)
continue
}
switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case uint64:
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
buf = strconv.AppendUint(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc))
if err != nil {
return "", err
}
buf = append(buf, '\'')
}
case json.RawMessage:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
}
buf = append(buf, '\'')
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, "_binary'"...)
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
}
buf = append(buf, '\'')
}
case string:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeStringBackslash(buf, v)
} else {
buf = escapeStringQuotes(buf, v)
}
buf = append(buf, '\'')
default:
return "", driver.ErrSkip
}
if len(buf)+4 > mc.maxAllowedPacket {
return "", driver.ErrSkip
}
}
if argPos != len(args) {
return "", driver.ErrSkip
}
return string(buf), nil
}
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
}
mc.affectedRows = 0
mc.insertId = 0
err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
}
return nil, mc.markBadConn(err)
}
// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err)
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return err
}
if resLen > 0 {
// columns
if err := mc.readUntilEOF(); err != nil {
return err
}
// rows
if err := mc.readUntilEOF(); err != nil {
return err
}
}
return mc.discardResults()
}
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
return mc.query(query, args)
}
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
return nil, driver.ErrSkip
}
// try client-side prepare to reduce roundtrip
prepared, err := mc.interpolateParams(query, args)
if err != nil {
return nil, err
}
query = prepared
}
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen == 0 {
rows.rs.done = true
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
// Columns
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
}
}
return nil, mc.markBadConn(err)
}
// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
}
dest := make([]driver.Value, resLen)
if err = rows.readRow(dest); err == nil {
return dest[0].([]byte), mc.readUntilEOF()
}
}
return nil, err
}
// finish is called when the query has canceled.
func (mc *mysqlConn) cancel(err error) {
mc.canceled.Set(err)
mc.cleanup()
}
// finish is called when the query has succeeded.
func (mc *mysqlConn) finish() {
if !mc.watching || mc.finished == nil {
return
}
select {
case mc.finished <- struct{}{}:
mc.watching = false
case <-mc.closech:
}
}
// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
if err = mc.watchCancel(ctx); err != nil {
return
}
defer mc.finish()
if err = mc.writeCommandPacket(comPing); err != nil {
return mc.markBadConn(err)
}
return mc.readResultOK()
}
// BeginTx implements driver.ConnBeginTx interface
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if mc.closed.IsSet() {
return nil, driver.ErrBadConn
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
level, err := mapIsolationLevel(opts.Isolation)
if err != nil {
return nil, err
}
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
if err != nil {
return nil, err
}
}
return mc.begin(opts.ReadOnly)
}
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := mc.query(query, dargs)
if err != nil {
mc.finish()
return nil, err
}
rows.finish = mc.finish
return rows, err
}
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
return mc.Exec(query, dargs)
}
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
stmt, err := mc.Prepare(query)
mc.finish()
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
stmt.Close()
return nil, ctx.Err()
}
return stmt, nil
}
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := stmt.query(dargs)
if err != nil {
stmt.mc.finish()
return nil, err
}
rows.finish = stmt.mc.finish
return rows, err
}
func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
defer stmt.mc.finish()
return stmt.Exec(dargs)
}
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
if mc.watching {
// Reach here if canceled,
// so the connection is already invalid
mc.cleanup()
return nil
}
// When ctx is already cancelled, don't watch it.
if err := ctx.Err(); err != nil {
return err
}
// When ctx is not cancellable, don't watch it.
if ctx.Done() == nil {
return nil
}
// When watcher is not alive, can't watch it.
if mc.watcher == nil {
return nil
}
mc.watching = true
mc.watcher <- ctx
return nil
}
func (mc *mysqlConn) startWatcher() {
watcher := make(chan context.Context, 1)
mc.watcher = watcher
finished := make(chan struct{})
mc.finished = finished
go func() {
for {
var ctx context.Context
select {
case ctx = <-watcher:
case <-mc.closech:
return
}
select {
case <-ctx.Done():
mc.cancel(ctx.Err())
case <-finished:
case <-mc.closech:
return
}
}
}()
}
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = converter{}.ConvertValue(nv.Value)
return
}
// ResetSession implements driver.SessionResetter.
// (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.IsSet() {
return driver.ErrBadConn
}
mc.reset = true
return nil
}
// IsValid implements driver.Validator interface
// (From Go 1.15)
func (mc *mysqlConn) IsValid() bool {
return !mc.closed.IsSet()
}

@ -0,0 +1,146 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"context"
"database/sql/driver"
"net"
)
type connector struct {
cfg *Config // immutable private copy.
}
// Connect implements driver.Connector interface.
// Connect returns a connection to the database.
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
var err error
// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
cfg: c.cfg,
}
mc.parseTime = mc.cfg.ParseTime
// Connect to Server
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
dctx := ctx
if mc.cfg.Timeout > 0 {
var cancel context.CancelFunc
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
defer cancel()
}
mc.netConn, err = dial(dctx, mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
}
if err != nil {
return nil, err
}
// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
if err := tc.SetKeepAlive(true); err != nil {
// Don't send COM_QUIT before handshake.
mc.netConn.Close()
mc.netConn = nil
return nil, err
}
}
// Call startWatcher for context support (From Go 1.8)
mc.startWatcher()
if err := mc.watchCancel(ctx); err != nil {
mc.cleanup()
return nil, err
}
defer mc.finish()
mc.buf = newBuffer(mc.netConn)
// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout
// Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket()
if err != nil {
mc.cleanup()
return nil, err
}
if plugin == "" {
plugin = defaultAuthPlugin
}
// Send Client Authentication Packet
authResp, err := mc.auth(authData, plugin)
if err != nil {
// try the default auth plugin, if using the requested plugin failed
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
plugin = defaultAuthPlugin
authResp, err = mc.auth(authData, plugin)
if err != nil {
mc.cleanup()
return nil, err
}
}
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
mc.cleanup()
return nil, err
}
// Handle response to auth packet, switch methods if possible
if err = mc.handleAuthResult(authData, plugin); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// Do not send COM_QUIT, just cleanup and return the error.
mc.cleanup()
return nil, err
}
if mc.cfg.MaxAllowedPacket > 0 {
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
} else {
// Get max allowed packet size
maxap, err := mc.getSystemVar("max_allowed_packet")
if err != nil {
mc.Close()
return nil, err
}
mc.maxAllowedPacket = stringToInt(maxap) - 1
}
if mc.maxAllowedPacket < maxPacketSize {
mc.maxWriteSize = mc.maxAllowedPacket
}
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}
return mc, nil
}
// Driver implements driver.Connector interface.
// Driver returns &MySQLDriver{}.
func (c *connector) Driver() driver.Driver {
return &MySQLDriver{}
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save