master
李光春 2 years ago
parent d6aa01eb12
commit cd0cd1c79d

@ -34,7 +34,7 @@ require (
github.com/shopspring/decimal v1.3.1 github.com/shopspring/decimal v1.3.1
github.com/sirupsen/logrus v1.9.0 github.com/sirupsen/logrus v1.9.0
github.com/tencentyun/cos-go-sdk-v5 v0.7.35 github.com/tencentyun/cos-go-sdk-v5 v0.7.35
github.com/upper/db/v4 v4.5.4 github.com/upper/db/v4 v4.6.0
github.com/uptrace/bun v1.1.6 github.com/uptrace/bun v1.1.6
github.com/uptrace/bun/dialect/mysqldialect v1.1.6 github.com/uptrace/bun/dialect/mysqldialect v1.1.6
github.com/uptrace/bun/dialect/pgdialect v1.1.6 github.com/uptrace/bun/dialect/pgdialect v1.1.6
@ -42,9 +42,9 @@ require (
github.com/upyun/go-sdk/v3 v3.0.2 github.com/upyun/go-sdk/v3 v3.0.2
go.mongodb.org/mongo-driver v1.10.0 go.mongodb.org/mongo-driver v1.10.0
go.uber.org/zap v1.21.0 go.uber.org/zap v1.21.0
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa golang.org/x/crypto v0.3.0
golang.org/x/oauth2 v0.0.0-20220722155238-128564f6959c golang.org/x/oauth2 v0.0.0-20220722155238-128564f6959c
golang.org/x/text v0.3.7 golang.org/x/text v0.5.0
google.golang.org/grpc v1.48.0 google.golang.org/grpc v1.48.0
google.golang.org/protobuf v1.28.0 google.golang.org/protobuf v1.28.0
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df
@ -112,13 +112,14 @@ require (
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.0.2 // indirect github.com/opencontainers/image-spec v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.6 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/saracen/go7z-fixtures v0.0.0-20190623165746-aa6b8fba1d2f // indirect github.com/saracen/go7z-fixtures v0.0.0-20190623165746-aa6b8fba1d2f // indirect
github.com/saracen/solidblock v0.0.0-20190426153529-45df20abab6f // indirect github.com/saracen/solidblock v0.0.0-20190426153529-45df20abab6f // indirect
github.com/segmentio/fasthash v1.0.3 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/testify v1.8.0 // indirect github.com/stretchr/testify v1.8.1 // indirect
github.com/syndtr/goleveldb v1.0.0 // indirect github.com/syndtr/goleveldb v1.0.0 // indirect
github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/pretty v1.2.0 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
@ -132,13 +133,13 @@ require (
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a // indirect github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a // indirect
go.uber.org/atomic v1.9.0 // indirect go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.8.0 // indirect go.uber.org/multierr v1.8.0 // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect golang.org/x/mod v0.7.0 // indirect
golang.org/x/net v0.0.0-20220722155237-a158d28d115b // indirect golang.org/x/net v0.3.0 // indirect
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f // indirect golang.org/x/sys v0.3.0 // indirect
golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 // indirect golang.org/x/term v0.3.0 // indirect
golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 // indirect golang.org/x/time v0.0.0-20220722155302-e5dcc9cfc0b9 // indirect
golang.org/x/tools v0.1.11 // indirect golang.org/x/tools v0.4.0 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20220722212130-b98a9ff5e252 // indirect google.golang.org/genproto v0.0.0-20220722212130-b98a9ff5e252 // indirect
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect

@ -591,6 +591,7 @@ github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceT
github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI=
github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg=
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY=
github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
@ -673,8 +674,8 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE=
github.com/pelletier/go-toml/v2 v2.0.2 h1:+jQXlF3scKIcSEKkdHzXhCTDLPFi5r1wnK6yPS+49Gw= github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU=
github.com/pelletier/go-toml/v2 v2.0.2/go.mod h1:MovirKjgVRESsAvNZlAjtFwV867yGuwRkXbG66OzopI= github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek=
github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac= github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac=
github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU=
github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc=
@ -742,6 +743,8 @@ github.com/saracen/solidblock v0.0.0-20190426153529-45df20abab6f/go.mod h1:LyBTu
github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtrmhM=
github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
@ -774,8 +777,9 @@ github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5J
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
@ -783,9 +787,9 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.194/go.mod h1:7sCQWVkxcsR38nffDW057DRGk8mUjK1Ing/EFOK8s8Y= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.194/go.mod h1:7sCQWVkxcsR38nffDW057DRGk8mUjK1Ing/EFOK8s8Y=
@ -803,8 +807,8 @@ github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/ulikunitz/xz v0.5.10 h1:t92gobL9l3HE202wg3rlk19F6X+JOxl9BBrCCMYEYd8= github.com/ulikunitz/xz v0.5.10 h1:t92gobL9l3HE202wg3rlk19F6X+JOxl9BBrCCMYEYd8=
github.com/ulikunitz/xz v0.5.10/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/ulikunitz/xz v0.5.10/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/upper/db/v4 v4.5.4 h1:Hxho4jSx4E+3fxlFgdH4wQTRKygtL0YQPDLQPCUu9wg= github.com/upper/db/v4 v4.6.0 h1:0VmASnqrl/XN8Ehoq++HBgZ4zRD5j3GXygW8FhP0C5I=
github.com/upper/db/v4 v4.5.4/go.mod h1:wyu5BM5Y2gowOt4i6C4LbxftH9QeUF338XVGH4uk+Eo= github.com/upper/db/v4 v4.6.0/go.mod h1:2mnRcPf+RcCXmVcD+o04LYlyu3UuF7ubamJia7CkN6s=
github.com/uptrace/bun v1.1.6 h1:vDJ1Qs6fXock5+q/PSOZZ7vZVZABmWkGlgZDUkJwbfc= github.com/uptrace/bun v1.1.6 h1:vDJ1Qs6fXock5+q/PSOZZ7vZVZABmWkGlgZDUkJwbfc=
github.com/uptrace/bun v1.1.6/go.mod h1:Z2Pd3cRvNKbrYuL6Gp1XGjA9QEYz+rDz5KkEi9MZLnQ= github.com/uptrace/bun v1.1.6/go.mod h1:Z2Pd3cRvNKbrYuL6Gp1XGjA9QEYz+rDz5KkEi9MZLnQ=
github.com/uptrace/bun/dialect/mysqldialect v1.1.6 h1:s0sOiXwszVLzXzsOqBKiL7A7g5GcNmrKOtaQFwQtqHc= github.com/uptrace/bun/dialect/mysqldialect v1.1.6 h1:s0sOiXwszVLzXzsOqBKiL7A7g5GcNmrKOtaQFwQtqHc=
@ -908,8 +912,8 @@ golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220307211146-efcb8507fb70/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220307211146-efcb8507fb70/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
golang.org/x/exp v0.0.0-20181106170214-d68db9428509/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20181106170214-d68db9428509/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -947,6 +951,7 @@ golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -996,8 +1001,9 @@ golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0= golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -1026,6 +1032,7 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -1099,13 +1106,15 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s= golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 h1:Q5284mrmYTpACcm+eAKjKJH48BBwSyfJqmmGDTtT8Vc= golang.org/x/term v0.2.0 h1:z85xZCsEl7bi/KwbNADeBYoOP0++7W1ipu+aGnpwzRM=
golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -1114,8 +1123,10 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@ -1187,8 +1198,9 @@ golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4f
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.11 h1:loJ25fNOEhSXfHrpoGj91eCUThwdNX6u24rO1xnNteY= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU=
golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4= golang.org/x/tools v0.4.0 h1:7mTAgkunk3fr4GAloyyCasadO6h9zSsQZbwvcaIciV4=
golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

@ -1,5 +1,5 @@
package go_library package go_library
func Version() string { func Version() string {
return "1.0.51" return "1.0.52"
} }

@ -140,6 +140,17 @@ fmt.Println(string(b))
[marshal]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#Marshal [marshal]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#Marshal
## Unstable API
This API does not yet follow the backward compatibility guarantees of this
library. They provide early access to features that may have rough edges or an
API subject to change.
### Parser
Parser is the unstable API that allows iterative parsing of a TOML document at
the AST level. See https://pkg.go.dev/github.com/pelletier/go-toml/v2/unstable.
## Benchmarks ## Benchmarks
Execution time speedup compared to other Go TOML libraries: Execution time speedup compared to other Go TOML libraries:

@ -5,6 +5,8 @@ import (
"math" "math"
"strconv" "strconv"
"time" "time"
"github.com/pelletier/go-toml/v2/unstable"
) )
func parseInteger(b []byte) (int64, error) { func parseInteger(b []byte) (int64, error) {
@ -32,7 +34,7 @@ func parseLocalDate(b []byte) (LocalDate, error) {
var date LocalDate var date LocalDate
if len(b) != 10 || b[4] != '-' || b[7] != '-' { if len(b) != 10 || b[4] != '-' || b[7] != '-' {
return date, newDecodeError(b, "dates are expected to have the format YYYY-MM-DD") return date, unstable.NewParserError(b, "dates are expected to have the format YYYY-MM-DD")
} }
var err error var err error
@ -53,7 +55,7 @@ func parseLocalDate(b []byte) (LocalDate, error) {
} }
if !isValidDate(date.Year, date.Month, date.Day) { if !isValidDate(date.Year, date.Month, date.Day) {
return LocalDate{}, newDecodeError(b, "impossible date") return LocalDate{}, unstable.NewParserError(b, "impossible date")
} }
return date, nil return date, nil
@ -64,7 +66,7 @@ func parseDecimalDigits(b []byte) (int, error) {
for i, c := range b { for i, c := range b {
if c < '0' || c > '9' { if c < '0' || c > '9' {
return 0, newDecodeError(b[i:i+1], "expected digit (0-9)") return 0, unstable.NewParserError(b[i:i+1], "expected digit (0-9)")
} }
v *= 10 v *= 10
v += int(c - '0') v += int(c - '0')
@ -97,7 +99,7 @@ func parseDateTime(b []byte) (time.Time, error) {
} else { } else {
const dateTimeByteLen = 6 const dateTimeByteLen = 6
if len(b) != dateTimeByteLen { if len(b) != dateTimeByteLen {
return time.Time{}, newDecodeError(b, "invalid date-time timezone") return time.Time{}, unstable.NewParserError(b, "invalid date-time timezone")
} }
var direction int var direction int
switch b[0] { switch b[0] {
@ -106,11 +108,11 @@ func parseDateTime(b []byte) (time.Time, error) {
case '+': case '+':
direction = +1 direction = +1
default: default:
return time.Time{}, newDecodeError(b[:1], "invalid timezone offset character") return time.Time{}, unstable.NewParserError(b[:1], "invalid timezone offset character")
} }
if b[3] != ':' { if b[3] != ':' {
return time.Time{}, newDecodeError(b[3:4], "expected a : separator") return time.Time{}, unstable.NewParserError(b[3:4], "expected a : separator")
} }
hours, err := parseDecimalDigits(b[1:3]) hours, err := parseDecimalDigits(b[1:3])
@ -118,7 +120,7 @@ func parseDateTime(b []byte) (time.Time, error) {
return time.Time{}, err return time.Time{}, err
} }
if hours > 23 { if hours > 23 {
return time.Time{}, newDecodeError(b[:1], "invalid timezone offset hours") return time.Time{}, unstable.NewParserError(b[:1], "invalid timezone offset hours")
} }
minutes, err := parseDecimalDigits(b[4:6]) minutes, err := parseDecimalDigits(b[4:6])
@ -126,7 +128,7 @@ func parseDateTime(b []byte) (time.Time, error) {
return time.Time{}, err return time.Time{}, err
} }
if minutes > 59 { if minutes > 59 {
return time.Time{}, newDecodeError(b[:1], "invalid timezone offset minutes") return time.Time{}, unstable.NewParserError(b[:1], "invalid timezone offset minutes")
} }
seconds := direction * (hours*3600 + minutes*60) seconds := direction * (hours*3600 + minutes*60)
@ -139,7 +141,7 @@ func parseDateTime(b []byte) (time.Time, error) {
} }
if len(b) > 0 { if len(b) > 0 {
return time.Time{}, newDecodeError(b, "extra bytes at the end of the timezone") return time.Time{}, unstable.NewParserError(b, "extra bytes at the end of the timezone")
} }
t := time.Date( t := time.Date(
@ -160,7 +162,7 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
const localDateTimeByteMinLen = 11 const localDateTimeByteMinLen = 11
if len(b) < localDateTimeByteMinLen { if len(b) < localDateTimeByteMinLen {
return dt, nil, newDecodeError(b, "local datetimes are expected to have the format YYYY-MM-DDTHH:MM:SS[.NNNNNNNNN]") return dt, nil, unstable.NewParserError(b, "local datetimes are expected to have the format YYYY-MM-DDTHH:MM:SS[.NNNNNNNNN]")
} }
date, err := parseLocalDate(b[:10]) date, err := parseLocalDate(b[:10])
@ -171,7 +173,7 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
sep := b[10] sep := b[10]
if sep != 'T' && sep != ' ' && sep != 't' { if sep != 'T' && sep != ' ' && sep != 't' {
return dt, nil, newDecodeError(b[10:11], "datetime separator is expected to be T or a space") return dt, nil, unstable.NewParserError(b[10:11], "datetime separator is expected to be T or a space")
} }
t, rest, err := parseLocalTime(b[11:]) t, rest, err := parseLocalTime(b[11:])
@ -195,7 +197,7 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
// check if b matches to have expected format HH:MM:SS[.NNNNNN] // check if b matches to have expected format HH:MM:SS[.NNNNNN]
const localTimeByteLen = 8 const localTimeByteLen = 8
if len(b) < localTimeByteLen { if len(b) < localTimeByteLen {
return t, nil, newDecodeError(b, "times are expected to have the format HH:MM:SS[.NNNNNN]") return t, nil, unstable.NewParserError(b, "times are expected to have the format HH:MM:SS[.NNNNNN]")
} }
var err error var err error
@ -206,10 +208,10 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
} }
if t.Hour > 23 { if t.Hour > 23 {
return t, nil, newDecodeError(b[0:2], "hour cannot be greater 23") return t, nil, unstable.NewParserError(b[0:2], "hour cannot be greater 23")
} }
if b[2] != ':' { if b[2] != ':' {
return t, nil, newDecodeError(b[2:3], "expecting colon between hours and minutes") return t, nil, unstable.NewParserError(b[2:3], "expecting colon between hours and minutes")
} }
t.Minute, err = parseDecimalDigits(b[3:5]) t.Minute, err = parseDecimalDigits(b[3:5])
@ -217,10 +219,10 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
return t, nil, err return t, nil, err
} }
if t.Minute > 59 { if t.Minute > 59 {
return t, nil, newDecodeError(b[3:5], "minutes cannot be greater 59") return t, nil, unstable.NewParserError(b[3:5], "minutes cannot be greater 59")
} }
if b[5] != ':' { if b[5] != ':' {
return t, nil, newDecodeError(b[5:6], "expecting colon between minutes and seconds") return t, nil, unstable.NewParserError(b[5:6], "expecting colon between minutes and seconds")
} }
t.Second, err = parseDecimalDigits(b[6:8]) t.Second, err = parseDecimalDigits(b[6:8])
@ -229,7 +231,7 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
} }
if t.Second > 60 { if t.Second > 60 {
return t, nil, newDecodeError(b[6:8], "seconds cannot be greater 60") return t, nil, unstable.NewParserError(b[6:8], "seconds cannot be greater 60")
} }
b = b[8:] b = b[8:]
@ -242,7 +244,7 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
for i, c := range b[1:] { for i, c := range b[1:] {
if !isDigit(c) { if !isDigit(c) {
if i == 0 { if i == 0 {
return t, nil, newDecodeError(b[0:1], "need at least one digit after fraction point") return t, nil, unstable.NewParserError(b[0:1], "need at least one digit after fraction point")
} }
break break
} }
@ -266,7 +268,7 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
} }
if precision == 0 { if precision == 0 {
return t, nil, newDecodeError(b[:1], "nanoseconds need at least one digit") return t, nil, unstable.NewParserError(b[:1], "nanoseconds need at least one digit")
} }
t.Nanosecond = frac * nspow[precision] t.Nanosecond = frac * nspow[precision]
@ -289,24 +291,24 @@ func parseFloat(b []byte) (float64, error) {
} }
if cleaned[0] == '.' { if cleaned[0] == '.' {
return 0, newDecodeError(b, "float cannot start with a dot") return 0, unstable.NewParserError(b, "float cannot start with a dot")
} }
if cleaned[len(cleaned)-1] == '.' { if cleaned[len(cleaned)-1] == '.' {
return 0, newDecodeError(b, "float cannot end with a dot") return 0, unstable.NewParserError(b, "float cannot end with a dot")
} }
dotAlreadySeen := false dotAlreadySeen := false
for i, c := range cleaned { for i, c := range cleaned {
if c == '.' { if c == '.' {
if dotAlreadySeen { if dotAlreadySeen {
return 0, newDecodeError(b[i:i+1], "float can have at most one decimal point") return 0, unstable.NewParserError(b[i:i+1], "float can have at most one decimal point")
} }
if !isDigit(cleaned[i-1]) { if !isDigit(cleaned[i-1]) {
return 0, newDecodeError(b[i-1:i+1], "float decimal point must be preceded by a digit") return 0, unstable.NewParserError(b[i-1:i+1], "float decimal point must be preceded by a digit")
} }
if !isDigit(cleaned[i+1]) { if !isDigit(cleaned[i+1]) {
return 0, newDecodeError(b[i:i+2], "float decimal point must be followed by a digit") return 0, unstable.NewParserError(b[i:i+2], "float decimal point must be followed by a digit")
} }
dotAlreadySeen = true dotAlreadySeen = true
} }
@ -317,12 +319,12 @@ func parseFloat(b []byte) (float64, error) {
start = 1 start = 1
} }
if cleaned[start] == '0' && isDigit(cleaned[start+1]) { if cleaned[start] == '0' && isDigit(cleaned[start+1]) {
return 0, newDecodeError(b, "float integer part cannot have leading zeroes") return 0, unstable.NewParserError(b, "float integer part cannot have leading zeroes")
} }
f, err := strconv.ParseFloat(string(cleaned), 64) f, err := strconv.ParseFloat(string(cleaned), 64)
if err != nil { if err != nil {
return 0, newDecodeError(b, "unable to parse float: %w", err) return 0, unstable.NewParserError(b, "unable to parse float: %w", err)
} }
return f, nil return f, nil
@ -336,7 +338,7 @@ func parseIntHex(b []byte) (int64, error) {
i, err := strconv.ParseInt(string(cleaned), 16, 64) i, err := strconv.ParseInt(string(cleaned), 16, 64)
if err != nil { if err != nil {
return 0, newDecodeError(b, "couldn't parse hexadecimal number: %w", err) return 0, unstable.NewParserError(b, "couldn't parse hexadecimal number: %w", err)
} }
return i, nil return i, nil
@ -350,7 +352,7 @@ func parseIntOct(b []byte) (int64, error) {
i, err := strconv.ParseInt(string(cleaned), 8, 64) i, err := strconv.ParseInt(string(cleaned), 8, 64)
if err != nil { if err != nil {
return 0, newDecodeError(b, "couldn't parse octal number: %w", err) return 0, unstable.NewParserError(b, "couldn't parse octal number: %w", err)
} }
return i, nil return i, nil
@ -364,7 +366,7 @@ func parseIntBin(b []byte) (int64, error) {
i, err := strconv.ParseInt(string(cleaned), 2, 64) i, err := strconv.ParseInt(string(cleaned), 2, 64)
if err != nil { if err != nil {
return 0, newDecodeError(b, "couldn't parse binary number: %w", err) return 0, unstable.NewParserError(b, "couldn't parse binary number: %w", err)
} }
return i, nil return i, nil
@ -387,12 +389,12 @@ func parseIntDec(b []byte) (int64, error) {
} }
if len(cleaned) > startIdx+1 && cleaned[startIdx] == '0' { if len(cleaned) > startIdx+1 && cleaned[startIdx] == '0' {
return 0, newDecodeError(b, "leading zero not allowed on decimal number") return 0, unstable.NewParserError(b, "leading zero not allowed on decimal number")
} }
i, err := strconv.ParseInt(string(cleaned), 10, 64) i, err := strconv.ParseInt(string(cleaned), 10, 64)
if err != nil { if err != nil {
return 0, newDecodeError(b, "couldn't parse decimal number: %w", err) return 0, unstable.NewParserError(b, "couldn't parse decimal number: %w", err)
} }
return i, nil return i, nil
@ -409,11 +411,11 @@ func checkAndRemoveUnderscoresIntegers(b []byte) ([]byte, error) {
} }
if b[start] == '_' { if b[start] == '_' {
return nil, newDecodeError(b[start:start+1], "number cannot start with underscore") return nil, unstable.NewParserError(b[start:start+1], "number cannot start with underscore")
} }
if b[len(b)-1] == '_' { if b[len(b)-1] == '_' {
return nil, newDecodeError(b[len(b)-1:], "number cannot end with underscore") return nil, unstable.NewParserError(b[len(b)-1:], "number cannot end with underscore")
} }
// fast path // fast path
@ -435,7 +437,7 @@ func checkAndRemoveUnderscoresIntegers(b []byte) ([]byte, error) {
c := b[i] c := b[i]
if c == '_' { if c == '_' {
if !before { if !before {
return nil, newDecodeError(b[i-1:i+1], "number must have at least one digit between underscores") return nil, unstable.NewParserError(b[i-1:i+1], "number must have at least one digit between underscores")
} }
before = false before = false
} else { } else {
@ -449,11 +451,11 @@ func checkAndRemoveUnderscoresIntegers(b []byte) ([]byte, error) {
func checkAndRemoveUnderscoresFloats(b []byte) ([]byte, error) { func checkAndRemoveUnderscoresFloats(b []byte) ([]byte, error) {
if b[0] == '_' { if b[0] == '_' {
return nil, newDecodeError(b[0:1], "number cannot start with underscore") return nil, unstable.NewParserError(b[0:1], "number cannot start with underscore")
} }
if b[len(b)-1] == '_' { if b[len(b)-1] == '_' {
return nil, newDecodeError(b[len(b)-1:], "number cannot end with underscore") return nil, unstable.NewParserError(b[len(b)-1:], "number cannot end with underscore")
} }
// fast path // fast path
@ -476,10 +478,10 @@ func checkAndRemoveUnderscoresFloats(b []byte) ([]byte, error) {
switch c { switch c {
case '_': case '_':
if !before { if !before {
return nil, newDecodeError(b[i-1:i+1], "number must have at least one digit between underscores") return nil, unstable.NewParserError(b[i-1:i+1], "number must have at least one digit between underscores")
} }
if i < len(b)-1 && (b[i+1] == 'e' || b[i+1] == 'E') { if i < len(b)-1 && (b[i+1] == 'e' || b[i+1] == 'E') {
return nil, newDecodeError(b[i+1:i+2], "cannot have underscore before exponent") return nil, unstable.NewParserError(b[i+1:i+2], "cannot have underscore before exponent")
} }
before = false before = false
case '+', '-': case '+', '-':
@ -488,15 +490,15 @@ func checkAndRemoveUnderscoresFloats(b []byte) ([]byte, error) {
before = false before = false
case 'e', 'E': case 'e', 'E':
if i < len(b)-1 && b[i+1] == '_' { if i < len(b)-1 && b[i+1] == '_' {
return nil, newDecodeError(b[i+1:i+2], "cannot have underscore after exponent") return nil, unstable.NewParserError(b[i+1:i+2], "cannot have underscore after exponent")
} }
cleaned = append(cleaned, c) cleaned = append(cleaned, c)
case '.': case '.':
if i < len(b)-1 && b[i+1] == '_' { if i < len(b)-1 && b[i+1] == '_' {
return nil, newDecodeError(b[i+1:i+2], "cannot have underscore after decimal point") return nil, unstable.NewParserError(b[i+1:i+2], "cannot have underscore after decimal point")
} }
if i > 0 && b[i-1] == '_' { if i > 0 && b[i-1] == '_' {
return nil, newDecodeError(b[i-1:i], "cannot have underscore before decimal point") return nil, unstable.NewParserError(b[i-1:i], "cannot have underscore before decimal point")
} }
cleaned = append(cleaned, c) cleaned = append(cleaned, c)
default: default:
@ -542,3 +544,7 @@ func daysIn(m int, year int) int {
func isLeap(year int) bool { func isLeap(year int) bool {
return year%4 == 0 && (year%100 != 0 || year%400 == 0) return year%4 == 0 && (year%100 != 0 || year%400 == 0)
} }
func isDigit(r byte) bool {
return r >= '0' && r <= '9'
}

@ -6,6 +6,7 @@ import (
"strings" "strings"
"github.com/pelletier/go-toml/v2/internal/danger" "github.com/pelletier/go-toml/v2/internal/danger"
"github.com/pelletier/go-toml/v2/unstable"
) )
// DecodeError represents an error encountered during the parsing or decoding // DecodeError represents an error encountered during the parsing or decoding
@ -55,25 +56,6 @@ func (s *StrictMissingError) String() string {
type Key []string type Key []string
// internal version of DecodeError that is used as the base to create a
// DecodeError with full context.
type decodeError struct {
highlight []byte
message string
key Key // optional
}
func (de *decodeError) Error() string {
return de.message
}
func newDecodeError(highlight []byte, format string, args ...interface{}) error {
return &decodeError{
highlight: highlight,
message: fmt.Errorf(format, args...).Error(),
}
}
// Error returns the error message contained in the DecodeError. // Error returns the error message contained in the DecodeError.
func (e *DecodeError) Error() string { func (e *DecodeError) Error() string {
return "toml: " + e.message return "toml: " + e.message
@ -103,13 +85,14 @@ func (e *DecodeError) Key() Key {
// //
// The function copies all bytes used in DecodeError, so that document and // The function copies all bytes used in DecodeError, so that document and
// highlight can be freely deallocated. // highlight can be freely deallocated.
//
//nolint:funlen //nolint:funlen
func wrapDecodeError(document []byte, de *decodeError) *DecodeError { func wrapDecodeError(document []byte, de *unstable.ParserError) *DecodeError {
offset := danger.SubsliceOffset(document, de.highlight) offset := danger.SubsliceOffset(document, de.Highlight)
errMessage := de.Error() errMessage := de.Error()
errLine, errColumn := positionAtEnd(document[:offset]) errLine, errColumn := positionAtEnd(document[:offset])
before, after := linesOfContext(document, de.highlight, offset, 3) before, after := linesOfContext(document, de.Highlight, offset, 3)
var buf strings.Builder var buf strings.Builder
@ -139,7 +122,7 @@ func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
buf.Write(before[0]) buf.Write(before[0])
} }
buf.Write(de.highlight) buf.Write(de.Highlight)
if len(after) > 0 { if len(after) > 0 {
buf.Write(after[0]) buf.Write(after[0])
@ -157,7 +140,7 @@ func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
buf.WriteString(strings.Repeat(" ", len(before[0]))) buf.WriteString(strings.Repeat(" ", len(before[0])))
} }
buf.WriteString(strings.Repeat("~", len(de.highlight))) buf.WriteString(strings.Repeat("~", len(de.Highlight)))
if len(errMessage) > 0 { if len(errMessage) > 0 {
buf.WriteString(" ") buf.WriteString(" ")
@ -182,7 +165,7 @@ func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
message: errMessage, message: errMessage,
line: errLine, line: errLine,
column: errColumn, column: errColumn,
key: de.key, key: de.Key,
human: buf.String(), human: buf.String(),
} }
} }

@ -1,51 +0,0 @@
package ast
type Reference int
const InvalidReference Reference = -1
func (r Reference) Valid() bool {
return r != InvalidReference
}
type Builder struct {
tree Root
lastIdx int
}
func (b *Builder) Tree() *Root {
return &b.tree
}
func (b *Builder) NodeAt(ref Reference) *Node {
return b.tree.at(ref)
}
func (b *Builder) Reset() {
b.tree.nodes = b.tree.nodes[:0]
b.lastIdx = 0
}
func (b *Builder) Push(n Node) Reference {
b.lastIdx = len(b.tree.nodes)
b.tree.nodes = append(b.tree.nodes, n)
return Reference(b.lastIdx)
}
func (b *Builder) PushAndChain(n Node) Reference {
newIdx := len(b.tree.nodes)
b.tree.nodes = append(b.tree.nodes, n)
if b.lastIdx >= 0 {
b.tree.nodes[b.lastIdx].next = newIdx - b.lastIdx
}
b.lastIdx = newIdx
return Reference(b.lastIdx)
}
func (b *Builder) AttachChild(parent Reference, child Reference) {
b.tree.nodes[parent].child = int(child) - int(parent)
}
func (b *Builder) Chain(from Reference, to Reference) {
b.tree.nodes[from].next = int(to) - int(from)
}

@ -0,0 +1,42 @@
package characters
var invalidAsciiTable = [256]bool{
0x00: true,
0x01: true,
0x02: true,
0x03: true,
0x04: true,
0x05: true,
0x06: true,
0x07: true,
0x08: true,
// 0x09 TAB
// 0x0A LF
0x0B: true,
0x0C: true,
// 0x0D CR
0x0E: true,
0x0F: true,
0x10: true,
0x11: true,
0x12: true,
0x13: true,
0x14: true,
0x15: true,
0x16: true,
0x17: true,
0x18: true,
0x19: true,
0x1A: true,
0x1B: true,
0x1C: true,
0x1D: true,
0x1E: true,
0x1F: true,
// 0x20 - 0x7E Printable ASCII characters
0x7F: true,
}
func InvalidAscii(b byte) bool {
return invalidAsciiTable[b]
}

@ -1,4 +1,4 @@
package toml package characters
import ( import (
"unicode/utf8" "unicode/utf8"
@ -32,7 +32,7 @@ func (u utf8Err) Zero() bool {
// 0x9 => tab, ok // 0x9 => tab, ok
// 0xA - 0x1F => invalid // 0xA - 0x1F => invalid
// 0x7F => invalid // 0x7F => invalid
func utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) { func Utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) {
// Fast path. Check for and skip 8 bytes of ASCII characters per iteration. // Fast path. Check for and skip 8 bytes of ASCII characters per iteration.
offset := 0 offset := 0
for len(p) >= 8 { for len(p) >= 8 {
@ -48,7 +48,7 @@ func utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) {
} }
for i, b := range p[:8] { for i, b := range p[:8] {
if invalidAscii(b) { if InvalidAscii(b) {
err.Index = offset + i err.Index = offset + i
err.Size = 1 err.Size = 1
return return
@ -62,7 +62,7 @@ func utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) {
for i := 0; i < n; { for i := 0; i < n; {
pi := p[i] pi := p[i]
if pi < utf8.RuneSelf { if pi < utf8.RuneSelf {
if invalidAscii(pi) { if InvalidAscii(pi) {
err.Index = offset + i err.Index = offset + i
err.Size = 1 err.Size = 1
return return
@ -106,11 +106,11 @@ func utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) {
} }
// Return the size of the next rune if valid, 0 otherwise. // Return the size of the next rune if valid, 0 otherwise.
func utf8ValidNext(p []byte) int { func Utf8ValidNext(p []byte) int {
c := p[0] c := p[0]
if c < utf8.RuneSelf { if c < utf8.RuneSelf {
if invalidAscii(c) { if InvalidAscii(c) {
return 0 return 0
} }
return 1 return 1
@ -140,47 +140,6 @@ func utf8ValidNext(p []byte) int {
return size return size
} }
var invalidAsciiTable = [256]bool{
0x00: true,
0x01: true,
0x02: true,
0x03: true,
0x04: true,
0x05: true,
0x06: true,
0x07: true,
0x08: true,
// 0x09 TAB
// 0x0A LF
0x0B: true,
0x0C: true,
// 0x0D CR
0x0E: true,
0x0F: true,
0x10: true,
0x11: true,
0x12: true,
0x13: true,
0x14: true,
0x15: true,
0x16: true,
0x17: true,
0x18: true,
0x19: true,
0x1A: true,
0x1B: true,
0x1C: true,
0x1D: true,
0x1E: true,
0x1F: true,
// 0x20 - 0x7E Printable ASCII characters
0x7F: true,
}
func invalidAscii(b byte) bool {
return invalidAsciiTable[b]
}
// acceptRange gives the range of valid values for the second byte in a UTF-8 // acceptRange gives the range of valid values for the second byte in a UTF-8
// sequence. // sequence.
type acceptRange struct { type acceptRange struct {

@ -1,8 +1,6 @@
package tracker package tracker
import ( import "github.com/pelletier/go-toml/v2/unstable"
"github.com/pelletier/go-toml/v2/internal/ast"
)
// KeyTracker is a tracker that keeps track of the current Key as the AST is // KeyTracker is a tracker that keeps track of the current Key as the AST is
// walked. // walked.
@ -11,19 +9,19 @@ type KeyTracker struct {
} }
// UpdateTable sets the state of the tracker with the AST table node. // UpdateTable sets the state of the tracker with the AST table node.
func (t *KeyTracker) UpdateTable(node *ast.Node) { func (t *KeyTracker) UpdateTable(node *unstable.Node) {
t.reset() t.reset()
t.Push(node) t.Push(node)
} }
// UpdateArrayTable sets the state of the tracker with the AST array table node. // UpdateArrayTable sets the state of the tracker with the AST array table node.
func (t *KeyTracker) UpdateArrayTable(node *ast.Node) { func (t *KeyTracker) UpdateArrayTable(node *unstable.Node) {
t.reset() t.reset()
t.Push(node) t.Push(node)
} }
// Push the given key on the stack. // Push the given key on the stack.
func (t *KeyTracker) Push(node *ast.Node) { func (t *KeyTracker) Push(node *unstable.Node) {
it := node.Key() it := node.Key()
for it.Next() { for it.Next() {
t.k = append(t.k, string(it.Node().Data)) t.k = append(t.k, string(it.Node().Data))
@ -31,7 +29,7 @@ func (t *KeyTracker) Push(node *ast.Node) {
} }
// Pop key from stack. // Pop key from stack.
func (t *KeyTracker) Pop(node *ast.Node) { func (t *KeyTracker) Pop(node *unstable.Node) {
it := node.Key() it := node.Key()
for it.Next() { for it.Next() {
t.k = t.k[:len(t.k)-1] t.k = t.k[:len(t.k)-1]

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/unstable"
) )
type keyKind uint8 type keyKind uint8
@ -150,23 +150,23 @@ func (s *SeenTracker) setExplicitFlag(parentIdx int) {
// CheckExpression takes a top-level node and checks that it does not contain // CheckExpression takes a top-level node and checks that it does not contain
// keys that have been seen in previous calls, and validates that types are // keys that have been seen in previous calls, and validates that types are
// consistent. // consistent.
func (s *SeenTracker) CheckExpression(node *ast.Node) error { func (s *SeenTracker) CheckExpression(node *unstable.Node) error {
if s.entries == nil { if s.entries == nil {
s.reset() s.reset()
} }
switch node.Kind { switch node.Kind {
case ast.KeyValue: case unstable.KeyValue:
return s.checkKeyValue(node) return s.checkKeyValue(node)
case ast.Table: case unstable.Table:
return s.checkTable(node) return s.checkTable(node)
case ast.ArrayTable: case unstable.ArrayTable:
return s.checkArrayTable(node) return s.checkArrayTable(node)
default: default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind)) panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
} }
} }
func (s *SeenTracker) checkTable(node *ast.Node) error { func (s *SeenTracker) checkTable(node *unstable.Node) error {
if s.currentIdx >= 0 { if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx) s.setExplicitFlag(s.currentIdx)
} }
@ -219,7 +219,7 @@ func (s *SeenTracker) checkTable(node *ast.Node) error {
return nil return nil
} }
func (s *SeenTracker) checkArrayTable(node *ast.Node) error { func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
if s.currentIdx >= 0 { if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx) s.setExplicitFlag(s.currentIdx)
} }
@ -267,7 +267,7 @@ func (s *SeenTracker) checkArrayTable(node *ast.Node) error {
return nil return nil
} }
func (s *SeenTracker) checkKeyValue(node *ast.Node) error { func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
parentIdx := s.currentIdx parentIdx := s.currentIdx
it := node.Key() it := node.Key()
@ -297,26 +297,26 @@ func (s *SeenTracker) checkKeyValue(node *ast.Node) error {
value := node.Value() value := node.Value()
switch value.Kind { switch value.Kind {
case ast.InlineTable: case unstable.InlineTable:
return s.checkInlineTable(value) return s.checkInlineTable(value)
case ast.Array: case unstable.Array:
return s.checkArray(value) return s.checkArray(value)
} }
return nil return nil
} }
func (s *SeenTracker) checkArray(node *ast.Node) error { func (s *SeenTracker) checkArray(node *unstable.Node) error {
it := node.Children() it := node.Children()
for it.Next() { for it.Next() {
n := it.Node() n := it.Node()
switch n.Kind { switch n.Kind {
case ast.InlineTable: case unstable.InlineTable:
err := s.checkInlineTable(n) err := s.checkInlineTable(n)
if err != nil { if err != nil {
return err return err
} }
case ast.Array: case unstable.Array:
err := s.checkArray(n) err := s.checkArray(n)
if err != nil { if err != nil {
return err return err
@ -326,7 +326,7 @@ func (s *SeenTracker) checkArray(node *ast.Node) error {
return nil return nil
} }
func (s *SeenTracker) checkInlineTable(node *ast.Node) error { func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
if pool.New == nil { if pool.New == nil {
pool.New = func() interface{} { pool.New = func() interface{} {
return &SeenTracker{} return &SeenTracker{}

@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"strings" "strings"
"time" "time"
"github.com/pelletier/go-toml/v2/unstable"
) )
// LocalDate represents a calendar day in no specific timezone. // LocalDate represents a calendar day in no specific timezone.
@ -75,7 +77,7 @@ func (d LocalTime) MarshalText() ([]byte, error) {
func (d *LocalTime) UnmarshalText(b []byte) error { func (d *LocalTime) UnmarshalText(b []byte) error {
res, left, err := parseLocalTime(b) res, left, err := parseLocalTime(b)
if err == nil && len(left) != 0 { if err == nil && len(left) != 0 {
err = newDecodeError(left, "extra characters") err = unstable.NewParserError(left, "extra characters")
} }
if err != nil { if err != nil {
return err return err
@ -109,7 +111,7 @@ func (d LocalDateTime) MarshalText() ([]byte, error) {
func (d *LocalDateTime) UnmarshalText(data []byte) error { func (d *LocalDateTime) UnmarshalText(data []byte) error {
res, left, err := parseLocalDateTime(data) res, left, err := parseLocalDateTime(data)
if err == nil && len(left) != 0 { if err == nil && len(left) != 0 {
err = newDecodeError(left, "extra characters") err = unstable.NewParserError(left, "extra characters")
} }
if err != nil { if err != nil {
return err return err

@ -12,6 +12,8 @@ import (
"strings" "strings"
"time" "time"
"unicode" "unicode"
"github.com/pelletier/go-toml/v2/internal/characters"
) )
// Marshal serializes a Go value as a TOML document. // Marshal serializes a Go value as a TOML document.
@ -54,7 +56,7 @@ func NewEncoder(w io.Writer) *Encoder {
// This behavior can be controlled on an individual struct field basis with the // This behavior can be controlled on an individual struct field basis with the
// inline tag: // inline tag:
// //
// MyField `inline:"true"` // MyField `toml:",inline"`
func (enc *Encoder) SetTablesInline(inline bool) *Encoder { func (enc *Encoder) SetTablesInline(inline bool) *Encoder {
enc.tablesInline = inline enc.tablesInline = inline
return enc return enc
@ -89,7 +91,7 @@ func (enc *Encoder) SetIndentTables(indent bool) *Encoder {
// //
// If v cannot be represented to TOML it returns an error. // If v cannot be represented to TOML it returns an error.
// //
// Encoding rules // # Encoding rules
// //
// A top level slice containing only maps or structs is encoded as [[table // A top level slice containing only maps or structs is encoded as [[table
// array]]. // array]].
@ -117,7 +119,20 @@ func (enc *Encoder) SetIndentTables(indent bool) *Encoder {
// When encoding structs, fields are encoded in order of definition, with their // When encoding structs, fields are encoded in order of definition, with their
// exact name. // exact name.
// //
// Struct tags // Tables and array tables are separated by empty lines. However, consecutive
// subtables definitions are not. For example:
//
// [top1]
//
// [top2]
// [top2.child1]
//
// [[array]]
//
// [[array]]
// [array.child2]
//
// # Struct tags
// //
// The encoding of each public struct field can be customized by the format // The encoding of each public struct field can be customized by the format
// string in the "toml" key of the struct field's tag. This follows // string in the "toml" key of the struct field's tag. This follows
@ -333,13 +348,13 @@ func isNil(v reflect.Value) bool {
} }
} }
func shouldOmitEmpty(options valueOptions, v reflect.Value) bool {
return options.omitempty && isEmptyValue(v)
}
func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v reflect.Value) ([]byte, error) { func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v reflect.Value) ([]byte, error) {
var err error var err error
if (ctx.options.omitempty || options.omitempty) && isEmptyValue(v) {
return b, nil
}
if !ctx.inline { if !ctx.inline {
b = enc.encodeComment(ctx.indent, options.comment, b) b = enc.encodeComment(ctx.indent, options.comment, b)
} }
@ -365,6 +380,8 @@ func (enc *Encoder) encodeKv(b []byte, ctx encoderCtx, options valueOptions, v r
func isEmptyValue(v reflect.Value) bool { func isEmptyValue(v reflect.Value) bool {
switch v.Kind() { switch v.Kind() {
case reflect.Struct:
return isEmptyStruct(v)
case reflect.Array, reflect.Map, reflect.Slice, reflect.String: case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0 return v.Len() == 0
case reflect.Bool: case reflect.Bool:
@ -381,6 +398,34 @@ func isEmptyValue(v reflect.Value) bool {
return false return false
} }
func isEmptyStruct(v reflect.Value) bool {
// TODO: merge with walkStruct and cache.
typ := v.Type()
for i := 0; i < typ.NumField(); i++ {
fieldType := typ.Field(i)
// only consider exported fields
if fieldType.PkgPath != "" {
continue
}
tag := fieldType.Tag.Get("toml")
// special field name to skip field
if tag == "-" {
continue
}
f := v.Field(i)
if !isEmptyValue(f) {
return false
}
}
return true
}
const literalQuote = '\'' const literalQuote = '\''
func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) []byte { func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) []byte {
@ -394,7 +439,7 @@ func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) []byt
func needsQuoting(v string) bool { func needsQuoting(v string) bool {
// TODO: vectorize // TODO: vectorize
for _, b := range []byte(v) { for _, b := range []byte(v) {
if b == '\'' || b == '\r' || b == '\n' || invalidAscii(b) { if b == '\'' || b == '\r' || b == '\n' || characters.InvalidAscii(b) {
return true return true
} }
} }
@ -410,7 +455,6 @@ func (enc *Encoder) encodeLiteralString(b []byte, v string) []byte {
return b return b
} }
//nolint:cyclop
func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byte { func (enc *Encoder) encodeQuotedString(multiline bool, b []byte, v string) []byte {
stringQuote := `"` stringQuote := `"`
@ -757,7 +801,13 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro
} }
ctx.skipTableHeader = false ctx.skipTableHeader = false
hasNonEmptyKV := false
for _, kv := range t.kvs { for _, kv := range t.kvs {
if shouldOmitEmpty(kv.Options, kv.Value) {
continue
}
hasNonEmptyKV = true
ctx.setKey(kv.Key) ctx.setKey(kv.Key)
b, err = enc.encodeKv(b, ctx, kv.Options, kv.Value) b, err = enc.encodeKv(b, ctx, kv.Options, kv.Value)
@ -768,7 +818,20 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro
b = append(b, '\n') b = append(b, '\n')
} }
first := true
for _, table := range t.tables { for _, table := range t.tables {
if shouldOmitEmpty(table.Options, table.Value) {
continue
}
if first {
first = false
if hasNonEmptyKV {
b = append(b, '\n')
}
} else {
b = append(b, "\n"...)
}
ctx.setKey(table.Key) ctx.setKey(table.Key)
ctx.options = table.Options ctx.options = table.Options
@ -777,8 +840,6 @@ func (enc *Encoder) encodeTable(b []byte, ctx encoderCtx, t table) ([]byte, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
b = append(b, '\n')
} }
return b, nil return b, nil
@ -791,6 +852,10 @@ func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte
first := true first := true
for _, kv := range t.kvs { for _, kv := range t.kvs {
if shouldOmitEmpty(kv.Options, kv.Value) {
continue
}
if first { if first {
first = false first = false
} else { } else {
@ -806,7 +871,7 @@ func (enc *Encoder) encodeTableInline(b []byte, ctx encoderCtx, t table) ([]byte
} }
if len(t.tables) > 0 { if len(t.tables) > 0 {
panic("inline table cannot contain nested tables, online key-values") panic("inline table cannot contain nested tables, only key-values")
} }
b = append(b, "}"...) b = append(b, "}"...)
@ -905,6 +970,10 @@ func (enc *Encoder) encodeSliceAsArrayTable(b []byte, ctx encoderCtx, v reflect.
b = enc.encodeComment(ctx.indent, ctx.options.comment, b) b = enc.encodeComment(ctx.indent, ctx.options.comment, b)
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
if i != 0 {
b = append(b, "\n"...)
}
b = append(b, scratch...) b = append(b, scratch...)
var err error var err error

@ -1,9 +1,9 @@
package toml package toml
import ( import (
"github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/danger" "github.com/pelletier/go-toml/v2/internal/danger"
"github.com/pelletier/go-toml/v2/internal/tracker" "github.com/pelletier/go-toml/v2/internal/tracker"
"github.com/pelletier/go-toml/v2/unstable"
) )
type strict struct { type strict struct {
@ -12,10 +12,10 @@ type strict struct {
// Tracks the current key being processed. // Tracks the current key being processed.
key tracker.KeyTracker key tracker.KeyTracker
missing []decodeError missing []unstable.ParserError
} }
func (s *strict) EnterTable(node *ast.Node) { func (s *strict) EnterTable(node *unstable.Node) {
if !s.Enabled { if !s.Enabled {
return return
} }
@ -23,7 +23,7 @@ func (s *strict) EnterTable(node *ast.Node) {
s.key.UpdateTable(node) s.key.UpdateTable(node)
} }
func (s *strict) EnterArrayTable(node *ast.Node) { func (s *strict) EnterArrayTable(node *unstable.Node) {
if !s.Enabled { if !s.Enabled {
return return
} }
@ -31,7 +31,7 @@ func (s *strict) EnterArrayTable(node *ast.Node) {
s.key.UpdateArrayTable(node) s.key.UpdateArrayTable(node)
} }
func (s *strict) EnterKeyValue(node *ast.Node) { func (s *strict) EnterKeyValue(node *unstable.Node) {
if !s.Enabled { if !s.Enabled {
return return
} }
@ -39,7 +39,7 @@ func (s *strict) EnterKeyValue(node *ast.Node) {
s.key.Push(node) s.key.Push(node)
} }
func (s *strict) ExitKeyValue(node *ast.Node) { func (s *strict) ExitKeyValue(node *unstable.Node) {
if !s.Enabled { if !s.Enabled {
return return
} }
@ -47,27 +47,27 @@ func (s *strict) ExitKeyValue(node *ast.Node) {
s.key.Pop(node) s.key.Pop(node)
} }
func (s *strict) MissingTable(node *ast.Node) { func (s *strict) MissingTable(node *unstable.Node) {
if !s.Enabled { if !s.Enabled {
return return
} }
s.missing = append(s.missing, decodeError{ s.missing = append(s.missing, unstable.ParserError{
highlight: keyLocation(node), Highlight: keyLocation(node),
message: "missing table", Message: "missing table",
key: s.key.Key(), Key: s.key.Key(),
}) })
} }
func (s *strict) MissingField(node *ast.Node) { func (s *strict) MissingField(node *unstable.Node) {
if !s.Enabled { if !s.Enabled {
return return
} }
s.missing = append(s.missing, decodeError{ s.missing = append(s.missing, unstable.ParserError{
highlight: keyLocation(node), Highlight: keyLocation(node),
message: "missing field", Message: "missing field",
key: s.key.Key(), Key: s.key.Key(),
}) })
} }
@ -88,7 +88,7 @@ func (s *strict) Error(doc []byte) error {
return err return err
} }
func keyLocation(node *ast.Node) []byte { func keyLocation(node *unstable.Node) []byte {
k := node.Key() k := node.Key()
hasOne := k.Next() hasOne := k.Next()

@ -6,9 +6,9 @@ import (
"time" "time"
) )
var timeType = reflect.TypeOf(time.Time{}) var timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() var textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem() var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
var mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{}) var mapStringInterfaceType = reflect.TypeOf(map[string]interface{}(nil))
var sliceInterfaceType = reflect.TypeOf([]interface{}{}) var sliceInterfaceType = reflect.TypeOf([]interface{}(nil))
var stringType = reflect.TypeOf("") var stringType = reflect.TypeOf("")

@ -12,16 +12,16 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/danger" "github.com/pelletier/go-toml/v2/internal/danger"
"github.com/pelletier/go-toml/v2/internal/tracker" "github.com/pelletier/go-toml/v2/internal/tracker"
"github.com/pelletier/go-toml/v2/unstable"
) )
// Unmarshal deserializes a TOML document into a Go value. // Unmarshal deserializes a TOML document into a Go value.
// //
// It is a shortcut for Decoder.Decode() with the default options. // It is a shortcut for Decoder.Decode() with the default options.
func Unmarshal(data []byte, v interface{}) error { func Unmarshal(data []byte, v interface{}) error {
p := parser{} p := unstable.Parser{}
p.Reset(data) p.Reset(data)
d := decoder{p: &p} d := decoder{p: &p}
@ -79,7 +79,7 @@ func (d *Decoder) DisallowUnknownFields() *Decoder {
// strict mode and a field is missing, a `toml.StrictMissingError` is // strict mode and a field is missing, a `toml.StrictMissingError` is
// returned. In any other case, this function returns a standard Go error. // returned. In any other case, this function returns a standard Go error.
// //
// Type mapping // # Type mapping
// //
// List of supported TOML types and their associated accepted Go types: // List of supported TOML types and their associated accepted Go types:
// //
@ -101,7 +101,7 @@ func (d *Decoder) Decode(v interface{}) error {
return fmt.Errorf("toml: %w", err) return fmt.Errorf("toml: %w", err)
} }
p := parser{} p := unstable.Parser{}
p.Reset(b) p.Reset(b)
dec := decoder{ dec := decoder{
p: &p, p: &p,
@ -115,7 +115,7 @@ func (d *Decoder) Decode(v interface{}) error {
type decoder struct { type decoder struct {
// Which parser instance in use for this decoding session. // Which parser instance in use for this decoding session.
p *parser p *unstable.Parser
// Flag indicating that the current expression is stashed. // Flag indicating that the current expression is stashed.
// If set to true, calling nextExpr will not actually pull a new expression // If set to true, calling nextExpr will not actually pull a new expression
@ -123,7 +123,7 @@ type decoder struct {
stashedExpr bool stashedExpr bool
// Skip expressions until a table is found. This is set to true when a // Skip expressions until a table is found. This is set to true when a
// table could not be create (missing field in map), so all KV expressions // table could not be created (missing field in map), so all KV expressions
// need to be skipped. // need to be skipped.
skipUntilTable bool skipUntilTable bool
@ -157,7 +157,7 @@ func (d *decoder) typeMismatchError(toml string, target reflect.Type) error {
return fmt.Errorf("toml: cannot decode TOML %s into a Go value of type %s", toml, target) return fmt.Errorf("toml: cannot decode TOML %s into a Go value of type %s", toml, target)
} }
func (d *decoder) expr() *ast.Node { func (d *decoder) expr() *unstable.Node {
return d.p.Expression() return d.p.Expression()
} }
@ -208,12 +208,12 @@ func (d *decoder) FromParser(v interface{}) error {
err := d.fromParser(r) err := d.fromParser(r)
if err == nil { if err == nil {
return d.strict.Error(d.p.data) return d.strict.Error(d.p.Data())
} }
var e *decodeError var e *unstable.ParserError
if errors.As(err, &e) { if errors.As(err, &e) {
return wrapDecodeError(d.p.data, e) return wrapDecodeError(d.p.Data(), e)
} }
return err return err
@ -234,16 +234,16 @@ func (d *decoder) fromParser(root reflect.Value) error {
Rules for the unmarshal code: Rules for the unmarshal code:
- The stack is used to keep track of which values need to be set where. - The stack is used to keep track of which values need to be set where.
- handle* functions <=> switch on a given ast.Kind. - handle* functions <=> switch on a given unstable.Kind.
- unmarshalX* functions need to unmarshal a node of kind X. - unmarshalX* functions need to unmarshal a node of kind X.
- An "object" is either a struct or a map. - An "object" is either a struct or a map.
*/ */
func (d *decoder) handleRootExpression(expr *ast.Node, v reflect.Value) error { func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) error {
var x reflect.Value var x reflect.Value
var err error var err error
if !(d.skipUntilTable && expr.Kind == ast.KeyValue) { if !(d.skipUntilTable && expr.Kind == unstable.KeyValue) {
err = d.seen.CheckExpression(expr) err = d.seen.CheckExpression(expr)
if err != nil { if err != nil {
return err return err
@ -251,16 +251,16 @@ func (d *decoder) handleRootExpression(expr *ast.Node, v reflect.Value) error {
} }
switch expr.Kind { switch expr.Kind {
case ast.KeyValue: case unstable.KeyValue:
if d.skipUntilTable { if d.skipUntilTable {
return nil return nil
} }
x, err = d.handleKeyValue(expr, v) x, err = d.handleKeyValue(expr, v)
case ast.Table: case unstable.Table:
d.skipUntilTable = false d.skipUntilTable = false
d.strict.EnterTable(expr) d.strict.EnterTable(expr)
x, err = d.handleTable(expr.Key(), v) x, err = d.handleTable(expr.Key(), v)
case ast.ArrayTable: case unstable.ArrayTable:
d.skipUntilTable = false d.skipUntilTable = false
d.strict.EnterArrayTable(expr) d.strict.EnterArrayTable(expr)
x, err = d.handleArrayTable(expr.Key(), v) x, err = d.handleArrayTable(expr.Key(), v)
@ -269,7 +269,7 @@ func (d *decoder) handleRootExpression(expr *ast.Node, v reflect.Value) error {
} }
if d.skipUntilTable { if d.skipUntilTable {
if expr.Kind == ast.Table || expr.Kind == ast.ArrayTable { if expr.Kind == unstable.Table || expr.Kind == unstable.ArrayTable {
d.strict.MissingTable(expr) d.strict.MissingTable(expr)
} }
} else if err == nil && x.IsValid() { } else if err == nil && x.IsValid() {
@ -279,14 +279,14 @@ func (d *decoder) handleRootExpression(expr *ast.Node, v reflect.Value) error {
return err return err
} }
func (d *decoder) handleArrayTable(key ast.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleArrayTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
if key.Next() { if key.Next() {
return d.handleArrayTablePart(key, v) return d.handleArrayTablePart(key, v)
} }
return d.handleKeyValues(v) return d.handleKeyValues(v)
} }
func (d *decoder) handleArrayTableCollectionLast(key ast.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
switch v.Kind() { switch v.Kind() {
case reflect.Interface: case reflect.Interface:
elem := v.Elem() elem := v.Elem()
@ -339,21 +339,21 @@ func (d *decoder) handleArrayTableCollectionLast(key ast.Iterator, v reflect.Val
case reflect.Array: case reflect.Array:
idx := d.arrayIndex(true, v) idx := d.arrayIndex(true, v)
if idx >= v.Len() { if idx >= v.Len() {
return v, fmt.Errorf("toml: cannot decode array table into %s at position %d", v.Type(), idx) return v, fmt.Errorf("%s at position %d", d.typeMismatchError("array table", v.Type()), idx)
} }
elem := v.Index(idx) elem := v.Index(idx)
_, err := d.handleArrayTable(key, elem) _, err := d.handleArrayTable(key, elem)
return v, err return v, err
default:
return reflect.Value{}, d.typeMismatchError("array table", v.Type())
} }
return d.handleArrayTable(key, v)
} }
// When parsing an array table expression, each part of the key needs to be // When parsing an array table expression, each part of the key needs to be
// evaluated like a normal key, but if it returns a collection, it also needs to // evaluated like a normal key, but if it returns a collection, it also needs to
// point to the last element of the collection. Unless it is the last part of // point to the last element of the collection. Unless it is the last part of
// the key, then it needs to create a new element at the end. // the key, then it needs to create a new element at the end.
func (d *decoder) handleArrayTableCollection(key ast.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleArrayTableCollection(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
if key.IsLast() { if key.IsLast() {
return d.handleArrayTableCollectionLast(key, v) return d.handleArrayTableCollectionLast(key, v)
} }
@ -390,7 +390,7 @@ func (d *decoder) handleArrayTableCollection(key ast.Iterator, v reflect.Value)
case reflect.Array: case reflect.Array:
idx := d.arrayIndex(false, v) idx := d.arrayIndex(false, v)
if idx >= v.Len() { if idx >= v.Len() {
return v, fmt.Errorf("toml: cannot decode array table into %s at position %d", v.Type(), idx) return v, fmt.Errorf("%s at position %d", d.typeMismatchError("array table", v.Type()), idx)
} }
elem := v.Index(idx) elem := v.Index(idx)
_, err := d.handleArrayTable(key, elem) _, err := d.handleArrayTable(key, elem)
@ -400,7 +400,7 @@ func (d *decoder) handleArrayTableCollection(key ast.Iterator, v reflect.Value)
return d.handleArrayTable(key, v) return d.handleArrayTable(key, v)
} }
func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handlerFn, makeFn valueMakerFn) (reflect.Value, error) { func (d *decoder) handleKeyPart(key unstable.Iterator, v reflect.Value, nextFn handlerFn, makeFn valueMakerFn) (reflect.Value, error) {
var rv reflect.Value var rv reflect.Value
// First, dispatch over v to make sure it is a valid object. // First, dispatch over v to make sure it is a valid object.
@ -483,7 +483,7 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
d.errorContext.Struct = t d.errorContext.Struct = t
d.errorContext.Field = path d.errorContext.Field = path
f := v.FieldByIndex(path) f := fieldByIndex(v, path)
x, err := nextFn(key, f) x, err := nextFn(key, f)
if err != nil || d.skipUntilTable { if err != nil || d.skipUntilTable {
return reflect.Value{}, err return reflect.Value{}, err
@ -518,7 +518,7 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
// HandleArrayTablePart navigates the Go structure v using the key v. It is // HandleArrayTablePart navigates the Go structure v using the key v. It is
// only used for the prefix (non-last) parts of an array-table. When // only used for the prefix (non-last) parts of an array-table. When
// encountering a collection, it should go to the last element. // encountering a collection, it should go to the last element.
func (d *decoder) handleArrayTablePart(key ast.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleArrayTablePart(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
var makeFn valueMakerFn var makeFn valueMakerFn
if key.IsLast() { if key.IsLast() {
makeFn = makeSliceInterface makeFn = makeSliceInterface
@ -530,10 +530,10 @@ func (d *decoder) handleArrayTablePart(key ast.Iterator, v reflect.Value) (refle
// HandleTable returns a reference when it has checked the next expression but // HandleTable returns a reference when it has checked the next expression but
// cannot handle it. // cannot handle it.
func (d *decoder) handleTable(key ast.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
if v.Kind() == reflect.Slice { if v.Kind() == reflect.Slice {
if v.Len() == 0 { if v.Len() == 0 {
return reflect.Value{}, newDecodeError(key.Node().Data, "cannot store a table in a slice") return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice")
} }
elem := v.Index(v.Len() - 1) elem := v.Index(v.Len() - 1)
x, err := d.handleTable(key, elem) x, err := d.handleTable(key, elem)
@ -560,7 +560,7 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
var rv reflect.Value var rv reflect.Value
for d.nextExpr() { for d.nextExpr() {
expr := d.expr() expr := d.expr()
if expr.Kind != ast.KeyValue { if expr.Kind != unstable.KeyValue {
// Stash the expression so that fromParser can just loop and use // Stash the expression so that fromParser can just loop and use
// the right handler. // the right handler.
// We could just recurse ourselves here, but at least this gives a // We could just recurse ourselves here, but at least this gives a
@ -587,7 +587,7 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
} }
type ( type (
handlerFn func(key ast.Iterator, v reflect.Value) (reflect.Value, error) handlerFn func(key unstable.Iterator, v reflect.Value) (reflect.Value, error)
valueMakerFn func() reflect.Value valueMakerFn func() reflect.Value
) )
@ -599,11 +599,11 @@ func makeSliceInterface() reflect.Value {
return reflect.MakeSlice(sliceInterfaceType, 0, 16) return reflect.MakeSlice(sliceInterfaceType, 0, 16)
} }
func (d *decoder) handleTablePart(key ast.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleTablePart(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
return d.handleKeyPart(key, v, d.handleTable, makeMapStringInterface) return d.handleKeyPart(key, v, d.handleTable, makeMapStringInterface)
} }
func (d *decoder) tryTextUnmarshaler(node *ast.Node, v reflect.Value) (bool, error) { func (d *decoder) tryTextUnmarshaler(node *unstable.Node, v reflect.Value) (bool, error) {
// Special case for time, because we allow to unmarshal to it from // Special case for time, because we allow to unmarshal to it from
// different kind of AST nodes. // different kind of AST nodes.
if v.Type() == timeType { if v.Type() == timeType {
@ -613,7 +613,7 @@ func (d *decoder) tryTextUnmarshaler(node *ast.Node, v reflect.Value) (bool, err
if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) { if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) {
err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data)
if err != nil { if err != nil {
return false, newDecodeError(d.p.Raw(node.Raw), "%w", err) return false, unstable.NewParserError(d.p.Raw(node.Raw), "%w", err)
} }
return true, nil return true, nil
@ -622,7 +622,7 @@ func (d *decoder) tryTextUnmarshaler(node *ast.Node, v reflect.Value) (bool, err
return false, nil return false, nil
} }
func (d *decoder) handleValue(value *ast.Node, v reflect.Value) error { func (d *decoder) handleValue(value *unstable.Node, v reflect.Value) error {
for v.Kind() == reflect.Ptr { for v.Kind() == reflect.Ptr {
v = initAndDereferencePointer(v) v = initAndDereferencePointer(v)
} }
@ -633,32 +633,32 @@ func (d *decoder) handleValue(value *ast.Node, v reflect.Value) error {
} }
switch value.Kind { switch value.Kind {
case ast.String: case unstable.String:
return d.unmarshalString(value, v) return d.unmarshalString(value, v)
case ast.Integer: case unstable.Integer:
return d.unmarshalInteger(value, v) return d.unmarshalInteger(value, v)
case ast.Float: case unstable.Float:
return d.unmarshalFloat(value, v) return d.unmarshalFloat(value, v)
case ast.Bool: case unstable.Bool:
return d.unmarshalBool(value, v) return d.unmarshalBool(value, v)
case ast.DateTime: case unstable.DateTime:
return d.unmarshalDateTime(value, v) return d.unmarshalDateTime(value, v)
case ast.LocalDate: case unstable.LocalDate:
return d.unmarshalLocalDate(value, v) return d.unmarshalLocalDate(value, v)
case ast.LocalTime: case unstable.LocalTime:
return d.unmarshalLocalTime(value, v) return d.unmarshalLocalTime(value, v)
case ast.LocalDateTime: case unstable.LocalDateTime:
return d.unmarshalLocalDateTime(value, v) return d.unmarshalLocalDateTime(value, v)
case ast.InlineTable: case unstable.InlineTable:
return d.unmarshalInlineTable(value, v) return d.unmarshalInlineTable(value, v)
case ast.Array: case unstable.Array:
return d.unmarshalArray(value, v) return d.unmarshalArray(value, v)
default: default:
panic(fmt.Errorf("handleValue not implemented for %s", value.Kind)) panic(fmt.Errorf("handleValue not implemented for %s", value.Kind))
} }
} }
func (d *decoder) unmarshalArray(array *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalArray(array *unstable.Node, v reflect.Value) error {
switch v.Kind() { switch v.Kind() {
case reflect.Slice: case reflect.Slice:
if v.IsNil() { if v.IsNil() {
@ -729,7 +729,7 @@ func (d *decoder) unmarshalArray(array *ast.Node, v reflect.Value) error {
return nil return nil
} }
func (d *decoder) unmarshalInlineTable(itable *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalInlineTable(itable *unstable.Node, v reflect.Value) error {
// Make sure v is an initialized object. // Make sure v is an initialized object.
switch v.Kind() { switch v.Kind() {
case reflect.Map: case reflect.Map:
@ -746,7 +746,7 @@ func (d *decoder) unmarshalInlineTable(itable *ast.Node, v reflect.Value) error
} }
return d.unmarshalInlineTable(itable, elem) return d.unmarshalInlineTable(itable, elem)
default: default:
return newDecodeError(itable.Data, "cannot store inline table in Go type %s", v.Kind()) return unstable.NewParserError(itable.Data, "cannot store inline table in Go type %s", v.Kind())
} }
it := itable.Children() it := itable.Children()
@ -765,7 +765,7 @@ func (d *decoder) unmarshalInlineTable(itable *ast.Node, v reflect.Value) error
return nil return nil
} }
func (d *decoder) unmarshalDateTime(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalDateTime(value *unstable.Node, v reflect.Value) error {
dt, err := parseDateTime(value.Data) dt, err := parseDateTime(value.Data)
if err != nil { if err != nil {
return err return err
@ -775,7 +775,7 @@ func (d *decoder) unmarshalDateTime(value *ast.Node, v reflect.Value) error {
return nil return nil
} }
func (d *decoder) unmarshalLocalDate(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalLocalDate(value *unstable.Node, v reflect.Value) error {
ld, err := parseLocalDate(value.Data) ld, err := parseLocalDate(value.Data)
if err != nil { if err != nil {
return err return err
@ -792,28 +792,28 @@ func (d *decoder) unmarshalLocalDate(value *ast.Node, v reflect.Value) error {
return nil return nil
} }
func (d *decoder) unmarshalLocalTime(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalLocalTime(value *unstable.Node, v reflect.Value) error {
lt, rest, err := parseLocalTime(value.Data) lt, rest, err := parseLocalTime(value.Data)
if err != nil { if err != nil {
return err return err
} }
if len(rest) > 0 { if len(rest) > 0 {
return newDecodeError(rest, "extra characters at the end of a local time") return unstable.NewParserError(rest, "extra characters at the end of a local time")
} }
v.Set(reflect.ValueOf(lt)) v.Set(reflect.ValueOf(lt))
return nil return nil
} }
func (d *decoder) unmarshalLocalDateTime(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalLocalDateTime(value *unstable.Node, v reflect.Value) error {
ldt, rest, err := parseLocalDateTime(value.Data) ldt, rest, err := parseLocalDateTime(value.Data)
if err != nil { if err != nil {
return err return err
} }
if len(rest) > 0 { if len(rest) > 0 {
return newDecodeError(rest, "extra characters at the end of a local date time") return unstable.NewParserError(rest, "extra characters at the end of a local date time")
} }
if v.Type() == timeType { if v.Type() == timeType {
@ -828,7 +828,7 @@ func (d *decoder) unmarshalLocalDateTime(value *ast.Node, v reflect.Value) error
return nil return nil
} }
func (d *decoder) unmarshalBool(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalBool(value *unstable.Node, v reflect.Value) error {
b := value.Data[0] == 't' b := value.Data[0] == 't'
switch v.Kind() { switch v.Kind() {
@ -837,13 +837,13 @@ func (d *decoder) unmarshalBool(value *ast.Node, v reflect.Value) error {
case reflect.Interface: case reflect.Interface:
v.Set(reflect.ValueOf(b)) v.Set(reflect.ValueOf(b))
default: default:
return newDecodeError(value.Data, "cannot assign boolean to a %t", b) return unstable.NewParserError(value.Data, "cannot assign boolean to a %t", b)
} }
return nil return nil
} }
func (d *decoder) unmarshalFloat(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalFloat(value *unstable.Node, v reflect.Value) error {
f, err := parseFloat(value.Data) f, err := parseFloat(value.Data)
if err != nil { if err != nil {
return err return err
@ -854,13 +854,13 @@ func (d *decoder) unmarshalFloat(value *ast.Node, v reflect.Value) error {
v.SetFloat(f) v.SetFloat(f)
case reflect.Float32: case reflect.Float32:
if f > math.MaxFloat32 { if f > math.MaxFloat32 {
return newDecodeError(value.Data, "number %f does not fit in a float32", f) return unstable.NewParserError(value.Data, "number %f does not fit in a float32", f)
} }
v.SetFloat(f) v.SetFloat(f)
case reflect.Interface: case reflect.Interface:
v.Set(reflect.ValueOf(f)) v.Set(reflect.ValueOf(f))
default: default:
return newDecodeError(value.Data, "float cannot be assigned to %s", v.Kind()) return unstable.NewParserError(value.Data, "float cannot be assigned to %s", v.Kind())
} }
return nil return nil
@ -886,7 +886,7 @@ func init() {
} }
} }
func (d *decoder) unmarshalInteger(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalInteger(value *unstable.Node, v reflect.Value) error {
i, err := parseInteger(value.Data) i, err := parseInteger(value.Data)
if err != nil { if err != nil {
return err return err
@ -967,20 +967,20 @@ func (d *decoder) unmarshalInteger(value *ast.Node, v reflect.Value) error {
return nil return nil
} }
func (d *decoder) unmarshalString(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalString(value *unstable.Node, v reflect.Value) error {
switch v.Kind() { switch v.Kind() {
case reflect.String: case reflect.String:
v.SetString(string(value.Data)) v.SetString(string(value.Data))
case reflect.Interface: case reflect.Interface:
v.Set(reflect.ValueOf(string(value.Data))) v.Set(reflect.ValueOf(string(value.Data)))
default: default:
return newDecodeError(d.p.Raw(value.Raw), "cannot store TOML string into a Go %s", v.Kind()) return unstable.NewParserError(d.p.Raw(value.Raw), "cannot store TOML string into a Go %s", v.Kind())
} }
return nil return nil
} }
func (d *decoder) handleKeyValue(expr *ast.Node, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleKeyValue(expr *unstable.Node, v reflect.Value) (reflect.Value, error) {
d.strict.EnterKeyValue(expr) d.strict.EnterKeyValue(expr)
v, err := d.handleKeyValueInner(expr.Key(), expr.Value(), v) v, err := d.handleKeyValueInner(expr.Key(), expr.Value(), v)
@ -994,7 +994,7 @@ func (d *decoder) handleKeyValue(expr *ast.Node, v reflect.Value) (reflect.Value
return v, err return v, err
} }
func (d *decoder) handleKeyValueInner(key ast.Iterator, value *ast.Node, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleKeyValueInner(key unstable.Iterator, value *unstable.Node, v reflect.Value) (reflect.Value, error) {
if key.Next() { if key.Next() {
// Still scoping the key // Still scoping the key
return d.handleKeyValuePart(key, value, v) return d.handleKeyValuePart(key, value, v)
@ -1004,7 +1004,7 @@ func (d *decoder) handleKeyValueInner(key ast.Iterator, value *ast.Node, v refle
return reflect.Value{}, d.handleValue(value, v) return reflect.Value{}, d.handleValue(value, v)
} }
func (d *decoder) handleKeyValuePart(key ast.Iterator, value *ast.Node, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node, v reflect.Value) (reflect.Value, error) {
// contains the replacement for v // contains the replacement for v
var rv reflect.Value var rv reflect.Value
@ -1071,7 +1071,7 @@ func (d *decoder) handleKeyValuePart(key ast.Iterator, value *ast.Node, v reflec
d.errorContext.Struct = t d.errorContext.Struct = t
d.errorContext.Field = path d.errorContext.Field = path
f := v.FieldByIndex(path) f := fieldByIndex(v, path)
x, err := d.handleKeyValueInner(key, value, f) x, err := d.handleKeyValueInner(key, value, f)
if err != nil { if err != nil {
return reflect.Value{}, err return reflect.Value{}, err
@ -1135,6 +1135,21 @@ func initAndDereferencePointer(v reflect.Value) reflect.Value {
return elem return elem
} }
// Same as reflect.Value.FieldByIndex, but creates pointers if needed.
func fieldByIndex(v reflect.Value, path []int) reflect.Value {
for i, x := range path {
v = v.Field(x)
if i < len(path)-1 && v.Kind() == reflect.Ptr {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
}
return v
}
type fieldPathsMap = map[string][]int type fieldPathsMap = map[string][]int
var globalFieldPathsCache atomic.Value // map[danger.TypeID]fieldPathsMap var globalFieldPathsCache atomic.Value // map[danger.TypeID]fieldPathsMap
@ -1192,7 +1207,14 @@ func forEachField(t reflect.Type, path []int, do func(name string, path []int))
} }
if f.Anonymous && name == "" { if f.Anonymous && name == "" {
forEachField(f.Type, fieldPath, do) t2 := f.Type
if t2.Kind() == reflect.Ptr {
t2 = t2.Elem()
}
if t2.Kind() == reflect.Struct {
forEachField(t2, fieldPath, do)
}
continue continue
} }

@ -1,4 +1,4 @@
package ast package unstable
import ( import (
"fmt" "fmt"
@ -7,13 +7,16 @@ import (
"github.com/pelletier/go-toml/v2/internal/danger" "github.com/pelletier/go-toml/v2/internal/danger"
) )
// Iterator starts uninitialized, you need to call Next() first. // Iterator over a sequence of nodes.
//
// Starts uninitialized, you need to call Next() first.
// //
// For example: // For example:
// //
// it := n.Children() // it := n.Children()
// for it.Next() { // for it.Next() {
// it.Node() // n := it.Node()
// // do something with n
// } // }
type Iterator struct { type Iterator struct {
started bool started bool
@ -32,42 +35,31 @@ func (c *Iterator) Next() bool {
} }
// IsLast returns true if the current node of the iterator is the last // IsLast returns true if the current node of the iterator is the last
// one. Subsequent call to Next() will return false. // one. Subsequent calls to Next() will return false.
func (c *Iterator) IsLast() bool { func (c *Iterator) IsLast() bool {
return c.node.next == 0 return c.node.next == 0
} }
// Node returns a copy of the node pointed at by the iterator. // Node returns a pointer to the node pointed at by the iterator.
func (c *Iterator) Node() *Node { func (c *Iterator) Node() *Node {
return c.node return c.node
} }
// Root contains a full AST. // Node in a TOML expression AST.
// //
// It is immutable once constructed with Builder. // Depending on Kind, its sequence of children should be interpreted
type Root struct { // differently.
nodes []Node //
} // - Array have one child per element in the array.
// - InlineTable have one child per key-value in the table (each of kind
// Iterator over the top level nodes. // InlineTable).
func (r *Root) Iterator() Iterator { // - KeyValue have at least two children. The first one is the value. The rest
it := Iterator{} // make a potentially dotted key.
if len(r.nodes) > 0 { // - Table and ArrayTable's children represent a dotted key (same as
it.node = &r.nodes[0] // KeyValue, but without the first node being the value).
} //
return it // When relevant, Raw describes the range of bytes this node is refering to in
} // the input document. Use Parser.Raw() to retrieve the actual bytes.
func (r *Root) at(idx Reference) *Node {
return &r.nodes[idx]
}
// Arrays have one child per element in the array. InlineTables have
// one child per key-value pair in the table. KeyValues have at least
// two children. The first one is the value. The rest make a
// potentially dotted key. Table and Array table have one child per
// element of the key they represent (same as KeyValue, but without
// the last node being the value).
type Node struct { type Node struct {
Kind Kind Kind Kind
Raw Range // Raw bytes from the input. Raw Range // Raw bytes from the input.
@ -80,13 +72,13 @@ type Node struct {
child int // 0 if no child child int // 0 if no child
} }
// Range of bytes in the document.
type Range struct { type Range struct {
Offset uint32 Offset uint32
Length uint32 Length uint32
} }
// Next returns a copy of the next node, or an invalid Node if there // Next returns a pointer to the next node, or nil if there is no next node.
// is no next node.
func (n *Node) Next() *Node { func (n *Node) Next() *Node {
if n.next == 0 { if n.next == 0 {
return nil return nil
@ -96,9 +88,9 @@ func (n *Node) Next() *Node {
return (*Node)(danger.Stride(ptr, size, n.next)) return (*Node)(danger.Stride(ptr, size, n.next))
} }
// Child returns a copy of the first child node of this node. Other // Child returns a pointer to the first child node of this node. Other children
// children can be accessed calling Next on the first child. Returns // can be accessed calling Next on the first child. Returns an nil if this Node
// an invalid Node if there is none. // has no child.
func (n *Node) Child() *Node { func (n *Node) Child() *Node {
if n.child == 0 { if n.child == 0 {
return nil return nil
@ -113,9 +105,9 @@ func (n *Node) Valid() bool {
return n != nil return n != nil
} }
// Key returns the child nodes making the Key on a supported // Key returns the children nodes making the Key on a supported node. Panics
// node. Panics otherwise. They are guaranteed to be all be of the // otherwise. They are guaranteed to be all be of the Kind Key. A simple key
// Kind Key. A simple key would return just one element. // would return just one element.
func (n *Node) Key() Iterator { func (n *Node) Key() Iterator {
switch n.Kind { switch n.Kind {
case KeyValue: case KeyValue:

@ -0,0 +1,71 @@
package unstable
// root contains a full AST.
//
// It is immutable once constructed with Builder.
type root struct {
nodes []Node
}
// Iterator over the top level nodes.
func (r *root) Iterator() Iterator {
it := Iterator{}
if len(r.nodes) > 0 {
it.node = &r.nodes[0]
}
return it
}
func (r *root) at(idx reference) *Node {
return &r.nodes[idx]
}
type reference int
const invalidReference reference = -1
func (r reference) Valid() bool {
return r != invalidReference
}
type builder struct {
tree root
lastIdx int
}
func (b *builder) Tree() *root {
return &b.tree
}
func (b *builder) NodeAt(ref reference) *Node {
return b.tree.at(ref)
}
func (b *builder) Reset() {
b.tree.nodes = b.tree.nodes[:0]
b.lastIdx = 0
}
func (b *builder) Push(n Node) reference {
b.lastIdx = len(b.tree.nodes)
b.tree.nodes = append(b.tree.nodes, n)
return reference(b.lastIdx)
}
func (b *builder) PushAndChain(n Node) reference {
newIdx := len(b.tree.nodes)
b.tree.nodes = append(b.tree.nodes, n)
if b.lastIdx >= 0 {
b.tree.nodes[b.lastIdx].next = newIdx - b.lastIdx
}
b.lastIdx = newIdx
return reference(b.lastIdx)
}
func (b *builder) AttachChild(parent reference, child reference) {
b.tree.nodes[parent].child = int(child) - int(parent)
}
func (b *builder) Chain(from reference, to reference) {
b.tree.nodes[from].next = int(to) - int(from)
}

@ -0,0 +1,3 @@
// Package unstable provides APIs that do not meet the backward compatibility
// guarantees yet.
package unstable

@ -1,25 +1,26 @@
package ast package unstable
import "fmt" import "fmt"
// Kind represents the type of TOML structure contained in a given Node.
type Kind int type Kind int
const ( const (
// meta // Meta
Invalid Kind = iota Invalid Kind = iota
Comment Comment
Key Key
// top level structures // Top level structures
Table Table
ArrayTable ArrayTable
KeyValue KeyValue
// containers values // Containers values
Array Array
InlineTable InlineTable
// values // Values
String String
Bool Bool
Float Float
@ -30,6 +31,7 @@ const (
DateTime DateTime
) )
// String implementation of fmt.Stringer.
func (k Kind) String() string { func (k Kind) String() string {
switch k { switch k {
case Invalid: case Invalid:

@ -1,50 +1,108 @@
package toml package unstable
import ( import (
"bytes" "bytes"
"fmt"
"unicode" "unicode"
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/characters"
"github.com/pelletier/go-toml/v2/internal/danger" "github.com/pelletier/go-toml/v2/internal/danger"
) )
type parser struct { // ParserError describes an error relative to the content of the document.
builder ast.Builder //
ref ast.Reference // It cannot outlive the instance of Parser it refers to, and may cause panics
// if the parser is reset.
type ParserError struct {
Highlight []byte
Message string
Key []string // optional
}
// Error is the implementation of the error interface.
func (e *ParserError) Error() string {
return e.Message
}
// NewParserError is a convenience function to create a ParserError
//
// Warning: Highlight needs to be a subslice of Parser.data, so only slices
// returned by Parser.Raw are valid candidates.
func NewParserError(highlight []byte, format string, args ...interface{}) error {
return &ParserError{
Highlight: highlight,
Message: fmt.Errorf(format, args...).Error(),
}
}
// Parser scans over a TOML-encoded document and generates an iterative AST.
//
// To prime the Parser, first reset it with the contents of a TOML document.
// Then, process all top-level expressions sequentially. See Example.
//
// Don't forget to check Error() after you're done parsing.
//
// Each top-level expression needs to be fully processed before calling
// NextExpression() again. Otherwise, calls to various Node methods may panic if
// the parser has moved on the next expression.
//
// For performance reasons, go-toml doesn't make a copy of the input bytes to
// the parser. Make sure to copy all the bytes you need to outlive the slice
// given to the parser.
//
// The parser doesn't provide nodes for comments yet, nor for whitespace.
type Parser struct {
data []byte data []byte
builder builder
ref reference
left []byte left []byte
err error err error
first bool first bool
} }
func (p *parser) Range(b []byte) ast.Range { // Data returns the slice provided to the last call to Reset.
return ast.Range{ func (p *Parser) Data() []byte {
return p.data
}
// Range returns a range description that corresponds to a given slice of the
// input. If the argument is not a subslice of the parser input, this function
// panics.
func (p *Parser) Range(b []byte) Range {
return Range{
Offset: uint32(danger.SubsliceOffset(p.data, b)), Offset: uint32(danger.SubsliceOffset(p.data, b)),
Length: uint32(len(b)), Length: uint32(len(b)),
} }
} }
func (p *parser) Raw(raw ast.Range) []byte { // Raw returns the slice corresponding to the bytes in the given range.
func (p *Parser) Raw(raw Range) []byte {
return p.data[raw.Offset : raw.Offset+raw.Length] return p.data[raw.Offset : raw.Offset+raw.Length]
} }
func (p *parser) Reset(b []byte) { // Reset brings the parser to its initial state for a given input. It wipes an
// reuses internal storage to reduce allocation.
func (p *Parser) Reset(b []byte) {
p.builder.Reset() p.builder.Reset()
p.ref = ast.InvalidReference p.ref = invalidReference
p.data = b p.data = b
p.left = b p.left = b
p.err = nil p.err = nil
p.first = true p.first = true
} }
//nolint:cyclop // NextExpression parses the next top-level expression. If an expression was
func (p *parser) NextExpression() bool { // successfully parsed, it returns true. If the parser is at the end of the
// document or an error occurred, it returns false.
//
// Retrieve the parsed expression with Expression().
func (p *Parser) NextExpression() bool {
if len(p.left) == 0 || p.err != nil { if len(p.left) == 0 || p.err != nil {
return false return false
} }
p.builder.Reset() p.builder.Reset()
p.ref = ast.InvalidReference p.ref = invalidReference
for { for {
if len(p.left) == 0 || p.err != nil { if len(p.left) == 0 || p.err != nil {
@ -73,15 +131,18 @@ func (p *parser) NextExpression() bool {
} }
} }
func (p *parser) Expression() *ast.Node { // Expression returns a pointer to the node representing the last successfully
// parsed expresion.
func (p *Parser) Expression() *Node {
return p.builder.NodeAt(p.ref) return p.builder.NodeAt(p.ref)
} }
func (p *parser) Error() error { // Error returns any error that has occured during parsing.
func (p *Parser) Error() error {
return p.err return p.err
} }
func (p *parser) parseNewline(b []byte) ([]byte, error) { func (p *Parser) parseNewline(b []byte) ([]byte, error) {
if b[0] == '\n' { if b[0] == '\n' {
return b[1:], nil return b[1:], nil
} }
@ -91,14 +152,14 @@ func (p *parser) parseNewline(b []byte) ([]byte, error) {
return rest, err return rest, err
} }
return nil, newDecodeError(b[0:1], "expected newline but got %#U", b[0]) return nil, NewParserError(b[0:1], "expected newline but got %#U", b[0])
} }
func (p *parser) parseExpression(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseExpression(b []byte) (reference, []byte, error) {
// expression = ws [ comment ] // expression = ws [ comment ]
// expression =/ ws keyval ws [ comment ] // expression =/ ws keyval ws [ comment ]
// expression =/ ws table ws [ comment ] // expression =/ ws table ws [ comment ]
ref := ast.InvalidReference ref := invalidReference
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
@ -136,7 +197,7 @@ func (p *parser) parseExpression(b []byte) (ast.Reference, []byte, error) {
return ref, b, nil return ref, b, nil
} }
func (p *parser) parseTable(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseTable(b []byte) (reference, []byte, error) {
// table = std-table / array-table // table = std-table / array-table
if len(b) > 1 && b[1] == '[' { if len(b) > 1 && b[1] == '[' {
return p.parseArrayTable(b) return p.parseArrayTable(b)
@ -145,12 +206,12 @@ func (p *parser) parseTable(b []byte) (ast.Reference, []byte, error) {
return p.parseStdTable(b) return p.parseStdTable(b)
} }
func (p *parser) parseArrayTable(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseArrayTable(b []byte) (reference, []byte, error) {
// array-table = array-table-open key array-table-close // array-table = array-table-open key array-table-close
// array-table-open = %x5B.5B ws ; [[ Double left square bracket // array-table-open = %x5B.5B ws ; [[ Double left square bracket
// array-table-close = ws %x5D.5D ; ]] Double right square bracket // array-table-close = ws %x5D.5D ; ]] Double right square bracket
ref := p.builder.Push(ast.Node{ ref := p.builder.Push(Node{
Kind: ast.ArrayTable, Kind: ArrayTable,
}) })
b = b[2:] b = b[2:]
@ -174,12 +235,12 @@ func (p *parser) parseArrayTable(b []byte) (ast.Reference, []byte, error) {
return ref, b, err return ref, b, err
} }
func (p *parser) parseStdTable(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseStdTable(b []byte) (reference, []byte, error) {
// std-table = std-table-open key std-table-close // std-table = std-table-open key std-table-close
// std-table-open = %x5B ws ; [ Left square bracket // std-table-open = %x5B ws ; [ Left square bracket
// std-table-close = ws %x5D ; ] Right square bracket // std-table-close = ws %x5D ; ] Right square bracket
ref := p.builder.Push(ast.Node{ ref := p.builder.Push(Node{
Kind: ast.Table, Kind: Table,
}) })
b = b[1:] b = b[1:]
@ -199,15 +260,15 @@ func (p *parser) parseStdTable(b []byte) (ast.Reference, []byte, error) {
return ref, b, err return ref, b, err
} }
func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseKeyval(b []byte) (reference, []byte, error) {
// keyval = key keyval-sep val // keyval = key keyval-sep val
ref := p.builder.Push(ast.Node{ ref := p.builder.Push(Node{
Kind: ast.KeyValue, Kind: KeyValue,
}) })
key, b, err := p.parseKey(b) key, b, err := p.parseKey(b)
if err != nil { if err != nil {
return ast.InvalidReference, nil, err return invalidReference, nil, err
} }
// keyval-sep = ws %x3D ws ; = // keyval-sep = ws %x3D ws ; =
@ -215,12 +276,12 @@ func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) {
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
if len(b) == 0 { if len(b) == 0 {
return ast.InvalidReference, nil, newDecodeError(b, "expected = after a key, but the document ends there") return invalidReference, nil, NewParserError(b, "expected = after a key, but the document ends there")
} }
b, err = expect('=', b) b, err = expect('=', b)
if err != nil { if err != nil {
return ast.InvalidReference, nil, err return invalidReference, nil, err
} }
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
@ -237,12 +298,12 @@ func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) {
} }
//nolint:cyclop,funlen //nolint:cyclop,funlen
func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseVal(b []byte) (reference, []byte, error) {
// val = string / boolean / array / inline-table / date-time / float / integer // val = string / boolean / array / inline-table / date-time / float / integer
ref := ast.InvalidReference ref := invalidReference
if len(b) == 0 { if len(b) == 0 {
return ref, nil, newDecodeError(b, "expected value, not eof") return ref, nil, NewParserError(b, "expected value, not eof")
} }
var err error var err error
@ -259,8 +320,8 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
} }
if err == nil { if err == nil {
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(Node{
Kind: ast.String, Kind: String,
Raw: p.Range(raw), Raw: p.Range(raw),
Data: v, Data: v,
}) })
@ -277,8 +338,8 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
} }
if err == nil { if err == nil {
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(Node{
Kind: ast.String, Kind: String,
Raw: p.Range(raw), Raw: p.Range(raw),
Data: v, Data: v,
}) })
@ -287,22 +348,22 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
return ref, b, err return ref, b, err
case 't': case 't':
if !scanFollowsTrue(b) { if !scanFollowsTrue(b) {
return ref, nil, newDecodeError(atmost(b, 4), "expected 'true'") return ref, nil, NewParserError(atmost(b, 4), "expected 'true'")
} }
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(Node{
Kind: ast.Bool, Kind: Bool,
Data: b[:4], Data: b[:4],
}) })
return ref, b[4:], nil return ref, b[4:], nil
case 'f': case 'f':
if !scanFollowsFalse(b) { if !scanFollowsFalse(b) {
return ref, nil, newDecodeError(atmost(b, 5), "expected 'false'") return ref, nil, NewParserError(atmost(b, 5), "expected 'false'")
} }
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(Node{
Kind: ast.Bool, Kind: Bool,
Data: b[:5], Data: b[:5],
}) })
@ -324,7 +385,7 @@ func atmost(b []byte, n int) []byte {
return b[:n] return b[:n]
} }
func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, []byte, error) { func (p *Parser) parseLiteralString(b []byte) ([]byte, []byte, []byte, error) {
v, rest, err := scanLiteralString(b) v, rest, err := scanLiteralString(b)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
@ -333,19 +394,19 @@ func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, []byte, error) {
return v, v[1 : len(v)-1], rest, nil return v, v[1 : len(v)-1], rest, nil
} }
func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseInlineTable(b []byte) (reference, []byte, error) {
// inline-table = inline-table-open [ inline-table-keyvals ] inline-table-close // inline-table = inline-table-open [ inline-table-keyvals ] inline-table-close
// inline-table-open = %x7B ws ; { // inline-table-open = %x7B ws ; {
// inline-table-close = ws %x7D ; } // inline-table-close = ws %x7D ; }
// inline-table-sep = ws %x2C ws ; , Comma // inline-table-sep = ws %x2C ws ; , Comma
// inline-table-keyvals = keyval [ inline-table-sep inline-table-keyvals ] // inline-table-keyvals = keyval [ inline-table-sep inline-table-keyvals ]
parent := p.builder.Push(ast.Node{ parent := p.builder.Push(Node{
Kind: ast.InlineTable, Kind: InlineTable,
}) })
first := true first := true
var child ast.Reference var child reference
b = b[1:] b = b[1:]
@ -356,7 +417,7 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) {
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
if len(b) == 0 { if len(b) == 0 {
return parent, nil, newDecodeError(previousB[:1], "inline table is incomplete") return parent, nil, NewParserError(previousB[:1], "inline table is incomplete")
} }
if b[0] == '}' { if b[0] == '}' {
@ -371,7 +432,7 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) {
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
} }
var kv ast.Reference var kv reference
kv, b, err = p.parseKeyval(b) kv, b, err = p.parseKeyval(b)
if err != nil { if err != nil {
@ -394,7 +455,7 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) {
} }
//nolint:funlen,cyclop //nolint:funlen,cyclop
func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseValArray(b []byte) (reference, []byte, error) {
// array = array-open [ array-values ] ws-comment-newline array-close // array = array-open [ array-values ] ws-comment-newline array-close
// array-open = %x5B ; [ // array-open = %x5B ; [
// array-close = %x5D ; ] // array-close = %x5D ; ]
@ -405,13 +466,13 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
arrayStart := b arrayStart := b
b = b[1:] b = b[1:]
parent := p.builder.Push(ast.Node{ parent := p.builder.Push(Node{
Kind: ast.Array, Kind: Array,
}) })
first := true first := true
var lastChild ast.Reference var lastChild reference
var err error var err error
for len(b) > 0 { for len(b) > 0 {
@ -421,7 +482,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
} }
if len(b) == 0 { if len(b) == 0 {
return parent, nil, newDecodeError(arrayStart[:1], "array is incomplete") return parent, nil, NewParserError(arrayStart[:1], "array is incomplete")
} }
if b[0] == ']' { if b[0] == ']' {
@ -430,7 +491,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
if b[0] == ',' { if b[0] == ',' {
if first { if first {
return parent, nil, newDecodeError(b[0:1], "array cannot start with comma") return parent, nil, NewParserError(b[0:1], "array cannot start with comma")
} }
b = b[1:] b = b[1:]
@ -439,7 +500,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
return parent, nil, err return parent, nil, err
} }
} else if !first { } else if !first {
return parent, nil, newDecodeError(b[0:1], "array elements must be separated by commas") return parent, nil, NewParserError(b[0:1], "array elements must be separated by commas")
} }
// TOML allows trailing commas in arrays. // TOML allows trailing commas in arrays.
@ -447,7 +508,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
break break
} }
var valueRef ast.Reference var valueRef reference
valueRef, b, err = p.parseVal(b) valueRef, b, err = p.parseVal(b)
if err != nil { if err != nil {
return parent, nil, err return parent, nil, err
@ -472,7 +533,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
return parent, rest, err return parent, rest, err
} }
func (p *parser) parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error) { func (p *Parser) parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error) {
for len(b) > 0 { for len(b) > 0 {
var err error var err error
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
@ -501,7 +562,7 @@ func (p *parser) parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error)
return b, nil return b, nil
} }
func (p *parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, []byte, error) { func (p *Parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, []byte, error) {
token, rest, err := scanMultilineLiteralString(b) token, rest, err := scanMultilineLiteralString(b)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
@ -520,7 +581,7 @@ func (p *parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, []byte,
} }
//nolint:funlen,gocognit,cyclop //nolint:funlen,gocognit,cyclop
func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, error) { func (p *Parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, error) {
// ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body // ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body
// ml-basic-string-delim // ml-basic-string-delim
// ml-basic-string-delim = 3quotation-mark // ml-basic-string-delim = 3quotation-mark
@ -551,11 +612,11 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
if !escaped { if !escaped {
str := token[startIdx:endIdx] str := token[startIdx:endIdx]
verr := utf8TomlValidAlreadyEscaped(str) verr := characters.Utf8TomlValidAlreadyEscaped(str)
if verr.Zero() { if verr.Zero() {
return token, str, rest, nil return token, str, rest, nil
} }
return nil, nil, nil, newDecodeError(str[verr.Index:verr.Index+verr.Size], "invalid UTF-8") return nil, nil, nil, NewParserError(str[verr.Index:verr.Index+verr.Size], "invalid UTF-8")
} }
var builder bytes.Buffer var builder bytes.Buffer
@ -635,13 +696,13 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
builder.WriteRune(x) builder.WriteRune(x)
i += 8 i += 8
default: default:
return nil, nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c) return nil, nil, nil, NewParserError(token[i:i+1], "invalid escaped character %#U", c)
} }
i++ i++
} else { } else {
size := utf8ValidNext(token[i:]) size := characters.Utf8ValidNext(token[i:])
if size == 0 { if size == 0 {
return nil, nil, nil, newDecodeError(token[i:i+1], "invalid character %#U", c) return nil, nil, nil, NewParserError(token[i:i+1], "invalid character %#U", c)
} }
builder.Write(token[i : i+size]) builder.Write(token[i : i+size])
i += size i += size
@ -651,7 +712,7 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
return token, builder.Bytes(), rest, nil return token, builder.Bytes(), rest, nil
} }
func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseKey(b []byte) (reference, []byte, error) {
// key = simple-key / dotted-key // key = simple-key / dotted-key
// simple-key = quoted-key / unquoted-key // simple-key = quoted-key / unquoted-key
// //
@ -662,11 +723,11 @@ func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) {
// dot-sep = ws %x2E ws ; . Period // dot-sep = ws %x2E ws ; . Period
raw, key, b, err := p.parseSimpleKey(b) raw, key, b, err := p.parseSimpleKey(b)
if err != nil { if err != nil {
return ast.InvalidReference, nil, err return invalidReference, nil, err
} }
ref := p.builder.Push(ast.Node{ ref := p.builder.Push(Node{
Kind: ast.Key, Kind: Key,
Raw: p.Range(raw), Raw: p.Range(raw),
Data: key, Data: key,
}) })
@ -681,8 +742,8 @@ func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) {
return ref, nil, err return ref, nil, err
} }
p.builder.PushAndChain(ast.Node{ p.builder.PushAndChain(Node{
Kind: ast.Key, Kind: Key,
Raw: p.Range(raw), Raw: p.Range(raw),
Data: key, Data: key,
}) })
@ -694,9 +755,9 @@ func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) {
return ref, b, nil return ref, b, nil
} }
func (p *parser) parseSimpleKey(b []byte) (raw, key, rest []byte, err error) { func (p *Parser) parseSimpleKey(b []byte) (raw, key, rest []byte, err error) {
if len(b) == 0 { if len(b) == 0 {
return nil, nil, nil, newDecodeError(b, "expected key but found none") return nil, nil, nil, NewParserError(b, "expected key but found none")
} }
// simple-key = quoted-key / unquoted-key // simple-key = quoted-key / unquoted-key
@ -711,12 +772,12 @@ func (p *parser) parseSimpleKey(b []byte) (raw, key, rest []byte, err error) {
key, rest = scanUnquotedKey(b) key, rest = scanUnquotedKey(b)
return key, key, rest, nil return key, key, rest, nil
default: default:
return nil, nil, nil, newDecodeError(b[0:1], "invalid character at start of key: %c", b[0]) return nil, nil, nil, NewParserError(b[0:1], "invalid character at start of key: %c", b[0])
} }
} }
//nolint:funlen,cyclop //nolint:funlen,cyclop
func (p *parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) { func (p *Parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) {
// basic-string = quotation-mark *basic-char quotation-mark // basic-string = quotation-mark *basic-char quotation-mark
// quotation-mark = %x22 ; " // quotation-mark = %x22 ; "
// basic-char = basic-unescaped / escaped // basic-char = basic-unescaped / escaped
@ -744,11 +805,11 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) {
// validate the string and return a direct reference to the buffer. // validate the string and return a direct reference to the buffer.
if !escaped { if !escaped {
str := token[startIdx:endIdx] str := token[startIdx:endIdx]
verr := utf8TomlValidAlreadyEscaped(str) verr := characters.Utf8TomlValidAlreadyEscaped(str)
if verr.Zero() { if verr.Zero() {
return token, str, rest, nil return token, str, rest, nil
} }
return nil, nil, nil, newDecodeError(str[verr.Index:verr.Index+verr.Size], "invalid UTF-8") return nil, nil, nil, NewParserError(str[verr.Index:verr.Index+verr.Size], "invalid UTF-8")
} }
i := startIdx i := startIdx
@ -795,13 +856,13 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) {
builder.WriteRune(x) builder.WriteRune(x)
i += 8 i += 8
default: default:
return nil, nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c) return nil, nil, nil, NewParserError(token[i:i+1], "invalid escaped character %#U", c)
} }
i++ i++
} else { } else {
size := utf8ValidNext(token[i:]) size := characters.Utf8ValidNext(token[i:])
if size == 0 { if size == 0 {
return nil, nil, nil, newDecodeError(token[i:i+1], "invalid character %#U", c) return nil, nil, nil, NewParserError(token[i:i+1], "invalid character %#U", c)
} }
builder.Write(token[i : i+size]) builder.Write(token[i : i+size])
i += size i += size
@ -813,7 +874,7 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) {
func hexToRune(b []byte, length int) (rune, error) { func hexToRune(b []byte, length int) (rune, error) {
if len(b) < length { if len(b) < length {
return -1, newDecodeError(b, "unicode point needs %d character, not %d", length, len(b)) return -1, NewParserError(b, "unicode point needs %d character, not %d", length, len(b))
} }
b = b[:length] b = b[:length]
@ -828,19 +889,19 @@ func hexToRune(b []byte, length int) (rune, error) {
case 'A' <= c && c <= 'F': case 'A' <= c && c <= 'F':
d = uint32(c - 'A' + 10) d = uint32(c - 'A' + 10)
default: default:
return -1, newDecodeError(b[i:i+1], "non-hex character") return -1, NewParserError(b[i:i+1], "non-hex character")
} }
r = r*16 + d r = r*16 + d
} }
if r > unicode.MaxRune || 0xD800 <= r && r < 0xE000 { if r > unicode.MaxRune || 0xD800 <= r && r < 0xE000 {
return -1, newDecodeError(b, "escape sequence is invalid Unicode code point") return -1, NewParserError(b, "escape sequence is invalid Unicode code point")
} }
return rune(r), nil return rune(r), nil
} }
func (p *parser) parseWhitespace(b []byte) []byte { func (p *Parser) parseWhitespace(b []byte) []byte {
// ws = *wschar // ws = *wschar
// wschar = %x20 ; Space // wschar = %x20 ; Space
// wschar =/ %x09 ; Horizontal tab // wschar =/ %x09 ; Horizontal tab
@ -850,24 +911,24 @@ func (p *parser) parseWhitespace(b []byte) []byte {
} }
//nolint:cyclop //nolint:cyclop
func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseIntOrFloatOrDateTime(b []byte) (reference, []byte, error) {
switch b[0] { switch b[0] {
case 'i': case 'i':
if !scanFollowsInf(b) { if !scanFollowsInf(b) {
return ast.InvalidReference, nil, newDecodeError(atmost(b, 3), "expected 'inf'") return invalidReference, nil, NewParserError(atmost(b, 3), "expected 'inf'")
} }
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: ast.Float, Kind: Float,
Data: b[:3], Data: b[:3],
}), b[3:], nil }), b[3:], nil
case 'n': case 'n':
if !scanFollowsNan(b) { if !scanFollowsNan(b) {
return ast.InvalidReference, nil, newDecodeError(atmost(b, 3), "expected 'nan'") return invalidReference, nil, NewParserError(atmost(b, 3), "expected 'nan'")
} }
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: ast.Float, Kind: Float,
Data: b[:3], Data: b[:3],
}), b[3:], nil }), b[3:], nil
case '+', '-': case '+', '-':
@ -898,7 +959,7 @@ func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, err
return p.scanIntOrFloat(b) return p.scanIntOrFloat(b)
} }
func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) { func (p *Parser) scanDateTime(b []byte) (reference, []byte, error) {
// scans for contiguous characters in [0-9T:Z.+-], and up to one space if // scans for contiguous characters in [0-9T:Z.+-], and up to one space if
// followed by a digit. // followed by a digit.
hasDate := false hasDate := false
@ -941,30 +1002,30 @@ byteLoop:
} }
} }
var kind ast.Kind var kind Kind
if hasTime { if hasTime {
if hasDate { if hasDate {
if hasTz { if hasTz {
kind = ast.DateTime kind = DateTime
} else { } else {
kind = ast.LocalDateTime kind = LocalDateTime
} }
} else { } else {
kind = ast.LocalTime kind = LocalTime
} }
} else { } else {
kind = ast.LocalDate kind = LocalDate
} }
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: kind, Kind: kind,
Data: b[:i], Data: b[:i],
}), b[i:], nil }), b[i:], nil
} }
//nolint:funlen,gocognit,cyclop //nolint:funlen,gocognit,cyclop
func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) { func (p *Parser) scanIntOrFloat(b []byte) (reference, []byte, error) {
i := 0 i := 0
if len(b) > 2 && b[0] == '0' && b[1] != '.' && b[1] != 'e' && b[1] != 'E' { if len(b) > 2 && b[0] == '0' && b[1] != '.' && b[1] != 'e' && b[1] != 'E' {
@ -990,8 +1051,8 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
} }
} }
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: ast.Integer, Kind: Integer,
Data: b[:i], Data: b[:i],
}), b[i:], nil }), b[i:], nil
} }
@ -1013,40 +1074,40 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
if c == 'i' { if c == 'i' {
if scanFollowsInf(b[i:]) { if scanFollowsInf(b[i:]) {
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: ast.Float, Kind: Float,
Data: b[:i+3], Data: b[:i+3],
}), b[i+3:], nil }), b[i+3:], nil
} }
return ast.InvalidReference, nil, newDecodeError(b[i:i+1], "unexpected character 'i' while scanning for a number") return invalidReference, nil, NewParserError(b[i:i+1], "unexpected character 'i' while scanning for a number")
} }
if c == 'n' { if c == 'n' {
if scanFollowsNan(b[i:]) { if scanFollowsNan(b[i:]) {
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: ast.Float, Kind: Float,
Data: b[:i+3], Data: b[:i+3],
}), b[i+3:], nil }), b[i+3:], nil
} }
return ast.InvalidReference, nil, newDecodeError(b[i:i+1], "unexpected character 'n' while scanning for a number") return invalidReference, nil, NewParserError(b[i:i+1], "unexpected character 'n' while scanning for a number")
} }
break break
} }
if i == 0 { if i == 0 {
return ast.InvalidReference, b, newDecodeError(b, "incomplete number") return invalidReference, b, NewParserError(b, "incomplete number")
} }
kind := ast.Integer kind := Integer
if isFloat { if isFloat {
kind = ast.Float kind = Float
} }
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: kind, Kind: kind,
Data: b[:i], Data: b[:i],
}), b[i:], nil }), b[i:], nil
@ -1075,11 +1136,11 @@ func isValidBinaryRune(r byte) bool {
func expect(x byte, b []byte) ([]byte, error) { func expect(x byte, b []byte) ([]byte, error) {
if len(b) == 0 { if len(b) == 0 {
return nil, newDecodeError(b, "expected character %c but the document ended here", x) return nil, NewParserError(b, "expected character %c but the document ended here", x)
} }
if b[0] != x { if b[0] != x {
return nil, newDecodeError(b[0:1], "expected character %c", x) return nil, NewParserError(b[0:1], "expected character %c", x)
} }
return b[1:], nil return b[1:], nil

@ -1,4 +1,6 @@
package toml package unstable
import "github.com/pelletier/go-toml/v2/internal/characters"
func scanFollows(b []byte, pattern string) bool { func scanFollows(b []byte, pattern string) bool {
n := len(pattern) n := len(pattern)
@ -54,16 +56,16 @@ func scanLiteralString(b []byte) ([]byte, []byte, error) {
case '\'': case '\'':
return b[:i+1], b[i+1:], nil return b[:i+1], b[i+1:], nil
case '\n', '\r': case '\n', '\r':
return nil, nil, newDecodeError(b[i:i+1], "literal strings cannot have new lines") return nil, nil, NewParserError(b[i:i+1], "literal strings cannot have new lines")
} }
size := utf8ValidNext(b[i:]) size := characters.Utf8ValidNext(b[i:])
if size == 0 { if size == 0 {
return nil, nil, newDecodeError(b[i:i+1], "invalid character") return nil, nil, NewParserError(b[i:i+1], "invalid character")
} }
i += size i += size
} }
return nil, nil, newDecodeError(b[len(b):], "unterminated literal string") return nil, nil, NewParserError(b[len(b):], "unterminated literal string")
} }
func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) { func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) {
@ -98,39 +100,39 @@ func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) {
i++ i++
if i < len(b) && b[i] == '\'' { if i < len(b) && b[i] == '\'' {
return nil, nil, newDecodeError(b[i-3:i+1], "''' not allowed in multiline literal string") return nil, nil, NewParserError(b[i-3:i+1], "''' not allowed in multiline literal string")
} }
return b[:i], b[i:], nil return b[:i], b[i:], nil
} }
case '\r': case '\r':
if len(b) < i+2 { if len(b) < i+2 {
return nil, nil, newDecodeError(b[len(b):], `need a \n after \r`) return nil, nil, NewParserError(b[len(b):], `need a \n after \r`)
} }
if b[i+1] != '\n' { if b[i+1] != '\n' {
return nil, nil, newDecodeError(b[i:i+2], `need a \n after \r`) return nil, nil, NewParserError(b[i:i+2], `need a \n after \r`)
} }
i += 2 // skip the \n i += 2 // skip the \n
continue continue
} }
size := utf8ValidNext(b[i:]) size := characters.Utf8ValidNext(b[i:])
if size == 0 { if size == 0 {
return nil, nil, newDecodeError(b[i:i+1], "invalid character") return nil, nil, NewParserError(b[i:i+1], "invalid character")
} }
i += size i += size
} }
return nil, nil, newDecodeError(b[len(b):], `multiline literal string not terminated by '''`) return nil, nil, NewParserError(b[len(b):], `multiline literal string not terminated by '''`)
} }
func scanWindowsNewline(b []byte) ([]byte, []byte, error) { func scanWindowsNewline(b []byte) ([]byte, []byte, error) {
const lenCRLF = 2 const lenCRLF = 2
if len(b) < lenCRLF { if len(b) < lenCRLF {
return nil, nil, newDecodeError(b, "windows new line expected") return nil, nil, NewParserError(b, "windows new line expected")
} }
if b[1] != '\n' { if b[1] != '\n' {
return nil, nil, newDecodeError(b, `windows new line should be \r\n`) return nil, nil, NewParserError(b, `windows new line should be \r\n`)
} }
return b[:lenCRLF], b[lenCRLF:], nil return b[:lenCRLF], b[lenCRLF:], nil
@ -165,11 +167,11 @@ func scanComment(b []byte) ([]byte, []byte, error) {
if i+1 < len(b) && b[i+1] == '\n' { if i+1 < len(b) && b[i+1] == '\n' {
return b[:i+1], b[i+1:], nil return b[:i+1], b[i+1:], nil
} }
return nil, nil, newDecodeError(b[i:i+1], "invalid character in comment") return nil, nil, NewParserError(b[i:i+1], "invalid character in comment")
} }
size := utf8ValidNext(b[i:]) size := characters.Utf8ValidNext(b[i:])
if size == 0 { if size == 0 {
return nil, nil, newDecodeError(b[i:i+1], "invalid character in comment") return nil, nil, NewParserError(b[i:i+1], "invalid character in comment")
} }
i += size i += size
@ -192,17 +194,17 @@ func scanBasicString(b []byte) ([]byte, bool, []byte, error) {
case '"': case '"':
return b[:i+1], escaped, b[i+1:], nil return b[:i+1], escaped, b[i+1:], nil
case '\n', '\r': case '\n', '\r':
return nil, escaped, nil, newDecodeError(b[i:i+1], "basic strings cannot have new lines") return nil, escaped, nil, NewParserError(b[i:i+1], "basic strings cannot have new lines")
case '\\': case '\\':
if len(b) < i+2 { if len(b) < i+2 {
return nil, escaped, nil, newDecodeError(b[i:i+1], "need a character after \\") return nil, escaped, nil, NewParserError(b[i:i+1], "need a character after \\")
} }
escaped = true escaped = true
i++ // skip the next character i++ // skip the next character
} }
} }
return nil, escaped, nil, newDecodeError(b[len(b):], `basic string not terminated by "`) return nil, escaped, nil, NewParserError(b[len(b):], `basic string not terminated by "`)
} }
func scanMultilineBasicString(b []byte) ([]byte, bool, []byte, error) { func scanMultilineBasicString(b []byte) ([]byte, bool, []byte, error) {
@ -243,27 +245,27 @@ func scanMultilineBasicString(b []byte) ([]byte, bool, []byte, error) {
i++ i++
if i < len(b) && b[i] == '"' { if i < len(b) && b[i] == '"' {
return nil, escaped, nil, newDecodeError(b[i-3:i+1], `""" not allowed in multiline basic string`) return nil, escaped, nil, NewParserError(b[i-3:i+1], `""" not allowed in multiline basic string`)
} }
return b[:i], escaped, b[i:], nil return b[:i], escaped, b[i:], nil
} }
case '\\': case '\\':
if len(b) < i+2 { if len(b) < i+2 {
return nil, escaped, nil, newDecodeError(b[len(b):], "need a character after \\") return nil, escaped, nil, NewParserError(b[len(b):], "need a character after \\")
} }
escaped = true escaped = true
i++ // skip the next character i++ // skip the next character
case '\r': case '\r':
if len(b) < i+2 { if len(b) < i+2 {
return nil, escaped, nil, newDecodeError(b[len(b):], `need a \n after \r`) return nil, escaped, nil, NewParserError(b[len(b):], `need a \n after \r`)
} }
if b[i+1] != '\n' { if b[i+1] != '\n' {
return nil, escaped, nil, newDecodeError(b[i:i+2], `need a \n after \r`) return nil, escaped, nil, NewParserError(b[i:i+2], `need a \n after \r`)
} }
i++ // skip the \n i++ // skip the \n
} }
} }
return nil, escaped, nil, newDecodeError(b[len(b):], `multiline basic string not terminated by """`) return nil, escaped, nil, NewParserError(b[len(b):], `multiline basic string not terminated by """`)
} }

@ -1,6 +1,6 @@
The MIT License (MIT) MIT License
Copyright (c) 2016 Mitchell Hashimoto Copyright (c) 2017 Segment
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal
@ -9,13 +9,13 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions: furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in The above copyright notice and this permission notice shall be included in all
all copies or substantial portions of the Software. copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
THE SOFTWARE. SOFTWARE.

@ -0,0 +1,121 @@
package fnv1a
const (
// FNV-1a
offset64 = uint64(14695981039346656037)
prime64 = uint64(1099511628211)
// Init64 is what 64 bits hash values should be initialized with.
Init64 = offset64
)
// HashString64 returns the hash of s.
func HashString64(s string) uint64 {
return AddString64(Init64, s)
}
// HashBytes64 returns the hash of u.
func HashBytes64(b []byte) uint64 {
return AddBytes64(Init64, b)
}
// HashUint64 returns the hash of u.
func HashUint64(u uint64) uint64 {
return AddUint64(Init64, u)
}
// AddString64 adds the hash of s to the precomputed hash value h.
func AddString64(h uint64, s string) uint64 {
/*
This is an unrolled version of this algorithm:
for _, c := range s {
h = (h ^ uint64(c)) * prime64
}
It seems to be ~1.5x faster than the simple loop in BenchmarkHash64:
- BenchmarkHash64/hash_function-4 30000000 56.1 ns/op 642.15 MB/s 0 B/op 0 allocs/op
- BenchmarkHash64/hash_function-4 50000000 38.6 ns/op 932.35 MB/s 0 B/op 0 allocs/op
*/
for len(s) >= 8 {
h = (h ^ uint64(s[0])) * prime64
h = (h ^ uint64(s[1])) * prime64
h = (h ^ uint64(s[2])) * prime64
h = (h ^ uint64(s[3])) * prime64
h = (h ^ uint64(s[4])) * prime64
h = (h ^ uint64(s[5])) * prime64
h = (h ^ uint64(s[6])) * prime64
h = (h ^ uint64(s[7])) * prime64
s = s[8:]
}
if len(s) >= 4 {
h = (h ^ uint64(s[0])) * prime64
h = (h ^ uint64(s[1])) * prime64
h = (h ^ uint64(s[2])) * prime64
h = (h ^ uint64(s[3])) * prime64
s = s[4:]
}
if len(s) >= 2 {
h = (h ^ uint64(s[0])) * prime64
h = (h ^ uint64(s[1])) * prime64
s = s[2:]
}
if len(s) > 0 {
h = (h ^ uint64(s[0])) * prime64
}
return h
}
// AddBytes64 adds the hash of b to the precomputed hash value h.
func AddBytes64(h uint64, b []byte) uint64 {
for len(b) >= 8 {
h = (h ^ uint64(b[0])) * prime64
h = (h ^ uint64(b[1])) * prime64
h = (h ^ uint64(b[2])) * prime64
h = (h ^ uint64(b[3])) * prime64
h = (h ^ uint64(b[4])) * prime64
h = (h ^ uint64(b[5])) * prime64
h = (h ^ uint64(b[6])) * prime64
h = (h ^ uint64(b[7])) * prime64
b = b[8:]
}
if len(b) >= 4 {
h = (h ^ uint64(b[0])) * prime64
h = (h ^ uint64(b[1])) * prime64
h = (h ^ uint64(b[2])) * prime64
h = (h ^ uint64(b[3])) * prime64
b = b[4:]
}
if len(b) >= 2 {
h = (h ^ uint64(b[0])) * prime64
h = (h ^ uint64(b[1])) * prime64
b = b[2:]
}
if len(b) > 0 {
h = (h ^ uint64(b[0])) * prime64
}
return h
}
// AddUint64 adds the hash value of the 8 bytes of u to h.
func AddUint64(h uint64, u uint64) uint64 {
h = (h ^ ((u >> 56) & 0xFF)) * prime64
h = (h ^ ((u >> 48) & 0xFF)) * prime64
h = (h ^ ((u >> 40) & 0xFF)) * prime64
h = (h ^ ((u >> 32) & 0xFF)) * prime64
h = (h ^ ((u >> 24) & 0xFF)) * prime64
h = (h ^ ((u >> 16) & 0xFF)) * prime64
h = (h ^ ((u >> 8) & 0xFF)) * prime64
h = (h ^ ((u >> 0) & 0xFF)) * prime64
return h
}

@ -0,0 +1,104 @@
package fnv1a
const (
// FNV-1a
offset32 = uint32(2166136261)
prime32 = uint32(16777619)
// Init32 is what 32 bits hash values should be initialized with.
Init32 = offset32
)
// HashString32 returns the hash of s.
func HashString32(s string) uint32 {
return AddString32(Init32, s)
}
// HashBytes32 returns the hash of u.
func HashBytes32(b []byte) uint32 {
return AddBytes32(Init32, b)
}
// HashUint32 returns the hash of u.
func HashUint32(u uint32) uint32 {
return AddUint32(Init32, u)
}
// AddString32 adds the hash of s to the precomputed hash value h.
func AddString32(h uint32, s string) uint32 {
for len(s) >= 8 {
h = (h ^ uint32(s[0])) * prime32
h = (h ^ uint32(s[1])) * prime32
h = (h ^ uint32(s[2])) * prime32
h = (h ^ uint32(s[3])) * prime32
h = (h ^ uint32(s[4])) * prime32
h = (h ^ uint32(s[5])) * prime32
h = (h ^ uint32(s[6])) * prime32
h = (h ^ uint32(s[7])) * prime32
s = s[8:]
}
if len(s) >= 4 {
h = (h ^ uint32(s[0])) * prime32
h = (h ^ uint32(s[1])) * prime32
h = (h ^ uint32(s[2])) * prime32
h = (h ^ uint32(s[3])) * prime32
s = s[4:]
}
if len(s) >= 2 {
h = (h ^ uint32(s[0])) * prime32
h = (h ^ uint32(s[1])) * prime32
s = s[2:]
}
if len(s) > 0 {
h = (h ^ uint32(s[0])) * prime32
}
return h
}
// AddBytes32 adds the hash of b to the precomputed hash value h.
func AddBytes32(h uint32, b []byte) uint32 {
for len(b) >= 8 {
h = (h ^ uint32(b[0])) * prime32
h = (h ^ uint32(b[1])) * prime32
h = (h ^ uint32(b[2])) * prime32
h = (h ^ uint32(b[3])) * prime32
h = (h ^ uint32(b[4])) * prime32
h = (h ^ uint32(b[5])) * prime32
h = (h ^ uint32(b[6])) * prime32
h = (h ^ uint32(b[7])) * prime32
b = b[8:]
}
if len(b) >= 4 {
h = (h ^ uint32(b[0])) * prime32
h = (h ^ uint32(b[1])) * prime32
h = (h ^ uint32(b[2])) * prime32
h = (h ^ uint32(b[3])) * prime32
b = b[4:]
}
if len(b) >= 2 {
h = (h ^ uint32(b[0])) * prime32
h = (h ^ uint32(b[1])) * prime32
b = b[2:]
}
if len(b) > 0 {
h = (h ^ uint32(b[0])) * prime32
}
return h
}
// AddUint32 adds the hash value of the 8 bytes of u to h.
func AddUint32(h, u uint32) uint32 {
h = (h ^ ((u >> 24) & 0xFF)) * prime32
h = (h ^ ((u >> 16) & 0xFF)) * prime32
h = (h ^ ((u >> 8) & 0xFF)) * prime32
h = (h ^ ((u >> 0) & 0xFF)) * prime32
return h
}

@ -99,7 +99,7 @@ func (*database) CompileStatement(sess sqladapter.Session, stmt *exql.Statement,
} }
query, args := sqlbuilder.Preprocess(compiled, args) query, args := sqlbuilder.Preprocess(compiled, args)
query = sqladapter.ReplaceWithDollarSign(query) query = string(sqladapter.ReplaceWithDollarSign([]byte(query)))
return query, args, nil return query, args, nil
} }

@ -24,25 +24,21 @@ package cache
import ( import (
"container/list" "container/list"
"errors" "errors"
"fmt"
"strconv"
"sync" "sync"
"github.com/upper/db/v4/internal/cache/hashstructure"
) )
const defaultCapacity = 128 const defaultCapacity = 128
// Cache holds a map of volatile key -> values. // Cache holds a map of volatile key -> values.
type Cache struct { type Cache struct {
cache map[string]*list.Element keys *list.List
li *list.List items map[uint64]*list.Element
capacity int
mu sync.RWMutex mu sync.RWMutex
capacity int
} }
type item struct { type cacheItem struct {
key string key uint64
value interface{} value interface{}
} }
@ -52,11 +48,11 @@ func NewCacheWithCapacity(capacity int) (*Cache, error) {
if capacity < 1 { if capacity < 1 {
return nil, errors.New("Capacity must be greater than zero.") return nil, errors.New("Capacity must be greater than zero.")
} }
return &Cache{ c := &Cache{
cache: make(map[string]*list.Element),
li: list.New(),
capacity: capacity, capacity: capacity,
}, nil }
c.init()
return c, nil
} }
// NewCache initializes a new caching space with default settings. // NewCache initializes a new caching space with default settings.
@ -68,6 +64,11 @@ func NewCache() *Cache {
return c return c
} }
func (c *Cache) init() {
c.items = make(map[uint64]*list.Element)
c.keys = list.New()
}
// Read attempts to retrieve a cached value as a string, if the value does not // Read attempts to retrieve a cached value as a string, if the value does not
// exists returns an empty string and false. // exists returns an empty string and false.
func (c *Cache) Read(h Hashable) (string, bool) { func (c *Cache) Read(h Hashable) (string, bool) {
@ -84,33 +85,35 @@ func (c *Cache) Read(h Hashable) (string, bool) {
func (c *Cache) ReadRaw(h Hashable) (interface{}, bool) { func (c *Cache) ReadRaw(h Hashable) (interface{}, bool) {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
data, ok := c.cache[h.Hash()]
item, ok := c.items[h.Hash()]
if ok { if ok {
return data.Value.(*item).value, true return item.Value.(*cacheItem).value, true
} }
return nil, false return nil, false
} }
// Write stores a value in memory. If the value already exists its overwritten. // Write stores a value in memory. If the value already exists its overwritten.
func (c *Cache) Write(h Hashable, value interface{}) { func (c *Cache) Write(h Hashable, value interface{}) {
key := h.Hash()
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if el, ok := c.cache[key]; ok { key := h.Hash()
el.Value.(*item).value = value
c.li.MoveToFront(el) if item, ok := c.items[key]; ok {
item.Value.(*cacheItem).value = value
c.keys.MoveToFront(item)
return return
} }
c.cache[key] = c.li.PushFront(&item{key, value}) c.items[key] = c.keys.PushFront(&cacheItem{key, value})
for c.li.Len() > c.capacity { for c.keys.Len() > c.capacity {
el := c.li.Remove(c.li.Back()) item := c.keys.Remove(c.keys.Back()).(*cacheItem)
delete(c.cache, el.(*item).key) delete(c.items, item.key)
if p, ok := el.(*item).value.(HasOnPurge); ok { if p, ok := item.value.(HasOnEvict); ok {
p.OnPurge() p.OnEvict()
} }
} }
} }
@ -120,33 +123,12 @@ func (c *Cache) Write(h Hashable, value interface{}) {
func (c *Cache) Clear() { func (c *Cache) Clear() {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
for _, el := range c.cache {
if p, ok := el.Value.(*item).value.(HasOnPurge); ok {
p.OnPurge()
}
}
c.cache = make(map[string]*list.Element)
c.li.Init()
}
// Hash returns a hash of the given struct. for _, item := range c.items {
func Hash(v interface{}) string { if p, ok := item.Value.(*cacheItem).value.(HasOnEvict); ok {
q, err := hashstructure.Hash(v, nil) p.OnEvict()
if err != nil { }
panic(fmt.Sprintf("Could not hash struct: %v", err.Error()))
} }
return strconv.FormatUint(q, 10)
}
type hash struct {
name string
}
func (h *hash) Hash() string {
return h.name
}
// String returns a Hashable that produces a hash equal to the given string. c.init()
func String(s string) Hashable {
return &hash{s}
} }

@ -0,0 +1,109 @@
package cache
import (
"fmt"
"github.com/segmentio/fasthash/fnv1a"
)
const (
hashTypeInt uint64 = 1 << iota
hashTypeSignedInt
hashTypeBool
hashTypeString
hashTypeHashable
hashTypeNil
)
type hasher struct {
t uint64
v interface{}
}
func (h *hasher) Hash() uint64 {
return NewHash(h.t, h.v)
}
func NewHashable(t uint64, v interface{}) Hashable {
return &hasher{t: t, v: v}
}
func InitHash(t uint64) uint64 {
return fnv1a.AddUint64(fnv1a.Init64, t)
}
func NewHash(t uint64, in ...interface{}) uint64 {
return AddToHash(InitHash(t), in...)
}
func AddToHash(h uint64, in ...interface{}) uint64 {
for i := range in {
if in[i] == nil {
continue
}
h = addToHash(h, in[i])
}
return h
}
func addToHash(h uint64, in interface{}) uint64 {
switch v := in.(type) {
case uint64:
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), v)
case uint32:
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
case uint16:
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
case uint8:
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
case uint:
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
case int64:
if v < 0 {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
} else {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
}
case int32:
if v < 0 {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
} else {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
}
case int16:
if v < 0 {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
} else {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
}
case int8:
if v < 0 {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
} else {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
}
case int:
if v < 0 {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
} else {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
}
case bool:
if v {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeBool), 1)
} else {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeBool), 2)
}
case string:
return fnv1a.AddString64(fnv1a.AddUint64(h, hashTypeString), v)
case Hashable:
if in == nil {
panic(fmt.Sprintf("could not hash nil element %T", in))
}
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeHashable), v.Hash())
case nil:
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeNil), 0)
default:
panic(fmt.Sprintf("unsupported value type %T", in))
}
}

@ -1,61 +0,0 @@
# hashstructure
hashstructure is a Go library for creating a unique hash value
for arbitrary values in Go.
This can be used to key values in a hash (for use in a map, set, etc.)
that are complex. The most common use case is comparing two values without
sending data across the network, caching values locally (de-dup), and so on.
## Features
* Hash any arbitrary Go value, including complex types.
* Tag a struct field to ignore it and not affect the hash value.
* Tag a slice type struct field to treat it as a set where ordering
doesn't affect the hash code but the field itself is still taken into
account to create the hash value.
* Optionally specify a custom hash function to optimize for speed, collision
avoidance for your data set, etc.
## Installation
Standard `go get`:
```
$ go get github.com/mitchellh/hashstructure
```
## Usage & Example
For usage and examples see the [Godoc](http://godoc.org/github.com/mitchellh/hashstructure).
A quick code example is shown below:
type ComplexStruct struct {
Name string
Age uint
Metadata map[string]interface{}
}
v := ComplexStruct{
Name: "mitchellh",
Age: 64,
Metadata: map[string]interface{}{
"car": true,
"location": "California",
"siblings": []string{"Bob", "John"},
},
}
hash, err := hashstructure.Hash(v, nil)
if err != nil {
panic(err)
}
fmt.Printf("%d", hash)
// Output:
// 2307517237273902113

@ -1,325 +0,0 @@
package hashstructure
import (
"encoding/binary"
"fmt"
"hash"
"hash/fnv"
"reflect"
)
// HashOptions are options that are available for hashing.
type HashOptions struct {
// Hasher is the hash function to use. If this isn't set, it will
// default to FNV.
Hasher hash.Hash64
// TagName is the struct tag to look at when hashing the structure.
// By default this is "hash".
TagName string
}
// Hash returns the hash value of an arbitrary value.
//
// If opts is nil, then default options will be used. See HashOptions
// for the default values.
//
// Notes on the value:
//
// * Unexported fields on structs are ignored and do not affect the
// hash value.
//
// * Adding an exported field to a struct with the zero value will change
// the hash value.
//
// For structs, the hashing can be controlled using tags. For example:
//
// struct {
// Name string
// UUID string `hash:"ignore"`
// }
//
// The available tag values are:
//
// * "ignore" - The field will be ignored and not affect the hash code.
//
// * "set" - The field will be treated as a set, where ordering doesn't
// affect the hash code. This only works for slices.
//
func Hash(v interface{}, opts *HashOptions) (uint64, error) {
// Create default options
if opts == nil {
opts = &HashOptions{}
}
if opts.Hasher == nil {
opts.Hasher = fnv.New64()
}
if opts.TagName == "" {
opts.TagName = "hash"
}
// Reset the hash
opts.Hasher.Reset()
// Create our walker and walk the structure
w := &walker{
h: opts.Hasher,
tag: opts.TagName,
}
return w.visit(reflect.ValueOf(v), nil)
}
type walker struct {
h hash.Hash64
tag string
}
type visitOpts struct {
// Flags are a bitmask of flags to affect behavior of this visit
Flags visitFlag
// Information about the struct containing this field
Struct interface{}
StructField string
}
func (w *walker) visit(v reflect.Value, opts *visitOpts) (uint64, error) {
// Loop since these can be wrapped in multiple layers of pointers
// and interfaces.
for {
// If we have an interface, dereference it. We have to do this up
// here because it might be a nil in there and the check below must
// catch that.
if v.Kind() == reflect.Interface {
v = v.Elem()
continue
}
if v.Kind() == reflect.Ptr {
v = reflect.Indirect(v)
continue
}
break
}
// If it is nil, treat it like a zero.
if !v.IsValid() {
var tmp int8
v = reflect.ValueOf(tmp)
}
// Binary writing can use raw ints, we have to convert to
// a sized-int, we'll choose the largest...
switch v.Kind() {
case reflect.Int:
v = reflect.ValueOf(int64(v.Int()))
case reflect.Uint:
v = reflect.ValueOf(uint64(v.Uint()))
case reflect.Bool:
var tmp int8
if v.Bool() {
tmp = 1
}
v = reflect.ValueOf(tmp)
}
k := v.Kind()
// We can shortcut numeric values by directly binary writing them
if k >= reflect.Int && k <= reflect.Complex64 {
// A direct hash calculation
w.h.Reset()
err := binary.Write(w.h, binary.LittleEndian, v.Interface())
return w.h.Sum64(), err
}
switch k {
case reflect.Array:
var h uint64
l := v.Len()
for i := 0; i < l; i++ {
current, err := w.visit(v.Index(i), nil)
if err != nil {
return 0, err
}
h = hashUpdateOrdered(w.h, h, current)
}
return h, nil
case reflect.Map:
var includeMap IncludableMap
if opts != nil && opts.Struct != nil {
if v, ok := opts.Struct.(IncludableMap); ok {
includeMap = v
}
}
// Build the hash for the map. We do this by XOR-ing all the key
// and value hashes. This makes it deterministic despite ordering.
var h uint64
for _, k := range v.MapKeys() {
v := v.MapIndex(k)
if includeMap != nil {
incl, err := includeMap.HashIncludeMap(
opts.StructField, k.Interface(), v.Interface())
if err != nil {
return 0, err
}
if !incl {
continue
}
}
kh, err := w.visit(k, nil)
if err != nil {
return 0, err
}
vh, err := w.visit(v, nil)
if err != nil {
return 0, err
}
fieldHash := hashUpdateOrdered(w.h, kh, vh)
h = hashUpdateUnordered(h, fieldHash)
}
return h, nil
case reflect.Struct:
var include Includable
parent := v.Interface()
if impl, ok := parent.(Includable); ok {
include = impl
}
t := v.Type()
h, err := w.visit(reflect.ValueOf(t.Name()), nil)
if err != nil {
return 0, err
}
l := v.NumField()
for i := 0; i < l; i++ {
if v := v.Field(i); v.CanSet() || t.Field(i).Name != "_" {
var f visitFlag
fieldType := t.Field(i)
if fieldType.PkgPath != "" {
// Unexported
continue
}
tag := fieldType.Tag.Get(w.tag)
if tag == "ignore" {
// Ignore this field
continue
}
// Check if we implement includable and check it
if include != nil {
incl, err := include.HashInclude(fieldType.Name, v)
if err != nil {
return 0, err
}
if !incl {
continue
}
}
switch tag {
case "set":
f |= visitFlagSet
}
kh, err := w.visit(reflect.ValueOf(fieldType.Name), nil)
if err != nil {
return 0, err
}
vh, err := w.visit(v, &visitOpts{
Flags: f,
Struct: parent,
StructField: fieldType.Name,
})
if err != nil {
return 0, err
}
fieldHash := hashUpdateOrdered(w.h, kh, vh)
h = hashUpdateUnordered(h, fieldHash)
}
}
return h, nil
case reflect.Slice:
// We have two behaviors here. If it isn't a set, then we just
// visit all the elements. If it is a set, then we do a deterministic
// hash code.
var h uint64
var set bool
if opts != nil {
set = (opts.Flags & visitFlagSet) != 0
}
l := v.Len()
for i := 0; i < l; i++ {
current, err := w.visit(v.Index(i), nil)
if err != nil {
return 0, err
}
if set {
h = hashUpdateUnordered(h, current)
} else {
h = hashUpdateOrdered(w.h, h, current)
}
}
return h, nil
case reflect.String:
// Directly hash
w.h.Reset()
_, err := w.h.Write([]byte(v.String()))
return w.h.Sum64(), err
default:
return 0, fmt.Errorf("unknown kind to hash: %s", k)
}
}
func hashUpdateOrdered(h hash.Hash64, a, b uint64) uint64 {
// For ordered updates, use a real hash function
h.Reset()
// We just panic if the binary writes fail because we are writing
// an int64 which should never be fail-able.
e1 := binary.Write(h, binary.LittleEndian, a)
e2 := binary.Write(h, binary.LittleEndian, b)
if e1 != nil {
panic(e1)
}
if e2 != nil {
panic(e2)
}
return h.Sum64()
}
func hashUpdateUnordered(a, b uint64) uint64 {
return a ^ b
}
// visitFlag is used as a bitmask for affecting visit behavior
type visitFlag uint
const (
visitFlagInvalid visitFlag = iota
visitFlagSet = iota << 1
)
var (
_ = visitFlagInvalid
)

@ -1,15 +0,0 @@
package hashstructure
// Includable is an interface that can optionally be implemented by
// a struct. It will be called for each field in the struct to check whether
// it should be included in the hash.
type Includable interface {
HashInclude(field string, v interface{}) (bool, error)
}
// IncludableMap is an interface that can optionally be implemented by
// a struct. It will be called when a map-type field is found to ask the
// struct if the map item should be included in the hash.
type IncludableMap interface {
HashIncludeMap(field string, k, v interface{}) (bool, error)
}

@ -24,11 +24,11 @@ package cache
// Hashable types must implement a method that returns a key. This key will be // Hashable types must implement a method that returns a key. This key will be
// associated with a cached value. // associated with a cached value.
type Hashable interface { type Hashable interface {
Hash() string Hash() uint64
} }
// HasOnPurge type is (optionally) implemented by cache objects to clean after // HasOnEvict type is (optionally) implemented by cache objects to clean after
// themselves. // themselves.
type HasOnPurge interface { type HasOnEvict interface {
OnPurge() OnEvict()
} }

@ -3,9 +3,11 @@ package exql
import ( import (
"fmt" "fmt"
"strings" "strings"
"github.com/upper/db/v4/internal/cache"
) )
type columnT struct { type columnWithAlias struct {
Name string Name string
Alias string Alias string
} }
@ -13,8 +15,6 @@ type columnT struct {
// Column represents a SQL column. // Column represents a SQL column.
type Column struct { type Column struct {
Name interface{} Name interface{}
Alias string
hash hash
} }
var _ = Fragment(&Column{}) var _ = Fragment(&Column{})
@ -25,8 +25,11 @@ func ColumnWithName(name string) *Column {
} }
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (c *Column) Hash() string { func (c *Column) Hash() uint64 {
return c.hash.Hash(c) if c == nil {
return cache.NewHash(FragmentType_Column, nil)
}
return cache.NewHash(FragmentType_Column, c.Name)
} }
// Compile transforms the ColumnValue into an equivalent SQL representation. // Compile transforms the ColumnValue into an equivalent SQL representation.
@ -35,20 +38,17 @@ func (c *Column) Compile(layout *Template) (compiled string, err error) {
return z, nil return z, nil
} }
alias := c.Alias var alias string
switch value := c.Name.(type) { switch value := c.Name.(type) {
case string: case string:
input := trimString(value) value = trimString(value)
chunks := separateByAS(input)
chunks := separateByAS(value)
if len(chunks) == 1 { if len(chunks) == 1 {
chunks = separateBySpace(input) chunks = separateBySpace(value)
} }
name := chunks[0] name := chunks[0]
nameChunks := strings.SplitN(name, layout.ColumnSeparator, 2) nameChunks := strings.SplitN(name, layout.ColumnSeparator, 2)
for i := range nameChunks { for i := range nameChunks {
@ -65,17 +65,19 @@ func (c *Column) Compile(layout *Template) (compiled string, err error) {
alias = trimString(chunks[1]) alias = trimString(chunks[1])
alias = layout.MustCompile(layout.IdentifierQuote, Raw{Value: alias}) alias = layout.MustCompile(layout.IdentifierQuote, Raw{Value: alias})
} }
case Raw: case compilable:
compiled = value.String() compiled, err = value.Compile(layout)
if err != nil {
return "", err
}
default: default:
compiled = fmt.Sprintf("%v", c.Name) return "", fmt.Errorf(errExpectingHashableFmt, c.Name)
} }
if alias != "" { if alias != "" {
compiled = layout.MustCompile(layout.ColumnAliasLayout, columnT{compiled, alias}) compiled = layout.MustCompile(layout.ColumnAliasLayout, columnWithAlias{compiled, alias})
} }
layout.Write(c, compiled) layout.Write(c, compiled)
return return
} }

@ -1,6 +1,7 @@
package exql package exql
import ( import (
"github.com/upper/db/v4/internal/cache"
"strings" "strings"
) )
@ -9,7 +10,6 @@ type ColumnValue struct {
Column Fragment Column Fragment
Operator string Operator string
Value Fragment Value Fragment
hash hash
} }
var _ = Fragment(&ColumnValue{}) var _ = Fragment(&ColumnValue{})
@ -21,8 +21,11 @@ type columnValueT struct {
} }
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (c *ColumnValue) Hash() string { func (c *ColumnValue) Hash() uint64 {
return c.hash.Hash(c) if c == nil {
return cache.NewHash(FragmentType_ColumnValue, nil)
}
return cache.NewHash(FragmentType_ColumnValue, c.Column, c.Operator, c.Value)
} }
// Compile transforms the ColumnValue into an equivalent SQL representation. // Compile transforms the ColumnValue into an equivalent SQL representation.
@ -58,7 +61,6 @@ func (c *ColumnValue) Compile(layout *Template) (compiled string, err error) {
// ColumnValues represents an array of ColumnValue // ColumnValues represents an array of ColumnValue
type ColumnValues struct { type ColumnValues struct {
ColumnValues []Fragment ColumnValues []Fragment
hash hash
} }
var _ = Fragment(&ColumnValues{}) var _ = Fragment(&ColumnValues{})
@ -71,13 +73,16 @@ func JoinColumnValues(values ...Fragment) *ColumnValues {
// Insert adds a column to the columns array. // Insert adds a column to the columns array.
func (c *ColumnValues) Insert(values ...Fragment) *ColumnValues { func (c *ColumnValues) Insert(values ...Fragment) *ColumnValues {
c.ColumnValues = append(c.ColumnValues, values...) c.ColumnValues = append(c.ColumnValues, values...)
c.hash.Reset()
return c return c
} }
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (c *ColumnValues) Hash() string { func (c *ColumnValues) Hash() uint64 {
return c.hash.Hash(c) h := cache.InitHash(FragmentType_ColumnValues)
for i := range c.ColumnValues {
h = cache.AddToHash(h, c.ColumnValues[i])
}
return h
} }
// Compile transforms the ColumnValues into its SQL representation. // Compile transforms the ColumnValues into its SQL representation.

@ -2,19 +2,27 @@ package exql
import ( import (
"strings" "strings"
"github.com/upper/db/v4/internal/cache"
) )
// Columns represents an array of Column. // Columns represents an array of Column.
type Columns struct { type Columns struct {
Columns []Fragment Columns []Fragment
hash hash
} }
var _ = Fragment(&Columns{}) var _ = Fragment(&Columns{})
// Hash returns a unique identifier. // Hash returns a unique identifier.
func (c *Columns) Hash() string { func (c *Columns) Hash() uint64 {
return c.hash.Hash(c) if c == nil {
return cache.NewHash(FragmentType_Columns, nil)
}
h := cache.InitHash(FragmentType_Columns)
for i := range c.Columns {
h = cache.AddToHash(h, c.Columns[i])
}
return h
} }
// JoinColumns creates and returns an array of Column. // JoinColumns creates and returns an array of Column.
@ -48,7 +56,6 @@ func (c *Columns) IsEmpty() bool {
// Compile transforms the Columns into an equivalent SQL representation. // Compile transforms the Columns into an equivalent SQL representation.
func (c *Columns) Compile(layout *Template) (compiled string, err error) { func (c *Columns) Compile(layout *Template) (compiled string, err error) {
if z, ok := layout.Read(c); ok { if z, ok := layout.Read(c); ok {
return z, nil return z, nil
} }

@ -1,9 +1,12 @@
package exql package exql
import (
"github.com/upper/db/v4/internal/cache"
)
// Database represents a SQL database. // Database represents a SQL database.
type Database struct { type Database struct {
Name string Name string
hash hash
} }
var _ = Fragment(&Database{}) var _ = Fragment(&Database{})
@ -14,8 +17,11 @@ func DatabaseWithName(name string) *Database {
} }
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (d *Database) Hash() string { func (d *Database) Hash() uint64 {
return d.hash.Hash(d) if d == nil {
return cache.NewHash(FragmentType_Database, nil)
}
return cache.NewHash(FragmentType_Database, d.Name)
} }
// Compile transforms the Database into an equivalent SQL representation. // Compile transforms the Database into an equivalent SQL representation.

@ -0,0 +1,5 @@
package exql
const (
errExpectingHashableFmt = "expecting hashable value, got %T"
)

@ -1,9 +1,12 @@
package exql package exql
import (
"github.com/upper/db/v4/internal/cache"
)
// GroupBy represents a SQL's "group by" statement. // GroupBy represents a SQL's "group by" statement.
type GroupBy struct { type GroupBy struct {
Columns Fragment Columns Fragment
hash hash
} }
var _ = Fragment(&GroupBy{}) var _ = Fragment(&GroupBy{})
@ -13,8 +16,11 @@ type groupByT struct {
} }
// Hash returns a unique identifier. // Hash returns a unique identifier.
func (g *GroupBy) Hash() string { func (g *GroupBy) Hash() uint64 {
return g.hash.Hash(g) if g == nil {
return cache.NewHash(FragmentType_GroupBy, nil)
}
return cache.NewHash(FragmentType_GroupBy, g.Columns)
} }
// GroupByColumns creates and returns a GroupBy with the given column. // GroupByColumns creates and returns a GroupBy with the given column.

@ -1,26 +0,0 @@
package exql
import (
"reflect"
"sync/atomic"
"github.com/upper/db/v4/internal/cache"
)
type hash struct {
v atomic.Value
}
func (h *hash) Hash(i interface{}) string {
v := h.v.Load()
if r, ok := v.(string); ok && r != "" {
return r
}
s := reflect.TypeOf(i).String() + ":" + cache.Hash(i)
h.v.Store(s)
return s
}
func (h *hash) Reset() {
h.v.Store("")
}

@ -2,6 +2,8 @@ package exql
import ( import (
"strings" "strings"
"github.com/upper/db/v4/internal/cache"
) )
type innerJoinT struct { type innerJoinT struct {
@ -14,14 +16,20 @@ type innerJoinT struct {
// Joins represents the union of different join conditions. // Joins represents the union of different join conditions.
type Joins struct { type Joins struct {
Conditions []Fragment Conditions []Fragment
hash hash
} }
var _ = Fragment(&Joins{}) var _ = Fragment(&Joins{})
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (j *Joins) Hash() string { func (j *Joins) Hash() uint64 {
return j.hash.Hash(j) if j == nil {
return cache.NewHash(FragmentType_Joins, nil)
}
h := cache.InitHash(FragmentType_Joins)
for i := range j.Conditions {
h = cache.AddToHash(h, j.Conditions[i])
}
return h
} }
// Compile transforms the Where into an equivalent SQL representation. // Compile transforms the Where into an equivalent SQL representation.
@ -66,14 +74,16 @@ type Join struct {
Table Fragment Table Fragment
On Fragment On Fragment
Using Fragment Using Fragment
hash hash
} }
var _ = Fragment(&Join{}) var _ = Fragment(&Join{})
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (j *Join) Hash() string { func (j *Join) Hash() uint64 {
return j.hash.Hash(j) if j == nil {
return cache.NewHash(FragmentType_Join, nil)
}
return cache.NewHash(FragmentType_Join, j.Type, j.Table, j.On, j.Using)
} }
// Compile transforms the Join into its equivalent SQL representation. // Compile transforms the Join into its equivalent SQL representation.
@ -118,9 +128,11 @@ type On Where
var _ = Fragment(&On{}) var _ = Fragment(&On{})
// Hash returns a unique identifier. func (o *On) Hash() uint64 {
func (o *On) Hash() string { if o == nil {
return o.hash.Hash(o) return cache.NewHash(FragmentType_On, nil)
}
return cache.NewHash(FragmentType_On, (*Where)(o))
} }
// Compile transforms the On into an equivalent SQL representation. // Compile transforms the On into an equivalent SQL representation.
@ -151,9 +163,11 @@ type usingT struct {
Columns string Columns string
} }
// Hash returns a unique identifier. func (u *Using) Hash() uint64 {
func (u *Using) Hash() string { if u == nil {
return u.hash.Hash(u) return cache.NewHash(FragmentType_Using, nil)
}
return cache.NewHash(FragmentType_Using, (*Columns)(u))
} }
// Compile transforms the Using into an equivalent SQL representation. // Compile transforms the Using into an equivalent SQL representation.

@ -1,8 +1,9 @@
package exql package exql
import ( import (
"fmt"
"strings" "strings"
"github.com/upper/db/v4/internal/cache"
) )
// Order represents the order in which SQL results are sorted. // Order represents the order in which SQL results are sorted.
@ -10,16 +11,20 @@ type Order uint8
// Possible values for Order // Possible values for Order
const ( const (
DefaultOrder = Order(iota) Order_Default Order = iota
Ascendent
Descendent Order_Ascendent
Order_Descendent
) )
func (o Order) Hash() uint64 {
return cache.NewHash(FragmentType_Order, uint8(o))
}
// SortColumn represents the column-order relation in an ORDER BY clause. // SortColumn represents the column-order relation in an ORDER BY clause.
type SortColumn struct { type SortColumn struct {
Column Fragment Column Fragment
Order Order
hash hash
} }
var _ = Fragment(&SortColumn{}) var _ = Fragment(&SortColumn{})
@ -34,7 +39,6 @@ var _ = Fragment(&SortColumn{})
// SortColumns represents the columns in an ORDER BY clause. // SortColumns represents the columns in an ORDER BY clause.
type SortColumns struct { type SortColumns struct {
Columns []Fragment Columns []Fragment
hash hash
} }
var _ = Fragment(&SortColumns{}) var _ = Fragment(&SortColumns{})
@ -42,7 +46,6 @@ var _ = Fragment(&SortColumns{})
// OrderBy represents an ORDER BY clause. // OrderBy represents an ORDER BY clause.
type OrderBy struct { type OrderBy struct {
SortColumns Fragment SortColumns Fragment
hash hash
} }
var _ = Fragment(&OrderBy{}) var _ = Fragment(&OrderBy{})
@ -62,8 +65,11 @@ func JoinWithOrderBy(sc *SortColumns) *OrderBy {
} }
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (s *SortColumn) Hash() string { func (s *SortColumn) Hash() uint64 {
return s.hash.Hash(s) if s == nil {
return cache.NewHash(FragmentType_SortColumn, nil)
}
return cache.NewHash(FragmentType_SortColumn, s.Column, s.Order)
} }
// Compile transforms the SortColumn into an equivalent SQL representation. // Compile transforms the SortColumn into an equivalent SQL representation.
@ -93,8 +99,15 @@ func (s *SortColumn) Compile(layout *Template) (compiled string, err error) {
} }
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (s *SortColumns) Hash() string { func (s *SortColumns) Hash() uint64 {
return s.hash.Hash(s) if s == nil {
return cache.NewHash(FragmentType_SortColumns, nil)
}
h := cache.InitHash(FragmentType_SortColumns)
for i := range s.Columns {
h = cache.AddToHash(h, s.Columns[i])
}
return h
} }
// Compile transforms the SortColumns into an equivalent SQL representation. // Compile transforms the SortColumns into an equivalent SQL representation.
@ -120,8 +133,11 @@ func (s *SortColumns) Compile(layout *Template) (compiled string, err error) {
} }
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (s *OrderBy) Hash() string { func (s *OrderBy) Hash() uint64 {
return s.hash.Hash(s) if s == nil {
return cache.NewHash(FragmentType_OrderBy, nil)
}
return cache.NewHash(FragmentType_OrderBy, s.SortColumns)
} }
// Compile transforms the SortColumn into an equivalent SQL representation. // Compile transforms the SortColumn into an equivalent SQL representation.
@ -147,17 +163,12 @@ func (s *OrderBy) Compile(layout *Template) (compiled string, err error) {
return return
} }
// Hash returns a unique identifier.
func (s *Order) Hash() string {
return fmt.Sprintf("%T.%d", s, uint8(*s))
}
// Compile transforms the SortColumn into an equivalent SQL representation. // Compile transforms the SortColumn into an equivalent SQL representation.
func (s Order) Compile(layout *Template) (string, error) { func (s Order) Compile(layout *Template) (string, error) {
switch s { switch s {
case Ascendent: case Order_Ascendent:
return layout.AscKeyword, nil return layout.AscKeyword, nil
case Descendent: case Order_Descendent:
return layout.DescKeyword, nil return layout.DescKeyword, nil
} }
return "", nil return "", nil

@ -2,7 +2,8 @@ package exql
import ( import (
"fmt" "fmt"
"strings"
"github.com/upper/db/v4/internal/cache"
) )
var ( var (
@ -11,18 +12,27 @@ var (
// Raw represents a value that is meant to be used in a query without escaping. // Raw represents a value that is meant to be used in a query without escaping.
type Raw struct { type Raw struct {
Value string // Value should not be modified after assigned. Value string
hash hash
} }
// RawValue creates and returns a new raw value. func NewRawValue(v interface{}) (*Raw, error) {
func RawValue(v string) *Raw { switch t := v.(type) {
return &Raw{Value: strings.TrimSpace(v)} case string:
return &Raw{Value: t}, nil
case int, uint, int64, uint64, int32, uint32, int16, uint16:
return &Raw{Value: fmt.Sprintf("%d", t)}, nil
case fmt.Stringer:
return &Raw{Value: t.String()}, nil
}
return nil, fmt.Errorf("unexpected type: %T", v)
} }
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (r *Raw) Hash() string { func (r *Raw) Hash() uint64 {
return r.hash.Hash(r) if r == nil {
return cache.NewHash(FragmentType_Raw, nil)
}
return cache.NewHash(FragmentType_Raw, r.Value)
} }
// Compile returns the raw value. // Compile returns the raw value.

@ -1,14 +1,20 @@
package exql package exql
import (
"github.com/upper/db/v4/internal/cache"
)
// Returning represents a RETURNING clause. // Returning represents a RETURNING clause.
type Returning struct { type Returning struct {
*Columns *Columns
hash hash
} }
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (r *Returning) Hash() string { func (r *Returning) Hash() uint64 {
return r.hash.Hash(r) if r == nil {
return cache.NewHash(FragmentType_Returning, nil)
}
return cache.NewHash(FragmentType_Returning, r.Columns)
} }
var _ = Fragment(&Returning{}) var _ = Fragment(&Returning{})

@ -4,6 +4,8 @@ import (
"errors" "errors"
"reflect" "reflect"
"strings" "strings"
"github.com/upper/db/v4/internal/cache"
) )
var errUnknownTemplateType = errors.New("Unknown template type") var errUnknownTemplateType = errors.New("Unknown template type")
@ -28,7 +30,6 @@ type Statement struct {
SQL string SQL string
hash hash
amendFn func(string) string amendFn func(string) string
} }
@ -40,8 +41,28 @@ func (layout *Template) doCompile(c Fragment) (string, error) {
} }
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (s *Statement) Hash() string { func (s *Statement) Hash() uint64 {
return s.hash.Hash(s) if s == nil {
return cache.NewHash(FragmentType_Statement, nil)
}
return cache.NewHash(
FragmentType_Statement,
s.Type,
s.Table,
s.Database,
s.Columns,
s.Values,
s.Distinct,
s.ColumnValues,
s.OrderBy,
s.GroupBy,
s.Joins,
s.Where,
s.Returning,
s.Limit,
s.Offset,
s.SQL,
)
} }
func (s *Statement) SetAmendment(amendFn func(string) string) { func (s *Statement) SetAmendment(amendFn func(string) string) {

@ -2,6 +2,8 @@ package exql
import ( import (
"strings" "strings"
"github.com/upper/db/v4/internal/cache"
) )
type tableT struct { type tableT struct {
@ -12,7 +14,6 @@ type tableT struct {
// Table struct represents a SQL table. // Table struct represents a SQL table.
type Table struct { type Table struct {
Name interface{} Name interface{}
hash hash
} }
var _ = Fragment(&Table{}) var _ = Fragment(&Table{})
@ -57,8 +58,11 @@ func TableWithName(name string) *Table {
} }
// Hash returns a string hash of the table value. // Hash returns a string hash of the table value.
func (t *Table) Hash() string { func (t *Table) Hash() uint64 {
return t.hash.Hash(t) if t == nil {
return cache.NewHash(FragmentType_Table, nil)
}
return cache.NewHash(FragmentType_Table, t.Name)
} }
// Compile transforms a table struct into a SQL chunk. // Compile transforms a table struct into a SQL chunk.

@ -11,11 +11,11 @@ import (
) )
// Type is the type of SQL query the statement represents. // Type is the type of SQL query the statement represents.
type Type uint type Type uint8
// Values for Type. // Values for Type.
const ( const (
NoOp = Type(iota) NoOp Type = iota
Truncate Truncate
DropTable DropTable
@ -29,13 +29,25 @@ const (
SQL SQL
) )
func (t Type) Hash() uint64 {
return cache.NewHash(FragmentType_StatementType, uint8(t))
}
type ( type (
// Limit represents the SQL limit in a query. // Limit represents the SQL limit in a query.
Limit int Limit int64
// Offset represents the SQL offset in a query. // Offset represents the SQL offset in a query.
Offset int Offset int64
) )
func (t Limit) Hash() uint64 {
return cache.NewHash(FragmentType_Limit, uint64(t))
}
func (t Offset) Hash() uint64 {
return cache.NewHash(FragmentType_Offset, uint64(t))
}
// Template is an SQL template. // Template is an SQL template.
type Template struct { type Template struct {
AndKeyword string AndKeyword string

@ -0,0 +1,35 @@
package exql
const (
FragmentType_None uint64 = iota + 713910251627
FragmentType_And
FragmentType_Column
FragmentType_ColumnValue
FragmentType_ColumnValues
FragmentType_Columns
FragmentType_Database
FragmentType_GroupBy
FragmentType_Join
FragmentType_Joins
FragmentType_Nil
FragmentType_Or
FragmentType_Limit
FragmentType_Offset
FragmentType_OrderBy
FragmentType_Order
FragmentType_Raw
FragmentType_Returning
FragmentType_SortBy
FragmentType_SortColumn
FragmentType_SortColumns
FragmentType_Statement
FragmentType_StatementType
FragmentType_Table
FragmentType_Value
FragmentType_On
FragmentType_Using
FragmentType_ValueGroups
FragmentType_Values
FragmentType_Where
)

@ -1,14 +1,14 @@
package exql package exql
import ( import (
"fmt"
"strings" "strings"
"github.com/upper/db/v4/internal/cache"
) )
// ValueGroups represents an array of value groups. // ValueGroups represents an array of value groups.
type ValueGroups struct { type ValueGroups struct {
Values []*Values Values []*Values
hash hash
} }
func (vg *ValueGroups) IsEmpty() bool { func (vg *ValueGroups) IsEmpty() bool {
@ -28,7 +28,6 @@ var _ = Fragment(&ValueGroups{})
// Values represents an array of Value. // Values represents an array of Value.
type Values struct { type Values struct {
Values []Fragment Values []Fragment
hash hash
} }
func (vs *Values) IsEmpty() bool { func (vs *Values) IsEmpty() bool {
@ -38,12 +37,16 @@ func (vs *Values) IsEmpty() bool {
return false return false
} }
// NewValueGroup creates and returns an array of values.
func NewValueGroup(v ...Fragment) *Values {
return &Values{Values: v}
}
var _ = Fragment(&Values{}) var _ = Fragment(&Values{})
// Value represents an escaped SQL value. // Value represents an escaped SQL value.
type Value struct { type Value struct {
V interface{} V interface{}
hash hash
} }
var _ = Fragment(&Value{}) var _ = Fragment(&Value{})
@ -53,50 +56,51 @@ func NewValue(v interface{}) *Value {
return &Value{V: v} return &Value{V: v}
} }
// NewValueGroup creates and returns an array of values.
func NewValueGroup(v ...Fragment) *Values {
return &Values{Values: v}
}
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (v *Value) Hash() string { func (v *Value) Hash() uint64 {
return v.hash.Hash(v) if v == nil {
} return cache.NewHash(FragmentType_Value, nil)
}
func (v *Value) IsEmpty() bool { return cache.NewHash(FragmentType_Value, v.V)
return false
} }
// Compile transforms the Value into an equivalent SQL representation. // Compile transforms the Value into an equivalent SQL representation.
func (v *Value) Compile(layout *Template) (compiled string, err error) { func (v *Value) Compile(layout *Template) (compiled string, err error) {
if z, ok := layout.Read(v); ok { if z, ok := layout.Read(v); ok {
return z, nil return z, nil
} }
switch t := v.V.(type) { switch value := v.V.(type) {
case Raw: case compilable:
compiled, err = t.Compile(layout) compiled, err = value.Compile(layout)
if err != nil { if err != nil {
return "", err return "", err
} }
case Fragment: default:
compiled, err = t.Compile(layout) value, err := NewRawValue(v.V)
if err != nil { if err != nil {
return "", err return "", err
} }
default: compiled = layout.MustCompile(
compiled = layout.MustCompile(layout.ValueQuote, RawValue(fmt.Sprintf(`%v`, v.V))) layout.ValueQuote,
value,
)
} }
layout.Write(v, compiled) layout.Write(v, compiled)
return return
} }
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (vs *Values) Hash() string { func (vs *Values) Hash() uint64 {
return vs.hash.Hash(vs) if vs == nil {
return cache.NewHash(FragmentType_Values, nil)
}
h := cache.InitHash(FragmentType_Values)
for i := range vs.Values {
h = cache.AddToHash(h, vs.Values[i])
}
return h
} }
// Compile transforms the Values into an equivalent SQL representation. // Compile transforms the Values into an equivalent SQL representation.
@ -122,8 +126,15 @@ func (vs *Values) Compile(layout *Template) (compiled string, err error) {
} }
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (vg *ValueGroups) Hash() string { func (vg *ValueGroups) Hash() uint64 {
return vg.hash.Hash(vg) if vg == nil {
return cache.NewHash(FragmentType_ValueGroups, nil)
}
h := cache.InitHash(FragmentType_ValueGroups)
for i := range vg.Values {
h = cache.AddToHash(h, vg.Values[i])
}
return h
} }
// Compile transforms the ValueGroups into an equivalent SQL representation. // Compile transforms the ValueGroups into an equivalent SQL representation.

@ -2,6 +2,8 @@ package exql
import ( import (
"strings" "strings"
"github.com/upper/db/v4/internal/cache"
) )
// Or represents an SQL OR operator. // Or represents an SQL OR operator.
@ -13,7 +15,6 @@ type And Where
// Where represents an SQL WHERE clause. // Where represents an SQL WHERE clause.
type Where struct { type Where struct {
Conditions []Fragment Conditions []Fragment
hash hash
} }
var _ = Fragment(&Where{}) var _ = Fragment(&Where{})
@ -38,8 +39,15 @@ func JoinWithAnd(conditions ...Fragment) *And {
} }
// Hash returns a unique identifier for the struct. // Hash returns a unique identifier for the struct.
func (w *Where) Hash() string { func (w *Where) Hash() uint64 {
return w.hash.Hash(w) if w == nil {
return cache.NewHash(FragmentType_Where, nil)
}
h := cache.InitHash(FragmentType_Where)
for i := range w.Conditions {
h = cache.AddToHash(h, w.Conditions[i])
}
return h
} }
// Appends adds the conditions to the ones that already exist. // Appends adds the conditions to the ones that already exist.
@ -51,15 +59,19 @@ func (w *Where) Append(a *Where) *Where {
} }
// Hash returns a unique identifier. // Hash returns a unique identifier.
func (o *Or) Hash() string { func (o *Or) Hash() uint64 {
w := Where(*o) if o == nil {
return `Or(` + w.Hash() + `)` return cache.NewHash(FragmentType_Or, nil)
}
return cache.NewHash(FragmentType_Or, (*Where)(o))
} }
// Hash returns a unique identifier. // Hash returns a unique identifier.
func (a *And) Hash() string { func (a *And) Hash() uint64 {
w := Where(*a) if a == nil {
return `And(` + w.Hash() + `)` return cache.NewHash(FragmentType_And, nil)
}
return cache.NewHash(FragmentType_And, (*Where)(a))
} }
// Compile transforms the Or into an equivalent SQL representation. // Compile transforms the Or into an equivalent SQL representation.

@ -0,0 +1,8 @@
package sqladapter
const (
hashTypeNone = iota + 345065139389
hashTypeCollection
hashTypePrimaryKeys
)

@ -1,6 +1,7 @@
package sqladapter package sqladapter
import ( import (
"bytes"
"context" "context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
@ -286,7 +287,8 @@ func (sess *sessionWithContext) Err(errIn error) (errOur error) {
} }
func (sess *sessionWithContext) PrimaryKeys(tableName string) ([]string, error) { func (sess *sessionWithContext) PrimaryKeys(tableName string) ([]string, error) {
h := cache.String(tableName) h := cache.NewHashable(hashTypePrimaryKeys, tableName)
cachedPK, ok := sess.cachedPKs.ReadRaw(h) cachedPK, ok := sess.cachedPKs.ReadRaw(h)
if ok { if ok {
return cachedPK.([]string), nil return cachedPK.([]string), nil
@ -652,7 +654,8 @@ func (sess *sessionWithContext) Collection(name string) db.Collection {
sess.cacheMu.Lock() sess.cacheMu.Lock()
defer sess.cacheMu.Unlock() defer sess.cacheMu.Unlock()
h := cache.String(name) h := cache.NewHashable(hashTypeCollection, name)
col, ok := sess.cachedCollections.ReadRaw(h) col, ok := sess.cachedCollections.ReadRaw(h)
if !ok { if !ok {
col = newCollection(name, sess.adapter.NewCollection()) col = newCollection(name, sess.adapter.NewCollection())
@ -1001,29 +1004,37 @@ func (sess *sessionWithContext) WaitForConnection(connectFn func() error) error
// ReplaceWithDollarSign turns a SQL statament with '?' placeholders into // ReplaceWithDollarSign turns a SQL statament with '?' placeholders into
// dollar placeholders, like $1, $2, ..., $n // dollar placeholders, like $1, $2, ..., $n
func ReplaceWithDollarSign(in string) string { func ReplaceWithDollarSign(buf []byte) []byte {
buf := []byte(in) z := bytes.Count(buf, []byte{'?'})
out := make([]byte, 0, len(buf)) // the capacity is a quick estimation of the total memory required, this
// reduces reallocations
i, j, k, t := 0, 1, 0, len(buf) out := make([]byte, 0, len(buf)+z*3)
for i < t { var i, k = 0, 1
for i < len(buf) {
if buf[i] == '?' { if buf[i] == '?' {
out = append(out, buf[k:i]...) out = append(out, buf[:i]...)
k = i + 1 buf = buf[i+1:]
i = 0
if k < t && buf[k] == '?' { if len(buf) > 0 && buf[0] == '?' {
i = k out = append(out, '?')
} else { buf = buf[1:]
out = append(out, []byte("$"+strconv.Itoa(j))...) continue
j++
} }
out = append(out, '$')
out = append(out, []byte(strconv.Itoa(k))...)
k = k + 1
continue
} }
i++ i = i + 1
} }
out = append(out, buf[k:i]...)
return string(out) out = append(out, buf[:len(buf)]...)
buf = nil
return out
} }
func copySettings(from Session, into Session) { func copySettings(from Session, into Session) {

@ -12,7 +12,7 @@ var (
) )
// Stmt represents a *sql.Stmt that is cached and provides the // Stmt represents a *sql.Stmt that is cached and provides the
// OnPurge method to allow it to clean after itself. // OnEvict method to allow it to clean after itself.
type Stmt struct { type Stmt struct {
*sql.Stmt *sql.Stmt
@ -69,8 +69,8 @@ func (c *Stmt) checkClose() error {
return nil return nil
} }
// OnPurge marks the statement as ready to be cleaned up. // OnEvict marks the statement as ready to be cleaned up.
func (c *Stmt) OnPurge() { func (c *Stmt) OnEvict() {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()

@ -75,7 +75,7 @@ type fieldValue struct {
} }
var ( var (
sqlPlaceholder = exql.RawValue(`?`) sqlPlaceholder = &exql.Raw{Value: `?`}
) )
var ( var (
@ -358,7 +358,7 @@ func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, err
q, a := Preprocess(p.String(), p.Arguments()) q, a := Preprocess(p.String(), p.Arguments())
f[i] = exql.RawValue("(" + q + ")") f[i] = &exql.Raw{Value: "(" + q + ")"}
args = append(args, a...) args = append(args, a...)
case isCompilable: case isCompilable:
c, err := v.Compile() c, err := v.Compile()
@ -369,7 +369,7 @@ func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, err
if _, ok := v.(db.Selector); ok { if _, ok := v.(db.Selector); ok {
q = "(" + q + ")" q = "(" + q + ")"
} }
f[i] = exql.RawValue(q) f[i] = &exql.Raw{Value: q}
args = append(args, a...) args = append(args, a...)
case *adapter.FuncExpr: case *adapter.FuncExpr:
fnName, fnArgs := v.Name(), v.Arguments() fnName, fnArgs := v.Name(), v.Arguments()
@ -379,22 +379,24 @@ func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, err
fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
} }
fnName, fnArgs = Preprocess(fnName, fnArgs) fnName, fnArgs = Preprocess(fnName, fnArgs)
f[i] = exql.RawValue(fnName) f[i] = &exql.Raw{Value: fnName}
args = append(args, fnArgs...) args = append(args, fnArgs...)
case *adapter.RawExpr: case *adapter.RawExpr:
q, a := Preprocess(v.Raw(), v.Arguments()) q, a := Preprocess(v.Raw(), v.Arguments())
f[i] = exql.RawValue(q) f[i] = &exql.Raw{Value: q}
args = append(args, a...) args = append(args, a...)
case exql.Fragment: case exql.Fragment:
f[i] = v f[i] = v
case string: case string:
f[i] = exql.ColumnWithName(v) f[i] = exql.ColumnWithName(v)
case int: case fmt.Stringer:
f[i] = exql.RawValue(fmt.Sprintf("%v", v)) f[i] = exql.ColumnWithName(v.String())
case interface{}:
f[i] = exql.ColumnWithName(fmt.Sprintf("%v", v))
default: default:
return nil, nil, fmt.Errorf("unexpected argument type %T for Select() argument", v) var err error
f[i], err = exql.NewRawValue(columns[i])
if err != nil {
return nil, nil, fmt.Errorf("unexpected argument type %T for Select() argument: %w", v, err)
}
} }
} }
return f, args, nil return f, args, nil

@ -1,43 +1,93 @@
package sqlbuilder package sqlbuilder
import ( import (
"bytes"
"database/sql/driver" "database/sql/driver"
"reflect" "reflect"
"strings"
"github.com/upper/db/v4/internal/adapter" "github.com/upper/db/v4/internal/adapter"
"github.com/upper/db/v4/internal/sqladapter/exql" "github.com/upper/db/v4/internal/sqladapter/exql"
) )
var ( var (
sqlDefault = exql.RawValue(`DEFAULT`) sqlDefault = &exql.Raw{Value: "DEFAULT"}
) )
func expandQuery(in string, args []interface{}, fn func(interface{}) (string, []interface{})) (string, []interface{}) { func expandQuery(in []byte, inArgs []interface{}) ([]byte, []interface{}) {
argn := 0 out := make([]byte, 0, len(in))
argx := make([]interface{}, 0, len(args)) outArgs := make([]interface{}, 0, len(inArgs))
for i := 0; i < len(in); i++ {
if in[i] != '?' { i := 0
for i < len(in) && len(inArgs) > 0 {
if in[i] == '?' {
out = append(out, in[:i]...)
in = in[i+1:]
i = 0
replace, replaceArgs := expandArgument(inArgs[0])
inArgs = inArgs[1:]
if len(replace) > 0 {
replace, replaceArgs = expandQuery(replace, replaceArgs)
out = append(out, replace...)
} else {
out = append(out, '?')
}
outArgs = append(outArgs, replaceArgs...)
continue continue
} }
if len(args) > argn { i = i + 1
k, values := fn(args[argn]) }
k, values = expandQuery(k, values, fn)
if len(out) < 1 {
return in, inArgs
}
out = append(out, in[:len(in)]...)
in = nil
outArgs = append(outArgs, inArgs[:len(inArgs)]...)
inArgs = nil
return out, outArgs
}
func expandArgument(arg interface{}) ([]byte, []interface{}) {
values, isSlice := toInterfaceArguments(arg)
if k != "" { if isSlice {
in = in[:i] + k + in[i+1:] if len(values) == 0 {
i += len(k) - 1 return []byte("(NULL)"), nil
} }
if len(values) > 0 { buf := bytes.Repeat([]byte(" ?,"), len(values))
argx = append(argx, values...) buf[0] = '('
buf[len(buf)-1] = ')'
return buf, values
} }
argn++
if len(values) == 1 {
switch t := arg.(type) {
case *adapter.RawExpr:
return expandQuery([]byte(t.Raw()), t.Arguments())
case hasPaginator:
p, err := t.Paginator()
if err == nil {
return append([]byte{'('}, append([]byte(p.String()), ')')...), p.Arguments()
} }
panic(err.Error())
case isCompilable:
s, err := t.Compile()
if err == nil {
return append([]byte{'('}, append([]byte(s), ')')...), t.Arguments()
} }
if len(argx) < len(args) { panic(err.Error())
argx = append(argx, args[argn:]...)
} }
return in, argx } else if len(values) == 0 {
return []byte("NULL"), nil
}
return nil, []interface{}{arg}
} }
// toInterfaceArguments converts the given value into an array of interfaces. // toInterfaceArguments converts the given value into an array of interfaces.
@ -57,7 +107,7 @@ func toInterfaceArguments(value interface{}) (args []interface{}, isSlice bool)
// Byte slice gets transformed into a string. // Byte slice gets transformed into a string.
if v.Type().Elem().Kind() == reflect.Uint8 { if v.Type().Elem().Kind() == reflect.Uint8 {
return []interface{}{string(value.([]byte))}, false return []interface{}{string(v.Bytes())}, false
} }
total = v.Len() total = v.Len()
@ -108,42 +158,9 @@ func toColumnsValuesAndArguments(columnNames []string, columnValues []interface{
return columns, values, arguments, nil return columns, values, arguments, nil
} }
func preprocessFn(arg interface{}) (string, []interface{}) {
values, isSlice := toInterfaceArguments(arg)
if isSlice {
if len(values) == 0 {
return `(NULL)`, nil
}
return `(?` + strings.Repeat(`, ?`, len(values)-1) + `)`, values
}
if len(values) == 1 {
switch t := arg.(type) {
case *adapter.RawExpr:
return Preprocess(t.Raw(), t.Arguments())
case hasPaginator:
p, err := t.Paginator()
if err == nil {
return `(` + p.String() + `)`, p.Arguments()
}
panic(err.Error())
case isCompilable:
c, err := t.Compile()
if err == nil {
return `(` + c + `)`, t.Arguments()
}
panic(err.Error())
}
} else if len(values) == 0 {
return `NULL`, nil
}
return "", []interface{}{arg}
}
// Preprocess expands arguments that needs to be expanded and compiles a query // Preprocess expands arguments that needs to be expanded and compiles a query
// into a single string. // into a single string.
func Preprocess(in string, args []interface{}) (string, []interface{}) { func Preprocess(in string, args []interface{}) (string, []interface{}) {
return expandQuery(in, args, preprocessFn) b, args := expandQuery([]byte(in), args)
return string(b), args
} }

@ -60,7 +60,7 @@ func (iq *inserterQuery) processValues() ([]*exql.Values, []interface{}, error)
l := len(enqueuedValue) l := len(enqueuedValue)
placeholders := make([]exql.Fragment, l) placeholders := make([]exql.Fragment, l)
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
placeholders[i] = exql.RawValue(`?`) placeholders[i] = sqlPlaceholder
} }
values = append(values, exql.NewValueGroup(placeholders...)) values = append(values, exql.NewValueGroup(placeholders...))
} }

@ -257,7 +257,7 @@ func (sel *selector) OrderBy(columns ...interface{}) db.Selector {
case *adapter.RawExpr: case *adapter.RawExpr:
query, args := Preprocess(value.Raw(), value.Arguments()) query, args := Preprocess(value.Raw(), value.Arguments())
sort = &exql.SortColumn{ sort = &exql.SortColumn{
Column: exql.RawValue(query), Column: &exql.Raw{Value: query},
} }
sq.orderByArgs = append(sq.orderByArgs, args...) sq.orderByArgs = append(sq.orderByArgs, args...)
case *adapter.FuncExpr: case *adapter.FuncExpr:
@ -269,21 +269,21 @@ func (sel *selector) OrderBy(columns ...interface{}) db.Selector {
} }
fnName, fnArgs = Preprocess(fnName, fnArgs) fnName, fnArgs = Preprocess(fnName, fnArgs)
sort = &exql.SortColumn{ sort = &exql.SortColumn{
Column: exql.RawValue(fnName), Column: &exql.Raw{Value: fnName},
} }
sq.orderByArgs = append(sq.orderByArgs, fnArgs...) sq.orderByArgs = append(sq.orderByArgs, fnArgs...)
case string: case string:
if strings.HasPrefix(value, "-") { if strings.HasPrefix(value, "-") {
sort = &exql.SortColumn{ sort = &exql.SortColumn{
Column: exql.ColumnWithName(value[1:]), Column: exql.ColumnWithName(value[1:]),
Order: exql.Descendent, Order: exql.Order_Descendent,
} }
} else { } else {
chunks := strings.SplitN(value, " ", 2) chunks := strings.SplitN(value, " ", 2)
order := exql.Ascendent order := exql.Order_Ascendent
if len(chunks) > 1 && strings.ToUpper(chunks[1]) == "DESC" { if len(chunks) > 1 && strings.ToUpper(chunks[1]) == "DESC" {
order = exql.Descendent order = exql.Order_Descendent
} }
sort = &exql.SortColumn{ sort = &exql.SortColumn{
@ -418,7 +418,7 @@ func (sel *selector) As(alias string) db.Selector {
if err != nil { if err != nil {
return err return err
} }
sq.table.Columns[last] = exql.RawValue(raw.Value + " AS " + compiled) sq.table.Columns[last] = &exql.Raw{Value: raw.Value + " AS " + compiled}
} }
return nil return nil
}) })

@ -21,7 +21,7 @@ func newTemplateWithUtils(template *exql.Template) *templateWithUtils {
func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []interface{}) { func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []interface{}) {
switch t := in.(type) { switch t := in.(type) {
case *adapter.RawExpr: case *adapter.RawExpr:
return exql.RawValue(t.Raw()), t.Arguments() return &exql.Raw{Value: t.Raw()}, t.Arguments()
case *adapter.FuncExpr: case *adapter.FuncExpr:
fnName := t.Name() fnName := t.Name()
fnArgs := []interface{}{} fnArgs := []interface{}{}
@ -35,7 +35,7 @@ func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []
fnArgs = append(fnArgs, args...) fnArgs = append(fnArgs, args...)
} }
} }
return exql.RawValue(fnName + `(` + strings.Join(fragments, `, `) + `)`), fnArgs return &exql.Raw{Value: fnName + `(` + strings.Join(fragments, `, `) + `)`}, fnArgs
default: default:
return sqlPlaceholder, []interface{}{in} return sqlPlaceholder, []interface{}{in}
} }
@ -51,7 +51,7 @@ func (tu *templateWithUtils) toWhereWithArguments(term interface{}) (where exql.
if s, ok := t[0].(string); ok { if s, ok := t[0].(string); ok {
if strings.ContainsAny(s, "?") || len(t) == 1 { if strings.ContainsAny(s, "?") || len(t) == 1 {
s, args = Preprocess(s, t[1:]) s, args = Preprocess(s, t[1:])
where.Conditions = []exql.Fragment{exql.RawValue(s)} where.Conditions = []exql.Fragment{&exql.Raw{Value: s}}
} else { } else {
var val interface{} var val interface{}
key := s key := s
@ -80,7 +80,7 @@ func (tu *templateWithUtils) toWhereWithArguments(term interface{}) (where exql.
return return
case *adapter.RawExpr: case *adapter.RawExpr:
r, v := Preprocess(t.Raw(), t.Arguments()) r, v := Preprocess(t.Raw(), t.Arguments())
where.Conditions = []exql.Fragment{exql.RawValue(r)} where.Conditions = []exql.Fragment{&exql.Raw{Value: r}}
args = append(args, v...) args = append(args, v...)
return return
case adapter.Constraints: case adapter.Constraints:
@ -172,10 +172,10 @@ func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnVal
} }
} else { } else {
if rawValue, ok := t.Key().(*adapter.RawExpr); ok { if rawValue, ok := t.Key().(*adapter.RawExpr); ok {
columnValue.Column = exql.RawValue(rawValue.Raw()) columnValue.Column = &exql.Raw{Value: rawValue.Raw()}
args = append(args, rawValue.Arguments()...) args = append(args, rawValue.Arguments()...)
} else { } else {
columnValue.Column = exql.RawValue(fmt.Sprintf("%v", t.Key())) columnValue.Column = &exql.Raw{Value: fmt.Sprintf("%v", t.Key())}
} }
} }
@ -190,14 +190,14 @@ func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnVal
fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
} }
fnName, fnArgs = Preprocess(fnName, fnArgs) fnName, fnArgs = Preprocess(fnName, fnArgs)
columnValue.Value = exql.RawValue(fnName) columnValue.Value = &exql.Raw{Value: fnName}
args = append(args, fnArgs...) args = append(args, fnArgs...)
case *db.RawExpr: case *db.RawExpr:
q, a := Preprocess(value.Raw(), value.Arguments()) q, a := Preprocess(value.Raw(), value.Arguments())
columnValue.Value = exql.RawValue(q) columnValue.Value = &exql.Raw{Value: q}
args = append(args, a...) args = append(args, a...)
case driver.Valuer: case driver.Valuer:
columnValue.Value = exql.RawValue("?") columnValue.Value = sqlPlaceholder
args = append(args, value) args = append(args, value)
case *db.Comparison: case *db.Comparison:
wrapper := &operatorWrapper{ wrapper := &operatorWrapper{
@ -210,7 +210,7 @@ func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnVal
q, a = Preprocess(q, a) q, a = Preprocess(q, a)
columnValue = exql.ColumnValue{ columnValue = exql.ColumnValue{
Column: exql.RawValue(q), Column: &exql.Raw{Value: q},
} }
if a != nil { if a != nil {
args = append(args, a...) args = append(args, a...)
@ -229,7 +229,7 @@ func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnVal
q, a = Preprocess(q, a) q, a = Preprocess(q, a)
columnValue = exql.ColumnValue{ columnValue = exql.ColumnValue{
Column: exql.RawValue(q), Column: &exql.Raw{Value: q},
} }
if a != nil { if a != nil {
args = append(args, a...) args = append(args, a...)
@ -249,7 +249,7 @@ func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnVal
case *adapter.RawExpr: case *adapter.RawExpr:
columnValue := exql.ColumnValue{} columnValue := exql.ColumnValue{}
p, q := Preprocess(t.Raw(), t.Arguments()) p, q := Preprocess(t.Raw(), t.Arguments())
columnValue.Column = exql.RawValue(p) columnValue.Column = &exql.Raw{Value: p}
cv.ColumnValues = append(cv.ColumnValues, &columnValue) cv.ColumnValues = append(cv.ColumnValues, &columnValue)
args = append(args, q...) args = append(args, q...)
return cv, args return cv, args
@ -294,7 +294,7 @@ func (tu *templateWithUtils) setColumnValues(term interface{}) (cv exql.ColumnVa
columnValue := exql.ColumnValue{ columnValue := exql.ColumnValue{
Column: exql.ColumnWithName(column), Column: exql.ColumnWithName(column),
Operator: tu.AssignmentOperator, Operator: tu.AssignmentOperator,
Value: exql.RawValue(format), Value: &exql.Raw{Value: format},
} }
ps := strings.Count(format, "?") ps := strings.Count(format, "?")
@ -313,7 +313,7 @@ func (tu *templateWithUtils) setColumnValues(term interface{}) (cv exql.ColumnVa
case *adapter.RawExpr: case *adapter.RawExpr:
columnValue := exql.ColumnValue{} columnValue := exql.ColumnValue{}
p, q := Preprocess(t.Raw(), t.Arguments()) p, q := Preprocess(t.Raw(), t.Arguments())
columnValue.Column = exql.RawValue(p) columnValue.Column = &exql.Raw{Value: p}
cv.ColumnValues = append(cv.ColumnValues, &columnValue) cv.ColumnValues = append(cv.ColumnValues, &columnValue)
args = append(args, q...) args = append(args, q...)
return cv, args return cv, args

@ -12,7 +12,7 @@ import (
"errors" "errors"
"math/bits" "math/bits"
"golang.org/x/crypto/internal/subtle" "golang.org/x/crypto/internal/alias"
) )
const ( const (
@ -189,7 +189,7 @@ func (s *Cipher) XORKeyStream(dst, src []byte) {
panic("chacha20: output smaller than input") panic("chacha20: output smaller than input")
} }
dst = dst[:len(src)] dst = dst[:len(src)]
if subtle.InexactOverlap(dst, src) { if alias.InexactOverlap(dst, src) {
panic("chacha20: invalid buffer overlap") panic("chacha20: invalid buffer overlap")
} }

@ -5,9 +5,8 @@
//go:build !purego //go:build !purego
// +build !purego // +build !purego
// Package subtle implements functions that are often useful in cryptographic // Package alias implements memory aliasing tests.
// code but require careful thought to use correctly. package alias
package subtle // import "golang.org/x/crypto/internal/subtle"
import "unsafe" import "unsafe"

@ -5,9 +5,8 @@
//go:build purego //go:build purego
// +build purego // +build purego
// Package subtle implements functions that are often useful in cryptographic // Package alias implements memory aliasing tests.
// code but require careful thought to use correctly. package alias
package subtle // import "golang.org/x/crypto/internal/subtle"
// This is the Google App Engine standard variant based on reflect // This is the Google App Engine standard variant based on reflect
// because the unsafe package and cgo are disallowed. // because the unsafe package and cgo are disallowed.

@ -7,6 +7,8 @@
package sha3 package sha3
import "math/bits"
// rc stores the round constants for use in the ι step. // rc stores the round constants for use in the ι step.
var rc = [24]uint64{ var rc = [24]uint64{
0x0000000000000001, 0x0000000000000001,
@ -60,13 +62,13 @@ func keccakF1600(a *[25]uint64) {
bc0 = a[0] ^ d0 bc0 = a[0] ^ d0
t = a[6] ^ d1 t = a[6] ^ d1
bc1 = t<<44 | t>>(64-44) bc1 = bits.RotateLeft64(t, 44)
t = a[12] ^ d2 t = a[12] ^ d2
bc2 = t<<43 | t>>(64-43) bc2 = bits.RotateLeft64(t, 43)
t = a[18] ^ d3 t = a[18] ^ d3
bc3 = t<<21 | t>>(64-21) bc3 = bits.RotateLeft64(t, 21)
t = a[24] ^ d4 t = a[24] ^ d4
bc4 = t<<14 | t>>(64-14) bc4 = bits.RotateLeft64(t, 14)
a[0] = bc0 ^ (bc2 &^ bc1) ^ rc[i] a[0] = bc0 ^ (bc2 &^ bc1) ^ rc[i]
a[6] = bc1 ^ (bc3 &^ bc2) a[6] = bc1 ^ (bc3 &^ bc2)
a[12] = bc2 ^ (bc4 &^ bc3) a[12] = bc2 ^ (bc4 &^ bc3)
@ -74,15 +76,15 @@ func keccakF1600(a *[25]uint64) {
a[24] = bc4 ^ (bc1 &^ bc0) a[24] = bc4 ^ (bc1 &^ bc0)
t = a[10] ^ d0 t = a[10] ^ d0
bc2 = t<<3 | t>>(64-3) bc2 = bits.RotateLeft64(t, 3)
t = a[16] ^ d1 t = a[16] ^ d1
bc3 = t<<45 | t>>(64-45) bc3 = bits.RotateLeft64(t, 45)
t = a[22] ^ d2 t = a[22] ^ d2
bc4 = t<<61 | t>>(64-61) bc4 = bits.RotateLeft64(t, 61)
t = a[3] ^ d3 t = a[3] ^ d3
bc0 = t<<28 | t>>(64-28) bc0 = bits.RotateLeft64(t, 28)
t = a[9] ^ d4 t = a[9] ^ d4
bc1 = t<<20 | t>>(64-20) bc1 = bits.RotateLeft64(t, 20)
a[10] = bc0 ^ (bc2 &^ bc1) a[10] = bc0 ^ (bc2 &^ bc1)
a[16] = bc1 ^ (bc3 &^ bc2) a[16] = bc1 ^ (bc3 &^ bc2)
a[22] = bc2 ^ (bc4 &^ bc3) a[22] = bc2 ^ (bc4 &^ bc3)
@ -90,15 +92,15 @@ func keccakF1600(a *[25]uint64) {
a[9] = bc4 ^ (bc1 &^ bc0) a[9] = bc4 ^ (bc1 &^ bc0)
t = a[20] ^ d0 t = a[20] ^ d0
bc4 = t<<18 | t>>(64-18) bc4 = bits.RotateLeft64(t, 18)
t = a[1] ^ d1 t = a[1] ^ d1
bc0 = t<<1 | t>>(64-1) bc0 = bits.RotateLeft64(t, 1)
t = a[7] ^ d2 t = a[7] ^ d2
bc1 = t<<6 | t>>(64-6) bc1 = bits.RotateLeft64(t, 6)
t = a[13] ^ d3 t = a[13] ^ d3
bc2 = t<<25 | t>>(64-25) bc2 = bits.RotateLeft64(t, 25)
t = a[19] ^ d4 t = a[19] ^ d4
bc3 = t<<8 | t>>(64-8) bc3 = bits.RotateLeft64(t, 8)
a[20] = bc0 ^ (bc2 &^ bc1) a[20] = bc0 ^ (bc2 &^ bc1)
a[1] = bc1 ^ (bc3 &^ bc2) a[1] = bc1 ^ (bc3 &^ bc2)
a[7] = bc2 ^ (bc4 &^ bc3) a[7] = bc2 ^ (bc4 &^ bc3)
@ -106,15 +108,15 @@ func keccakF1600(a *[25]uint64) {
a[19] = bc4 ^ (bc1 &^ bc0) a[19] = bc4 ^ (bc1 &^ bc0)
t = a[5] ^ d0 t = a[5] ^ d0
bc1 = t<<36 | t>>(64-36) bc1 = bits.RotateLeft64(t, 36)
t = a[11] ^ d1 t = a[11] ^ d1
bc2 = t<<10 | t>>(64-10) bc2 = bits.RotateLeft64(t, 10)
t = a[17] ^ d2 t = a[17] ^ d2
bc3 = t<<15 | t>>(64-15) bc3 = bits.RotateLeft64(t, 15)
t = a[23] ^ d3 t = a[23] ^ d3
bc4 = t<<56 | t>>(64-56) bc4 = bits.RotateLeft64(t, 56)
t = a[4] ^ d4 t = a[4] ^ d4
bc0 = t<<27 | t>>(64-27) bc0 = bits.RotateLeft64(t, 27)
a[5] = bc0 ^ (bc2 &^ bc1) a[5] = bc0 ^ (bc2 &^ bc1)
a[11] = bc1 ^ (bc3 &^ bc2) a[11] = bc1 ^ (bc3 &^ bc2)
a[17] = bc2 ^ (bc4 &^ bc3) a[17] = bc2 ^ (bc4 &^ bc3)
@ -122,15 +124,15 @@ func keccakF1600(a *[25]uint64) {
a[4] = bc4 ^ (bc1 &^ bc0) a[4] = bc4 ^ (bc1 &^ bc0)
t = a[15] ^ d0 t = a[15] ^ d0
bc3 = t<<41 | t>>(64-41) bc3 = bits.RotateLeft64(t, 41)
t = a[21] ^ d1 t = a[21] ^ d1
bc4 = t<<2 | t>>(64-2) bc4 = bits.RotateLeft64(t, 2)
t = a[2] ^ d2 t = a[2] ^ d2
bc0 = t<<62 | t>>(64-62) bc0 = bits.RotateLeft64(t, 62)
t = a[8] ^ d3 t = a[8] ^ d3
bc1 = t<<55 | t>>(64-55) bc1 = bits.RotateLeft64(t, 55)
t = a[14] ^ d4 t = a[14] ^ d4
bc2 = t<<39 | t>>(64-39) bc2 = bits.RotateLeft64(t, 39)
a[15] = bc0 ^ (bc2 &^ bc1) a[15] = bc0 ^ (bc2 &^ bc1)
a[21] = bc1 ^ (bc3 &^ bc2) a[21] = bc1 ^ (bc3 &^ bc2)
a[2] = bc2 ^ (bc4 &^ bc3) a[2] = bc2 ^ (bc4 &^ bc3)
@ -151,13 +153,13 @@ func keccakF1600(a *[25]uint64) {
bc0 = a[0] ^ d0 bc0 = a[0] ^ d0
t = a[16] ^ d1 t = a[16] ^ d1
bc1 = t<<44 | t>>(64-44) bc1 = bits.RotateLeft64(t, 44)
t = a[7] ^ d2 t = a[7] ^ d2
bc2 = t<<43 | t>>(64-43) bc2 = bits.RotateLeft64(t, 43)
t = a[23] ^ d3 t = a[23] ^ d3
bc3 = t<<21 | t>>(64-21) bc3 = bits.RotateLeft64(t, 21)
t = a[14] ^ d4 t = a[14] ^ d4
bc4 = t<<14 | t>>(64-14) bc4 = bits.RotateLeft64(t, 14)
a[0] = bc0 ^ (bc2 &^ bc1) ^ rc[i+1] a[0] = bc0 ^ (bc2 &^ bc1) ^ rc[i+1]
a[16] = bc1 ^ (bc3 &^ bc2) a[16] = bc1 ^ (bc3 &^ bc2)
a[7] = bc2 ^ (bc4 &^ bc3) a[7] = bc2 ^ (bc4 &^ bc3)
@ -165,15 +167,15 @@ func keccakF1600(a *[25]uint64) {
a[14] = bc4 ^ (bc1 &^ bc0) a[14] = bc4 ^ (bc1 &^ bc0)
t = a[20] ^ d0 t = a[20] ^ d0
bc2 = t<<3 | t>>(64-3) bc2 = bits.RotateLeft64(t, 3)
t = a[11] ^ d1 t = a[11] ^ d1
bc3 = t<<45 | t>>(64-45) bc3 = bits.RotateLeft64(t, 45)
t = a[2] ^ d2 t = a[2] ^ d2
bc4 = t<<61 | t>>(64-61) bc4 = bits.RotateLeft64(t, 61)
t = a[18] ^ d3 t = a[18] ^ d3
bc0 = t<<28 | t>>(64-28) bc0 = bits.RotateLeft64(t, 28)
t = a[9] ^ d4 t = a[9] ^ d4
bc1 = t<<20 | t>>(64-20) bc1 = bits.RotateLeft64(t, 20)
a[20] = bc0 ^ (bc2 &^ bc1) a[20] = bc0 ^ (bc2 &^ bc1)
a[11] = bc1 ^ (bc3 &^ bc2) a[11] = bc1 ^ (bc3 &^ bc2)
a[2] = bc2 ^ (bc4 &^ bc3) a[2] = bc2 ^ (bc4 &^ bc3)
@ -181,15 +183,15 @@ func keccakF1600(a *[25]uint64) {
a[9] = bc4 ^ (bc1 &^ bc0) a[9] = bc4 ^ (bc1 &^ bc0)
t = a[15] ^ d0 t = a[15] ^ d0
bc4 = t<<18 | t>>(64-18) bc4 = bits.RotateLeft64(t, 18)
t = a[6] ^ d1 t = a[6] ^ d1
bc0 = t<<1 | t>>(64-1) bc0 = bits.RotateLeft64(t, 1)
t = a[22] ^ d2 t = a[22] ^ d2
bc1 = t<<6 | t>>(64-6) bc1 = bits.RotateLeft64(t, 6)
t = a[13] ^ d3 t = a[13] ^ d3
bc2 = t<<25 | t>>(64-25) bc2 = bits.RotateLeft64(t, 25)
t = a[4] ^ d4 t = a[4] ^ d4
bc3 = t<<8 | t>>(64-8) bc3 = bits.RotateLeft64(t, 8)
a[15] = bc0 ^ (bc2 &^ bc1) a[15] = bc0 ^ (bc2 &^ bc1)
a[6] = bc1 ^ (bc3 &^ bc2) a[6] = bc1 ^ (bc3 &^ bc2)
a[22] = bc2 ^ (bc4 &^ bc3) a[22] = bc2 ^ (bc4 &^ bc3)
@ -197,15 +199,15 @@ func keccakF1600(a *[25]uint64) {
a[4] = bc4 ^ (bc1 &^ bc0) a[4] = bc4 ^ (bc1 &^ bc0)
t = a[10] ^ d0 t = a[10] ^ d0
bc1 = t<<36 | t>>(64-36) bc1 = bits.RotateLeft64(t, 36)
t = a[1] ^ d1 t = a[1] ^ d1
bc2 = t<<10 | t>>(64-10) bc2 = bits.RotateLeft64(t, 10)
t = a[17] ^ d2 t = a[17] ^ d2
bc3 = t<<15 | t>>(64-15) bc3 = bits.RotateLeft64(t, 15)
t = a[8] ^ d3 t = a[8] ^ d3
bc4 = t<<56 | t>>(64-56) bc4 = bits.RotateLeft64(t, 56)
t = a[24] ^ d4 t = a[24] ^ d4
bc0 = t<<27 | t>>(64-27) bc0 = bits.RotateLeft64(t, 27)
a[10] = bc0 ^ (bc2 &^ bc1) a[10] = bc0 ^ (bc2 &^ bc1)
a[1] = bc1 ^ (bc3 &^ bc2) a[1] = bc1 ^ (bc3 &^ bc2)
a[17] = bc2 ^ (bc4 &^ bc3) a[17] = bc2 ^ (bc4 &^ bc3)
@ -213,15 +215,15 @@ func keccakF1600(a *[25]uint64) {
a[24] = bc4 ^ (bc1 &^ bc0) a[24] = bc4 ^ (bc1 &^ bc0)
t = a[5] ^ d0 t = a[5] ^ d0
bc3 = t<<41 | t>>(64-41) bc3 = bits.RotateLeft64(t, 41)
t = a[21] ^ d1 t = a[21] ^ d1
bc4 = t<<2 | t>>(64-2) bc4 = bits.RotateLeft64(t, 2)
t = a[12] ^ d2 t = a[12] ^ d2
bc0 = t<<62 | t>>(64-62) bc0 = bits.RotateLeft64(t, 62)
t = a[3] ^ d3 t = a[3] ^ d3
bc1 = t<<55 | t>>(64-55) bc1 = bits.RotateLeft64(t, 55)
t = a[19] ^ d4 t = a[19] ^ d4
bc2 = t<<39 | t>>(64-39) bc2 = bits.RotateLeft64(t, 39)
a[5] = bc0 ^ (bc2 &^ bc1) a[5] = bc0 ^ (bc2 &^ bc1)
a[21] = bc1 ^ (bc3 &^ bc2) a[21] = bc1 ^ (bc3 &^ bc2)
a[12] = bc2 ^ (bc4 &^ bc3) a[12] = bc2 ^ (bc4 &^ bc3)
@ -242,13 +244,13 @@ func keccakF1600(a *[25]uint64) {
bc0 = a[0] ^ d0 bc0 = a[0] ^ d0
t = a[11] ^ d1 t = a[11] ^ d1
bc1 = t<<44 | t>>(64-44) bc1 = bits.RotateLeft64(t, 44)
t = a[22] ^ d2 t = a[22] ^ d2
bc2 = t<<43 | t>>(64-43) bc2 = bits.RotateLeft64(t, 43)
t = a[8] ^ d3 t = a[8] ^ d3
bc3 = t<<21 | t>>(64-21) bc3 = bits.RotateLeft64(t, 21)
t = a[19] ^ d4 t = a[19] ^ d4
bc4 = t<<14 | t>>(64-14) bc4 = bits.RotateLeft64(t, 14)
a[0] = bc0 ^ (bc2 &^ bc1) ^ rc[i+2] a[0] = bc0 ^ (bc2 &^ bc1) ^ rc[i+2]
a[11] = bc1 ^ (bc3 &^ bc2) a[11] = bc1 ^ (bc3 &^ bc2)
a[22] = bc2 ^ (bc4 &^ bc3) a[22] = bc2 ^ (bc4 &^ bc3)
@ -256,15 +258,15 @@ func keccakF1600(a *[25]uint64) {
a[19] = bc4 ^ (bc1 &^ bc0) a[19] = bc4 ^ (bc1 &^ bc0)
t = a[15] ^ d0 t = a[15] ^ d0
bc2 = t<<3 | t>>(64-3) bc2 = bits.RotateLeft64(t, 3)
t = a[1] ^ d1 t = a[1] ^ d1
bc3 = t<<45 | t>>(64-45) bc3 = bits.RotateLeft64(t, 45)
t = a[12] ^ d2 t = a[12] ^ d2
bc4 = t<<61 | t>>(64-61) bc4 = bits.RotateLeft64(t, 61)
t = a[23] ^ d3 t = a[23] ^ d3
bc0 = t<<28 | t>>(64-28) bc0 = bits.RotateLeft64(t, 28)
t = a[9] ^ d4 t = a[9] ^ d4
bc1 = t<<20 | t>>(64-20) bc1 = bits.RotateLeft64(t, 20)
a[15] = bc0 ^ (bc2 &^ bc1) a[15] = bc0 ^ (bc2 &^ bc1)
a[1] = bc1 ^ (bc3 &^ bc2) a[1] = bc1 ^ (bc3 &^ bc2)
a[12] = bc2 ^ (bc4 &^ bc3) a[12] = bc2 ^ (bc4 &^ bc3)
@ -272,15 +274,15 @@ func keccakF1600(a *[25]uint64) {
a[9] = bc4 ^ (bc1 &^ bc0) a[9] = bc4 ^ (bc1 &^ bc0)
t = a[5] ^ d0 t = a[5] ^ d0
bc4 = t<<18 | t>>(64-18) bc4 = bits.RotateLeft64(t, 18)
t = a[16] ^ d1 t = a[16] ^ d1
bc0 = t<<1 | t>>(64-1) bc0 = bits.RotateLeft64(t, 1)
t = a[2] ^ d2 t = a[2] ^ d2
bc1 = t<<6 | t>>(64-6) bc1 = bits.RotateLeft64(t, 6)
t = a[13] ^ d3 t = a[13] ^ d3
bc2 = t<<25 | t>>(64-25) bc2 = bits.RotateLeft64(t, 25)
t = a[24] ^ d4 t = a[24] ^ d4
bc3 = t<<8 | t>>(64-8) bc3 = bits.RotateLeft64(t, 8)
a[5] = bc0 ^ (bc2 &^ bc1) a[5] = bc0 ^ (bc2 &^ bc1)
a[16] = bc1 ^ (bc3 &^ bc2) a[16] = bc1 ^ (bc3 &^ bc2)
a[2] = bc2 ^ (bc4 &^ bc3) a[2] = bc2 ^ (bc4 &^ bc3)
@ -288,15 +290,15 @@ func keccakF1600(a *[25]uint64) {
a[24] = bc4 ^ (bc1 &^ bc0) a[24] = bc4 ^ (bc1 &^ bc0)
t = a[20] ^ d0 t = a[20] ^ d0
bc1 = t<<36 | t>>(64-36) bc1 = bits.RotateLeft64(t, 36)
t = a[6] ^ d1 t = a[6] ^ d1
bc2 = t<<10 | t>>(64-10) bc2 = bits.RotateLeft64(t, 10)
t = a[17] ^ d2 t = a[17] ^ d2
bc3 = t<<15 | t>>(64-15) bc3 = bits.RotateLeft64(t, 15)
t = a[3] ^ d3 t = a[3] ^ d3
bc4 = t<<56 | t>>(64-56) bc4 = bits.RotateLeft64(t, 56)
t = a[14] ^ d4 t = a[14] ^ d4
bc0 = t<<27 | t>>(64-27) bc0 = bits.RotateLeft64(t, 27)
a[20] = bc0 ^ (bc2 &^ bc1) a[20] = bc0 ^ (bc2 &^ bc1)
a[6] = bc1 ^ (bc3 &^ bc2) a[6] = bc1 ^ (bc3 &^ bc2)
a[17] = bc2 ^ (bc4 &^ bc3) a[17] = bc2 ^ (bc4 &^ bc3)
@ -304,15 +306,15 @@ func keccakF1600(a *[25]uint64) {
a[14] = bc4 ^ (bc1 &^ bc0) a[14] = bc4 ^ (bc1 &^ bc0)
t = a[10] ^ d0 t = a[10] ^ d0
bc3 = t<<41 | t>>(64-41) bc3 = bits.RotateLeft64(t, 41)
t = a[21] ^ d1 t = a[21] ^ d1
bc4 = t<<2 | t>>(64-2) bc4 = bits.RotateLeft64(t, 2)
t = a[7] ^ d2 t = a[7] ^ d2
bc0 = t<<62 | t>>(64-62) bc0 = bits.RotateLeft64(t, 62)
t = a[18] ^ d3 t = a[18] ^ d3
bc1 = t<<55 | t>>(64-55) bc1 = bits.RotateLeft64(t, 55)
t = a[4] ^ d4 t = a[4] ^ d4
bc2 = t<<39 | t>>(64-39) bc2 = bits.RotateLeft64(t, 39)
a[10] = bc0 ^ (bc2 &^ bc1) a[10] = bc0 ^ (bc2 &^ bc1)
a[21] = bc1 ^ (bc3 &^ bc2) a[21] = bc1 ^ (bc3 &^ bc2)
a[7] = bc2 ^ (bc4 &^ bc3) a[7] = bc2 ^ (bc4 &^ bc3)
@ -333,13 +335,13 @@ func keccakF1600(a *[25]uint64) {
bc0 = a[0] ^ d0 bc0 = a[0] ^ d0
t = a[1] ^ d1 t = a[1] ^ d1
bc1 = t<<44 | t>>(64-44) bc1 = bits.RotateLeft64(t, 44)
t = a[2] ^ d2 t = a[2] ^ d2
bc2 = t<<43 | t>>(64-43) bc2 = bits.RotateLeft64(t, 43)
t = a[3] ^ d3 t = a[3] ^ d3
bc3 = t<<21 | t>>(64-21) bc3 = bits.RotateLeft64(t, 21)
t = a[4] ^ d4 t = a[4] ^ d4
bc4 = t<<14 | t>>(64-14) bc4 = bits.RotateLeft64(t, 14)
a[0] = bc0 ^ (bc2 &^ bc1) ^ rc[i+3] a[0] = bc0 ^ (bc2 &^ bc1) ^ rc[i+3]
a[1] = bc1 ^ (bc3 &^ bc2) a[1] = bc1 ^ (bc3 &^ bc2)
a[2] = bc2 ^ (bc4 &^ bc3) a[2] = bc2 ^ (bc4 &^ bc3)
@ -347,15 +349,15 @@ func keccakF1600(a *[25]uint64) {
a[4] = bc4 ^ (bc1 &^ bc0) a[4] = bc4 ^ (bc1 &^ bc0)
t = a[5] ^ d0 t = a[5] ^ d0
bc2 = t<<3 | t>>(64-3) bc2 = bits.RotateLeft64(t, 3)
t = a[6] ^ d1 t = a[6] ^ d1
bc3 = t<<45 | t>>(64-45) bc3 = bits.RotateLeft64(t, 45)
t = a[7] ^ d2 t = a[7] ^ d2
bc4 = t<<61 | t>>(64-61) bc4 = bits.RotateLeft64(t, 61)
t = a[8] ^ d3 t = a[8] ^ d3
bc0 = t<<28 | t>>(64-28) bc0 = bits.RotateLeft64(t, 28)
t = a[9] ^ d4 t = a[9] ^ d4
bc1 = t<<20 | t>>(64-20) bc1 = bits.RotateLeft64(t, 20)
a[5] = bc0 ^ (bc2 &^ bc1) a[5] = bc0 ^ (bc2 &^ bc1)
a[6] = bc1 ^ (bc3 &^ bc2) a[6] = bc1 ^ (bc3 &^ bc2)
a[7] = bc2 ^ (bc4 &^ bc3) a[7] = bc2 ^ (bc4 &^ bc3)
@ -363,15 +365,15 @@ func keccakF1600(a *[25]uint64) {
a[9] = bc4 ^ (bc1 &^ bc0) a[9] = bc4 ^ (bc1 &^ bc0)
t = a[10] ^ d0 t = a[10] ^ d0
bc4 = t<<18 | t>>(64-18) bc4 = bits.RotateLeft64(t, 18)
t = a[11] ^ d1 t = a[11] ^ d1
bc0 = t<<1 | t>>(64-1) bc0 = bits.RotateLeft64(t, 1)
t = a[12] ^ d2 t = a[12] ^ d2
bc1 = t<<6 | t>>(64-6) bc1 = bits.RotateLeft64(t, 6)
t = a[13] ^ d3 t = a[13] ^ d3
bc2 = t<<25 | t>>(64-25) bc2 = bits.RotateLeft64(t, 25)
t = a[14] ^ d4 t = a[14] ^ d4
bc3 = t<<8 | t>>(64-8) bc3 = bits.RotateLeft64(t, 8)
a[10] = bc0 ^ (bc2 &^ bc1) a[10] = bc0 ^ (bc2 &^ bc1)
a[11] = bc1 ^ (bc3 &^ bc2) a[11] = bc1 ^ (bc3 &^ bc2)
a[12] = bc2 ^ (bc4 &^ bc3) a[12] = bc2 ^ (bc4 &^ bc3)
@ -379,15 +381,15 @@ func keccakF1600(a *[25]uint64) {
a[14] = bc4 ^ (bc1 &^ bc0) a[14] = bc4 ^ (bc1 &^ bc0)
t = a[15] ^ d0 t = a[15] ^ d0
bc1 = t<<36 | t>>(64-36) bc1 = bits.RotateLeft64(t, 36)
t = a[16] ^ d1 t = a[16] ^ d1
bc2 = t<<10 | t>>(64-10) bc2 = bits.RotateLeft64(t, 10)
t = a[17] ^ d2 t = a[17] ^ d2
bc3 = t<<15 | t>>(64-15) bc3 = bits.RotateLeft64(t, 15)
t = a[18] ^ d3 t = a[18] ^ d3
bc4 = t<<56 | t>>(64-56) bc4 = bits.RotateLeft64(t, 56)
t = a[19] ^ d4 t = a[19] ^ d4
bc0 = t<<27 | t>>(64-27) bc0 = bits.RotateLeft64(t, 27)
a[15] = bc0 ^ (bc2 &^ bc1) a[15] = bc0 ^ (bc2 &^ bc1)
a[16] = bc1 ^ (bc3 &^ bc2) a[16] = bc1 ^ (bc3 &^ bc2)
a[17] = bc2 ^ (bc4 &^ bc3) a[17] = bc2 ^ (bc4 &^ bc3)
@ -395,15 +397,15 @@ func keccakF1600(a *[25]uint64) {
a[19] = bc4 ^ (bc1 &^ bc0) a[19] = bc4 ^ (bc1 &^ bc0)
t = a[20] ^ d0 t = a[20] ^ d0
bc3 = t<<41 | t>>(64-41) bc3 = bits.RotateLeft64(t, 41)
t = a[21] ^ d1 t = a[21] ^ d1
bc4 = t<<2 | t>>(64-2) bc4 = bits.RotateLeft64(t, 2)
t = a[22] ^ d2 t = a[22] ^ d2
bc0 = t<<62 | t>>(64-62) bc0 = bits.RotateLeft64(t, 62)
t = a[23] ^ d3 t = a[23] ^ d3
bc1 = t<<55 | t>>(64-55) bc1 = bits.RotateLeft64(t, 55)
t = a[24] ^ d4 t = a[24] ^ d4
bc2 = t<<39 | t>>(64-39) bc2 = bits.RotateLeft64(t, 39)
a[20] = bc0 ^ (bc2 &^ bc1) a[20] = bc0 ^ (bc2 &^ bc1)
a[21] = bc1 ^ (bc3 &^ bc2) a[21] = bc1 ^ (bc3 &^ bc2)
a[22] = bc2 ^ (bc4 &^ bc3) a[22] = bc2 ^ (bc4 &^ bc3)

@ -251,7 +251,7 @@ type algorithmOpenSSHCertSigner struct {
// private key is held by signer. It returns an error if the public key in cert // private key is held by signer. It returns an error if the public key in cert
// doesn't match the key used by signer. // doesn't match the key used by signer.
func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) {
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) {
return nil, errors.New("ssh: signer and cert have different public key") return nil, errors.New("ssh: signer and cert have different public key")
} }

@ -15,7 +15,6 @@ import (
"fmt" "fmt"
"hash" "hash"
"io" "io"
"io/ioutil"
"golang.org/x/crypto/chacha20" "golang.org/x/crypto/chacha20"
"golang.org/x/crypto/internal/poly1305" "golang.org/x/crypto/internal/poly1305"
@ -97,13 +96,13 @@ func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream,
// are not supported and will not be negotiated, even if explicitly requested in // are not supported and will not be negotiated, even if explicitly requested in
// ClientConfig.Crypto.Ciphers. // ClientConfig.Crypto.Ciphers.
var cipherModes = map[string]*cipherMode{ var cipherModes = map[string]*cipherMode{
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms // Ciphers from RFC 4344, which introduced many CTR-based ciphers. Algorithms
// are defined in the order specified in the RFC. // are defined in the order specified in the RFC.
"aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)}, "aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)},
"aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)}, "aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)},
"aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)}, "aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)},
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers. // Ciphers from RFC 4345, which introduces security-improved arcfour ciphers.
// They are defined in the order specified in the RFC. // They are defined in the order specified in the RFC.
"arcfour128": {16, 0, streamCipherMode(1536, newRC4)}, "arcfour128": {16, 0, streamCipherMode(1536, newRC4)},
"arcfour256": {32, 0, streamCipherMode(1536, newRC4)}, "arcfour256": {32, 0, streamCipherMode(1536, newRC4)},
@ -111,7 +110,7 @@ var cipherModes = map[string]*cipherMode{
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol. // Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and // Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
// RC4) has problems with weak keys, and should be used with caution." // RC4) has problems with weak keys, and should be used with caution."
// RFC4345 introduces improved versions of Arcfour. // RFC 4345 introduces improved versions of Arcfour.
"arcfour": {16, 0, streamCipherMode(0, newRC4)}, "arcfour": {16, 0, streamCipherMode(0, newRC4)},
// AEAD ciphers // AEAD ciphers
@ -497,7 +496,7 @@ func (c *cbcCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error)
// data, to make distinguishing between // data, to make distinguishing between
// failing MAC and failing length check more // failing MAC and failing length check more
// difficult. // difficult.
io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage)) io.CopyN(io.Discard, r, int64(c.oracleCamouflage))
} }
} }
return p, err return p, err
@ -642,7 +641,7 @@ const chacha20Poly1305ID = "chacha20-poly1305@openssh.com"
// //
// https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00 // https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00
// //
// the methods here also implement padding, which RFC4253 Section 6 // the methods here also implement padding, which RFC 4253 Section 6
// also requires of stream ciphers. // also requires of stream ciphers.
type chacha20Poly1305Cipher struct { type chacha20Poly1305Cipher struct {
lengthKey [32]byte lengthKey [32]byte

@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"io" "io"
"math" "math"
"strings"
"sync" "sync"
_ "crypto/sha1" _ "crypto/sha1"
@ -118,6 +119,20 @@ func algorithmsForKeyFormat(keyFormat string) []string {
} }
} }
// supportedPubKeyAuthAlgos specifies the supported client public key
// authentication algorithms. Note that this doesn't include certificate types
// since those use the underlying algorithm. This list is sent to the client if
// it supports the server-sig-algs extension. Order is irrelevant.
var supportedPubKeyAuthAlgos = []string{
KeyAlgoED25519,
KeyAlgoSKED25519, KeyAlgoSKECDSA256,
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA,
KeyAlgoDSA,
}
var supportedPubKeyAuthAlgosList = strings.Join(supportedPubKeyAuthAlgos, ",")
// unexpectedMessageError results when the SSH message that we received didn't // unexpectedMessageError results when the SSH message that we received didn't
// match what we wanted. // match what we wanted.
func unexpectedMessageError(expected, got uint8) error { func unexpectedMessageError(expected, got uint8) error {
@ -149,7 +164,7 @@ type directionAlgorithms struct {
// rekeyBytes returns a rekeying intervals in bytes. // rekeyBytes returns a rekeying intervals in bytes.
func (a *directionAlgorithms) rekeyBytes() int64 { func (a *directionAlgorithms) rekeyBytes() int64 {
// According to RFC4344 block ciphers should rekey after // According to RFC 4344 block ciphers should rekey after
// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is // 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
// 128. // 128.
switch a.Cipher { switch a.Cipher {
@ -158,7 +173,7 @@ func (a *directionAlgorithms) rekeyBytes() int64 {
} }
// For others, stick with RFC4253 recommendation to rekey after 1 Gb of data. // For others, stick with RFC 4253 recommendation to rekey after 1 Gb of data.
return 1 << 30 return 1 << 30
} }

@ -52,7 +52,7 @@ type Conn interface {
// SendRequest sends a global request, and returns the // SendRequest sends a global request, and returns the
// reply. If wantReply is true, it returns the response status // reply. If wantReply is true, it returns the response status
// and payload. See also RFC4254, section 4. // and payload. See also RFC 4254, section 4.
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
// OpenChannel tries to open an channel. If the request is // OpenChannel tries to open an channel. If the request is

@ -615,7 +615,8 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
return err return err
} }
if t.sessionID == nil { firstKeyExchange := t.sessionID == nil
if firstKeyExchange {
t.sessionID = result.H t.sessionID = result.H
} }
result.SessionID = t.sessionID result.SessionID = t.sessionID
@ -626,6 +627,24 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
return err return err
} }
// On the server side, after the first SSH_MSG_NEWKEYS, send a SSH_MSG_EXT_INFO
// message with the server-sig-algs extension if the client supports it. See
// RFC 8308, Sections 2.4 and 3.1.
if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") {
extInfo := &extInfoMsg{
NumExtensions: 1,
Payload: make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)),
}
extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs"))
extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...)
extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList))
extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...)
if err := t.conn.writePacket(Marshal(extInfo)); err != nil {
return err
}
}
if packet, err := t.conn.readPacket(); err != nil { if packet, err := t.conn.readPacket(); err != nil {
return err return err
} else if packet[0] != msgNewKeys { } else if packet[0] != msgNewKeys {

@ -184,7 +184,7 @@ func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey
return "", nil, nil, "", nil, io.EOF return "", nil, nil, "", nil, io.EOF
} }
// ParseAuthorizedKeys parses a public key from an authorized_keys // ParseAuthorizedKey parses a public key from an authorized_keys
// file used in OpenSSH according to the sshd(8) manual page. // file used in OpenSSH according to the sshd(8) manual page.
func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) {
for len(in) > 0 { for len(in) > 0 {

@ -68,7 +68,7 @@ type kexInitMsg struct {
// See RFC 4253, section 8. // See RFC 4253, section 8.
// Diffie-Helman // Diffie-Hellman
const msgKexDHInit = 30 const msgKexDHInit = 30
type kexDHInitMsg struct { type kexDHInitMsg struct {

@ -68,8 +68,16 @@ type ServerConfig struct {
// NoClientAuth is true if clients are allowed to connect without // NoClientAuth is true if clients are allowed to connect without
// authenticating. // authenticating.
// To determine NoClientAuth at runtime, set NoClientAuth to true
// and the optional NoClientAuthCallback to a non-nil value.
NoClientAuth bool NoClientAuth bool
// NoClientAuthCallback, if non-nil, is called when a user
// attempts to authenticate with auth method "none".
// NoClientAuth must also be set to true for this be used, or
// this func is unused.
NoClientAuthCallback func(ConnMetadata) (*Permissions, error)
// MaxAuthTries specifies the maximum number of authentication attempts // MaxAuthTries specifies the maximum number of authentication attempts
// permitted per connection. If set to a negative number, the number of // permitted per connection. If set to a negative number, the number of
// attempts are unlimited. If set to zero, the number of attempts are limited // attempts are unlimited. If set to zero, the number of attempts are limited
@ -283,15 +291,6 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
return perms, err return perms, err
} }
func isAcceptableAlgo(algo string) bool {
switch algo {
case KeyAlgoRSA, KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoSKECDSA256, KeyAlgoED25519, KeyAlgoSKED25519,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01:
return true
}
return false
}
func checkSourceAddress(addr net.Addr, sourceAddrs string) error { func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
if addr == nil { if addr == nil {
return errors.New("ssh: no address known for client, but source-address match required") return errors.New("ssh: no address known for client, but source-address match required")
@ -455,8 +454,12 @@ userAuthLoop:
switch userAuthReq.Method { switch userAuthReq.Method {
case "none": case "none":
if config.NoClientAuth { if config.NoClientAuth {
if config.NoClientAuthCallback != nil {
perms, authErr = config.NoClientAuthCallback(s)
} else {
authErr = nil authErr = nil
} }
}
// allow initial attempt of 'none' without penalty // allow initial attempt of 'none' without penalty
if authFailures == 0 { if authFailures == 0 {
@ -502,7 +505,7 @@ userAuthLoop:
return nil, parseError(msgUserAuthRequest) return nil, parseError(msgUserAuthRequest)
} }
algo := string(algoBytes) algo := string(algoBytes)
if !isAcceptableAlgo(algo) { if !contains(supportedPubKeyAuthAlgos, underlyingAlgo(algo)) {
authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo) authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
break break
} }
@ -560,7 +563,7 @@ userAuthLoop:
// algorithm name that corresponds to algo with // algorithm name that corresponds to algo with
// sig.Format. This is usually the same, but // sig.Format. This is usually the same, but
// for certs, the names differ. // for certs, the names differ.
if !isAcceptableAlgo(sig.Format) { if !contains(supportedPubKeyAuthAlgos, sig.Format) {
authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format) authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format)
break break
} }

@ -13,7 +13,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"sync" "sync"
) )
@ -124,7 +123,7 @@ type Session struct {
// output and error. // output and error.
// //
// If either is nil, Run connects the corresponding file // If either is nil, Run connects the corresponding file
// descriptor to an instance of ioutil.Discard. There is a // descriptor to an instance of io.Discard. There is a
// fixed amount of buffering that is shared for the two streams. // fixed amount of buffering that is shared for the two streams.
// If either blocks it may eventually cause the remote // If either blocks it may eventually cause the remote
// command to block. // command to block.
@ -506,7 +505,7 @@ func (s *Session) stdout() {
return return
} }
if s.Stdout == nil { if s.Stdout == nil {
s.Stdout = ioutil.Discard s.Stdout = io.Discard
} }
s.copyFuncs = append(s.copyFuncs, func() error { s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stdout, s.ch) _, err := io.Copy(s.Stdout, s.ch)
@ -519,7 +518,7 @@ func (s *Session) stderr() {
return return
} }
if s.Stderr == nil { if s.Stderr == nil {
s.Stderr = ioutil.Discard s.Stderr = io.Discard
} }
s.copyFuncs = append(s.copyFuncs, func() error { s.copyFuncs = append(s.copyFuncs, func() error {
_, err := io.Copy(s.Stderr, s.ch.Stderr()) _, err := io.Copy(s.Stderr, s.ch.Stderr())

@ -32,7 +32,7 @@ var DeadlineExceeded = context.DeadlineExceeded
// call cancel as soon as the operations running in this Context complete. // call cancel as soon as the operations running in this Context complete.
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
ctx, f := context.WithCancel(parent) ctx, f := context.WithCancel(parent)
return ctx, CancelFunc(f) return ctx, f
} }
// WithDeadline returns a copy of the parent context with the deadline adjusted // WithDeadline returns a copy of the parent context with the deadline adjusted
@ -46,7 +46,7 @@ func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
// call cancel as soon as the operations running in this Context complete. // call cancel as soon as the operations running in this Context complete.
func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) {
ctx, f := context.WithDeadline(parent, deadline) ctx, f := context.WithDeadline(parent, deadline)
return ctx, CancelFunc(f) return ctx, f
} }
// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). // WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).

@ -734,7 +734,7 @@ func inHeadIM(p *parser) bool {
return false return false
} }
// 12.2.6.4.5. // Section 12.2.6.4.5.
func inHeadNoscriptIM(p *parser) bool { func inHeadNoscriptIM(p *parser) bool {
switch p.tok.Type { switch p.tok.Type {
case DoctypeToken: case DoctypeToken:

@ -85,7 +85,7 @@ func render1(w writer, n *Node) error {
if _, err := w.WriteString("<!--"); err != nil { if _, err := w.WriteString("<!--"); err != nil {
return err return err
} }
if _, err := w.WriteString(n.Data); err != nil { if err := escape(w, n.Data); err != nil {
return err return err
} }
if _, err := w.WriteString("-->"); err != nil { if _, err := w.WriteString("-->"); err != nil {
@ -96,7 +96,7 @@ func render1(w writer, n *Node) error {
if _, err := w.WriteString("<!DOCTYPE "); err != nil { if _, err := w.WriteString("<!DOCTYPE "); err != nil {
return err return err
} }
if _, err := w.WriteString(n.Data); err != nil { if err := escape(w, n.Data); err != nil {
return err return err
} }
if n.Attr != nil { if n.Attr != nil {

@ -110,9 +110,9 @@ func (t Token) String() string {
case SelfClosingTagToken: case SelfClosingTagToken:
return "<" + t.tagString() + "/>" return "<" + t.tagString() + "/>"
case CommentToken: case CommentToken:
return "<!--" + t.Data + "-->" return "<!--" + EscapeString(t.Data) + "-->"
case DoctypeToken: case DoctypeToken:
return "<!DOCTYPE " + t.Data + ">" return "<!DOCTYPE " + EscapeString(t.Data) + ">"
} }
return "Invalid(" + strconv.Itoa(int(t.Type)) + ")" return "Invalid(" + strconv.Itoa(int(t.Type)) + ")"
} }
@ -605,7 +605,10 @@ func (z *Tokenizer) readComment() {
z.data.end = z.data.start z.data.end = z.data.start
} }
}() }()
for dashCount := 2; ; {
var dashCount int
beginning := true
for {
c := z.readByte() c := z.readByte()
if z.err != nil { if z.err != nil {
// Ignore up to two dashes at EOF. // Ignore up to two dashes at EOF.
@ -620,7 +623,7 @@ func (z *Tokenizer) readComment() {
dashCount++ dashCount++
continue continue
case '>': case '>':
if dashCount >= 2 { if dashCount >= 2 || beginning {
z.data.end = z.raw.end - len("-->") z.data.end = z.raw.end - len("-->")
return return
} }
@ -638,6 +641,7 @@ func (z *Tokenizer) readComment() {
} }
} }
dashCount = 0 dashCount = 0
beginning = false
} }
} }

@ -23,7 +23,7 @@ const frameHeaderLen = 9
var padZeros = make([]byte, 255) // zeros for padding var padZeros = make([]byte, 255) // zeros for padding
// A FrameType is a registered frame type as defined in // A FrameType is a registered frame type as defined in
// http://http2.github.io/http2-spec/#rfc.section.11.2 // https://httpwg.org/specs/rfc7540.html#rfc.section.11.2
type FrameType uint8 type FrameType uint8
const ( const (
@ -146,7 +146,7 @@ func typeFrameParser(t FrameType) frameParser {
// A FrameHeader is the 9 byte header of all HTTP/2 frames. // A FrameHeader is the 9 byte header of all HTTP/2 frames.
// //
// See http://http2.github.io/http2-spec/#FrameHeader // See https://httpwg.org/specs/rfc7540.html#FrameHeader
type FrameHeader struct { type FrameHeader struct {
valid bool // caller can access []byte fields in the Frame valid bool // caller can access []byte fields in the Frame
@ -575,7 +575,7 @@ func (fr *Framer) checkFrameOrder(f Frame) error {
// A DataFrame conveys arbitrary, variable-length sequences of octets // A DataFrame conveys arbitrary, variable-length sequences of octets
// associated with a stream. // associated with a stream.
// See http://http2.github.io/http2-spec/#rfc.section.6.1 // See https://httpwg.org/specs/rfc7540.html#rfc.section.6.1
type DataFrame struct { type DataFrame struct {
FrameHeader FrameHeader
data []byte data []byte
@ -698,7 +698,7 @@ func (f *Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []by
// endpoints communicate, such as preferences and constraints on peer // endpoints communicate, such as preferences and constraints on peer
// behavior. // behavior.
// //
// See http://http2.github.io/http2-spec/#SETTINGS // See https://httpwg.org/specs/rfc7540.html#SETTINGS
type SettingsFrame struct { type SettingsFrame struct {
FrameHeader FrameHeader
p []byte p []byte
@ -837,7 +837,7 @@ func (f *Framer) WriteSettingsAck() error {
// A PingFrame is a mechanism for measuring a minimal round trip time // A PingFrame is a mechanism for measuring a minimal round trip time
// from the sender, as well as determining whether an idle connection // from the sender, as well as determining whether an idle connection
// is still functional. // is still functional.
// See http://http2.github.io/http2-spec/#rfc.section.6.7 // See https://httpwg.org/specs/rfc7540.html#rfc.section.6.7
type PingFrame struct { type PingFrame struct {
FrameHeader FrameHeader
Data [8]byte Data [8]byte
@ -870,7 +870,7 @@ func (f *Framer) WritePing(ack bool, data [8]byte) error {
} }
// A GoAwayFrame informs the remote peer to stop creating streams on this connection. // A GoAwayFrame informs the remote peer to stop creating streams on this connection.
// See http://http2.github.io/http2-spec/#rfc.section.6.8 // See https://httpwg.org/specs/rfc7540.html#rfc.section.6.8
type GoAwayFrame struct { type GoAwayFrame struct {
FrameHeader FrameHeader
LastStreamID uint32 LastStreamID uint32
@ -934,7 +934,7 @@ func parseUnknownFrame(_ *frameCache, fh FrameHeader, countError func(string), p
} }
// A WindowUpdateFrame is used to implement flow control. // A WindowUpdateFrame is used to implement flow control.
// See http://http2.github.io/http2-spec/#rfc.section.6.9 // See https://httpwg.org/specs/rfc7540.html#rfc.section.6.9
type WindowUpdateFrame struct { type WindowUpdateFrame struct {
FrameHeader FrameHeader
Increment uint32 // never read with high bit set Increment uint32 // never read with high bit set
@ -1123,7 +1123,7 @@ func (f *Framer) WriteHeaders(p HeadersFrameParam) error {
} }
// A PriorityFrame specifies the sender-advised priority of a stream. // A PriorityFrame specifies the sender-advised priority of a stream.
// See http://http2.github.io/http2-spec/#rfc.section.6.3 // See https://httpwg.org/specs/rfc7540.html#rfc.section.6.3
type PriorityFrame struct { type PriorityFrame struct {
FrameHeader FrameHeader
PriorityParam PriorityParam
@ -1193,7 +1193,7 @@ func (f *Framer) WritePriority(streamID uint32, p PriorityParam) error {
} }
// A RSTStreamFrame allows for abnormal termination of a stream. // A RSTStreamFrame allows for abnormal termination of a stream.
// See http://http2.github.io/http2-spec/#rfc.section.6.4 // See https://httpwg.org/specs/rfc7540.html#rfc.section.6.4
type RSTStreamFrame struct { type RSTStreamFrame struct {
FrameHeader FrameHeader
ErrCode ErrCode ErrCode ErrCode
@ -1225,7 +1225,7 @@ func (f *Framer) WriteRSTStream(streamID uint32, code ErrCode) error {
} }
// A ContinuationFrame is used to continue a sequence of header block fragments. // A ContinuationFrame is used to continue a sequence of header block fragments.
// See http://http2.github.io/http2-spec/#rfc.section.6.10 // See https://httpwg.org/specs/rfc7540.html#rfc.section.6.10
type ContinuationFrame struct { type ContinuationFrame struct {
FrameHeader FrameHeader
headerFragBuf []byte headerFragBuf []byte
@ -1266,7 +1266,7 @@ func (f *Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlock
} }
// A PushPromiseFrame is used to initiate a server stream. // A PushPromiseFrame is used to initiate a server stream.
// See http://http2.github.io/http2-spec/#rfc.section.6.6 // See https://httpwg.org/specs/rfc7540.html#rfc.section.6.6
type PushPromiseFrame struct { type PushPromiseFrame struct {
FrameHeader FrameHeader
PromiseID uint32 PromiseID uint32

@ -70,6 +70,15 @@ func NewHandler(h http.Handler, s *http2.Server) http.Handler {
} }
} }
// extractServer extracts existing http.Server instance from http.Request or create an empty http.Server
func extractServer(r *http.Request) *http.Server {
server, ok := r.Context().Value(http.ServerContextKey).(*http.Server)
if ok {
return server
}
return new(http.Server)
}
// ServeHTTP implement the h2c support that is enabled by h2c.GetH2CHandler. // ServeHTTP implement the h2c support that is enabled by h2c.GetH2CHandler.
func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Handle h2c with prior knowledge (RFC 7540 Section 3.4) // Handle h2c with prior knowledge (RFC 7540 Section 3.4)
@ -87,6 +96,7 @@ func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer conn.Close() defer conn.Close()
s.s.ServeConn(conn, &http2.ServeConnOpts{ s.s.ServeConn(conn, &http2.ServeConnOpts{
Context: r.Context(), Context: r.Context(),
BaseConfig: extractServer(r),
Handler: s.Handler, Handler: s.Handler,
SawClientPreface: true, SawClientPreface: true,
}) })
@ -99,11 +109,13 @@ func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if http2VerboseLogs { if http2VerboseLogs {
log.Printf("h2c: error h2c upgrade: %v", err) log.Printf("h2c: error h2c upgrade: %v", err)
} }
w.WriteHeader(http.StatusInternalServerError)
return return
} }
defer conn.Close() defer conn.Close()
s.s.ServeConn(conn, &http2.ServeConnOpts{ s.s.ServeConn(conn, &http2.ServeConnOpts{
Context: r.Context(), Context: r.Context(),
BaseConfig: extractServer(r),
Handler: s.Handler, Handler: s.Handler,
UpgradeRequest: r, UpgradeRequest: r,
Settings: settings, Settings: settings,
@ -156,7 +168,10 @@ func h2cUpgrade(w http.ResponseWriter, r *http.Request) (_ net.Conn, settings []
return nil, nil, errors.New("h2c: connection does not support Hijack") return nil, nil, errors.New("h2c: connection does not support Hijack")
} }
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil {
return nil, nil, err
}
r.Body = io.NopCloser(bytes.NewBuffer(body)) r.Body = io.NopCloser(bytes.NewBuffer(body))
conn, rw, err := hijacker.Hijack() conn, rw, err := hijacker.Hijack()

@ -27,7 +27,14 @@ func buildCommonHeaderMaps() {
"accept-language", "accept-language",
"accept-ranges", "accept-ranges",
"age", "age",
"access-control-allow-credentials",
"access-control-allow-headers",
"access-control-allow-methods",
"access-control-allow-origin", "access-control-allow-origin",
"access-control-expose-headers",
"access-control-max-age",
"access-control-request-headers",
"access-control-request-method",
"allow", "allow",
"authorization", "authorization",
"cache-control", "cache-control",
@ -53,6 +60,7 @@ func buildCommonHeaderMaps() {
"link", "link",
"location", "location",
"max-forwards", "max-forwards",
"origin",
"proxy-authenticate", "proxy-authenticate",
"proxy-authorization", "proxy-authorization",
"range", "range",
@ -68,6 +76,8 @@ func buildCommonHeaderMaps() {
"vary", "vary",
"via", "via",
"www-authenticate", "www-authenticate",
"x-forwarded-for",
"x-forwarded-proto",
} }
commonLowerHeader = make(map[string]string, len(common)) commonLowerHeader = make(map[string]string, len(common))
commonCanonHeader = make(map[string]string, len(common)) commonCanonHeader = make(map[string]string, len(common))
@ -85,3 +95,11 @@ func lowerHeader(v string) (lower string, ascii bool) {
} }
return asciiToLower(v) return asciiToLower(v)
} }
func canonicalHeader(v string) string {
buildCommonHeaderMapsOnce()
if s, ok := commonCanonHeader[v]; ok {
return s
}
return http.CanonicalHeaderKey(v)
}

@ -191,7 +191,7 @@ func appendTableSize(dst []byte, v uint32) []byte {
// bit prefix, to dst and returns the extended buffer. // bit prefix, to dst and returns the extended buffer.
// //
// See // See
// http://http2.github.io/http2-spec/compression.html#integer.representation // https://httpwg.org/specs/rfc7541.html#integer.representation
func appendVarInt(dst []byte, n byte, i uint64) []byte { func appendVarInt(dst []byte, n byte, i uint64) []byte {
k := uint64((1 << n) - 1) k := uint64((1 << n) - 1)
if i < k { if i < k {

@ -59,7 +59,7 @@ func (hf HeaderField) String() string {
// Size returns the size of an entry per RFC 7541 section 4.1. // Size returns the size of an entry per RFC 7541 section 4.1.
func (hf HeaderField) Size() uint32 { func (hf HeaderField) Size() uint32 {
// http://http2.github.io/http2-spec/compression.html#rfc.section.4.1 // https://httpwg.org/specs/rfc7541.html#rfc.section.4.1
// "The size of the dynamic table is the sum of the size of // "The size of the dynamic table is the sum of the size of
// its entries. The size of an entry is the sum of its name's // its entries. The size of an entry is the sum of its name's
// length in octets (as defined in Section 5.2), its value's // length in octets (as defined in Section 5.2), its value's
@ -158,7 +158,7 @@ func (d *Decoder) SetAllowedMaxDynamicTableSize(v uint32) {
} }
type dynamicTable struct { type dynamicTable struct {
// http://http2.github.io/http2-spec/compression.html#rfc.section.2.3.2 // https://httpwg.org/specs/rfc7541.html#rfc.section.2.3.2
table headerFieldTable table headerFieldTable
size uint32 // in bytes size uint32 // in bytes
maxSize uint32 // current maxSize maxSize uint32 // current maxSize
@ -307,27 +307,27 @@ func (d *Decoder) parseHeaderFieldRepr() error {
case b&128 != 0: case b&128 != 0:
// Indexed representation. // Indexed representation.
// High bit set? // High bit set?
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.1 // https://httpwg.org/specs/rfc7541.html#rfc.section.6.1
return d.parseFieldIndexed() return d.parseFieldIndexed()
case b&192 == 64: case b&192 == 64:
// 6.2.1 Literal Header Field with Incremental Indexing // 6.2.1 Literal Header Field with Incremental Indexing
// 0b10xxxxxx: top two bits are 10 // 0b10xxxxxx: top two bits are 10
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.1 // https://httpwg.org/specs/rfc7541.html#rfc.section.6.2.1
return d.parseFieldLiteral(6, indexedTrue) return d.parseFieldLiteral(6, indexedTrue)
case b&240 == 0: case b&240 == 0:
// 6.2.2 Literal Header Field without Indexing // 6.2.2 Literal Header Field without Indexing
// 0b0000xxxx: top four bits are 0000 // 0b0000xxxx: top four bits are 0000
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.2 // https://httpwg.org/specs/rfc7541.html#rfc.section.6.2.2
return d.parseFieldLiteral(4, indexedFalse) return d.parseFieldLiteral(4, indexedFalse)
case b&240 == 16: case b&240 == 16:
// 6.2.3 Literal Header Field never Indexed // 6.2.3 Literal Header Field never Indexed
// 0b0001xxxx: top four bits are 0001 // 0b0001xxxx: top four bits are 0001
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.3 // https://httpwg.org/specs/rfc7541.html#rfc.section.6.2.3
return d.parseFieldLiteral(4, indexedNever) return d.parseFieldLiteral(4, indexedNever)
case b&224 == 32: case b&224 == 32:
// 6.3 Dynamic Table Size Update // 6.3 Dynamic Table Size Update
// Top three bits are '001'. // Top three bits are '001'.
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.3 // https://httpwg.org/specs/rfc7541.html#rfc.section.6.3
return d.parseDynamicTableSizeUpdate() return d.parseDynamicTableSizeUpdate()
} }
@ -420,7 +420,7 @@ var errVarintOverflow = DecodingError{errors.New("varint integer overflow")}
// readVarInt reads an unsigned variable length integer off the // readVarInt reads an unsigned variable length integer off the
// beginning of p. n is the parameter as described in // beginning of p. n is the parameter as described in
// http://http2.github.io/http2-spec/compression.html#rfc.section.5.1. // https://httpwg.org/specs/rfc7541.html#rfc.section.5.1.
// //
// n must always be between 1 and 8. // n must always be between 1 and 8.
// //

@ -0,0 +1,188 @@
// go generate gen.go
// Code generated by the command above; DO NOT EDIT.
package hpack
var staticTable = &headerFieldTable{
evictCount: 0,
byName: map[string]uint64{
":authority": 1,
":method": 3,
":path": 5,
":scheme": 7,
":status": 14,
"accept-charset": 15,
"accept-encoding": 16,
"accept-language": 17,
"accept-ranges": 18,
"accept": 19,
"access-control-allow-origin": 20,
"age": 21,
"allow": 22,
"authorization": 23,
"cache-control": 24,
"content-disposition": 25,
"content-encoding": 26,
"content-language": 27,
"content-length": 28,
"content-location": 29,
"content-range": 30,
"content-type": 31,
"cookie": 32,
"date": 33,
"etag": 34,
"expect": 35,
"expires": 36,
"from": 37,
"host": 38,
"if-match": 39,
"if-modified-since": 40,
"if-none-match": 41,
"if-range": 42,
"if-unmodified-since": 43,
"last-modified": 44,
"link": 45,
"location": 46,
"max-forwards": 47,
"proxy-authenticate": 48,
"proxy-authorization": 49,
"range": 50,
"referer": 51,
"refresh": 52,
"retry-after": 53,
"server": 54,
"set-cookie": 55,
"strict-transport-security": 56,
"transfer-encoding": 57,
"user-agent": 58,
"vary": 59,
"via": 60,
"www-authenticate": 61,
},
byNameValue: map[pairNameValue]uint64{
{name: ":authority", value: ""}: 1,
{name: ":method", value: "GET"}: 2,
{name: ":method", value: "POST"}: 3,
{name: ":path", value: "/"}: 4,
{name: ":path", value: "/index.html"}: 5,
{name: ":scheme", value: "http"}: 6,
{name: ":scheme", value: "https"}: 7,
{name: ":status", value: "200"}: 8,
{name: ":status", value: "204"}: 9,
{name: ":status", value: "206"}: 10,
{name: ":status", value: "304"}: 11,
{name: ":status", value: "400"}: 12,
{name: ":status", value: "404"}: 13,
{name: ":status", value: "500"}: 14,
{name: "accept-charset", value: ""}: 15,
{name: "accept-encoding", value: "gzip, deflate"}: 16,
{name: "accept-language", value: ""}: 17,
{name: "accept-ranges", value: ""}: 18,
{name: "accept", value: ""}: 19,
{name: "access-control-allow-origin", value: ""}: 20,
{name: "age", value: ""}: 21,
{name: "allow", value: ""}: 22,
{name: "authorization", value: ""}: 23,
{name: "cache-control", value: ""}: 24,
{name: "content-disposition", value: ""}: 25,
{name: "content-encoding", value: ""}: 26,
{name: "content-language", value: ""}: 27,
{name: "content-length", value: ""}: 28,
{name: "content-location", value: ""}: 29,
{name: "content-range", value: ""}: 30,
{name: "content-type", value: ""}: 31,
{name: "cookie", value: ""}: 32,
{name: "date", value: ""}: 33,
{name: "etag", value: ""}: 34,
{name: "expect", value: ""}: 35,
{name: "expires", value: ""}: 36,
{name: "from", value: ""}: 37,
{name: "host", value: ""}: 38,
{name: "if-match", value: ""}: 39,
{name: "if-modified-since", value: ""}: 40,
{name: "if-none-match", value: ""}: 41,
{name: "if-range", value: ""}: 42,
{name: "if-unmodified-since", value: ""}: 43,
{name: "last-modified", value: ""}: 44,
{name: "link", value: ""}: 45,
{name: "location", value: ""}: 46,
{name: "max-forwards", value: ""}: 47,
{name: "proxy-authenticate", value: ""}: 48,
{name: "proxy-authorization", value: ""}: 49,
{name: "range", value: ""}: 50,
{name: "referer", value: ""}: 51,
{name: "refresh", value: ""}: 52,
{name: "retry-after", value: ""}: 53,
{name: "server", value: ""}: 54,
{name: "set-cookie", value: ""}: 55,
{name: "strict-transport-security", value: ""}: 56,
{name: "transfer-encoding", value: ""}: 57,
{name: "user-agent", value: ""}: 58,
{name: "vary", value: ""}: 59,
{name: "via", value: ""}: 60,
{name: "www-authenticate", value: ""}: 61,
},
ents: []HeaderField{
{Name: ":authority", Value: "", Sensitive: false},
{Name: ":method", Value: "GET", Sensitive: false},
{Name: ":method", Value: "POST", Sensitive: false},
{Name: ":path", Value: "/", Sensitive: false},
{Name: ":path", Value: "/index.html", Sensitive: false},
{Name: ":scheme", Value: "http", Sensitive: false},
{Name: ":scheme", Value: "https", Sensitive: false},
{Name: ":status", Value: "200", Sensitive: false},
{Name: ":status", Value: "204", Sensitive: false},
{Name: ":status", Value: "206", Sensitive: false},
{Name: ":status", Value: "304", Sensitive: false},
{Name: ":status", Value: "400", Sensitive: false},
{Name: ":status", Value: "404", Sensitive: false},
{Name: ":status", Value: "500", Sensitive: false},
{Name: "accept-charset", Value: "", Sensitive: false},
{Name: "accept-encoding", Value: "gzip, deflate", Sensitive: false},
{Name: "accept-language", Value: "", Sensitive: false},
{Name: "accept-ranges", Value: "", Sensitive: false},
{Name: "accept", Value: "", Sensitive: false},
{Name: "access-control-allow-origin", Value: "", Sensitive: false},
{Name: "age", Value: "", Sensitive: false},
{Name: "allow", Value: "", Sensitive: false},
{Name: "authorization", Value: "", Sensitive: false},
{Name: "cache-control", Value: "", Sensitive: false},
{Name: "content-disposition", Value: "", Sensitive: false},
{Name: "content-encoding", Value: "", Sensitive: false},
{Name: "content-language", Value: "", Sensitive: false},
{Name: "content-length", Value: "", Sensitive: false},
{Name: "content-location", Value: "", Sensitive: false},
{Name: "content-range", Value: "", Sensitive: false},
{Name: "content-type", Value: "", Sensitive: false},
{Name: "cookie", Value: "", Sensitive: false},
{Name: "date", Value: "", Sensitive: false},
{Name: "etag", Value: "", Sensitive: false},
{Name: "expect", Value: "", Sensitive: false},
{Name: "expires", Value: "", Sensitive: false},
{Name: "from", Value: "", Sensitive: false},
{Name: "host", Value: "", Sensitive: false},
{Name: "if-match", Value: "", Sensitive: false},
{Name: "if-modified-since", Value: "", Sensitive: false},
{Name: "if-none-match", Value: "", Sensitive: false},
{Name: "if-range", Value: "", Sensitive: false},
{Name: "if-unmodified-since", Value: "", Sensitive: false},
{Name: "last-modified", Value: "", Sensitive: false},
{Name: "link", Value: "", Sensitive: false},
{Name: "location", Value: "", Sensitive: false},
{Name: "max-forwards", Value: "", Sensitive: false},
{Name: "proxy-authenticate", Value: "", Sensitive: false},
{Name: "proxy-authorization", Value: "", Sensitive: false},
{Name: "range", Value: "", Sensitive: false},
{Name: "referer", Value: "", Sensitive: false},
{Name: "refresh", Value: "", Sensitive: false},
{Name: "retry-after", Value: "", Sensitive: false},
{Name: "server", Value: "", Sensitive: false},
{Name: "set-cookie", Value: "", Sensitive: false},
{Name: "strict-transport-security", Value: "", Sensitive: false},
{Name: "transfer-encoding", Value: "", Sensitive: false},
{Name: "user-agent", Value: "", Sensitive: false},
{Name: "vary", Value: "", Sensitive: false},
{Name: "via", Value: "", Sensitive: false},
{Name: "www-authenticate", Value: "", Sensitive: false},
},
}

@ -96,8 +96,7 @@ func (t *headerFieldTable) evictOldest(n int) {
// meaning t.ents is reversed for dynamic tables. Hence, when t is a dynamic // meaning t.ents is reversed for dynamic tables. Hence, when t is a dynamic
// table, the return value i actually refers to the entry t.ents[t.len()-i]. // table, the return value i actually refers to the entry t.ents[t.len()-i].
// //
// All tables are assumed to be a dynamic tables except for the global // All tables are assumed to be a dynamic tables except for the global staticTable.
// staticTable pointer.
// //
// See Section 2.3.3. // See Section 2.3.3.
func (t *headerFieldTable) search(f HeaderField) (i uint64, nameValueMatch bool) { func (t *headerFieldTable) search(f HeaderField) (i uint64, nameValueMatch bool) {
@ -125,81 +124,6 @@ func (t *headerFieldTable) idToIndex(id uint64) uint64 {
return k + 1 return k + 1
} }
// http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-07#appendix-B
var staticTable = newStaticTable()
var staticTableEntries = [...]HeaderField{
{Name: ":authority"},
{Name: ":method", Value: "GET"},
{Name: ":method", Value: "POST"},
{Name: ":path", Value: "/"},
{Name: ":path", Value: "/index.html"},
{Name: ":scheme", Value: "http"},
{Name: ":scheme", Value: "https"},
{Name: ":status", Value: "200"},
{Name: ":status", Value: "204"},
{Name: ":status", Value: "206"},
{Name: ":status", Value: "304"},
{Name: ":status", Value: "400"},
{Name: ":status", Value: "404"},
{Name: ":status", Value: "500"},
{Name: "accept-charset"},
{Name: "accept-encoding", Value: "gzip, deflate"},
{Name: "accept-language"},
{Name: "accept-ranges"},
{Name: "accept"},
{Name: "access-control-allow-origin"},
{Name: "age"},
{Name: "allow"},
{Name: "authorization"},
{Name: "cache-control"},
{Name: "content-disposition"},
{Name: "content-encoding"},
{Name: "content-language"},
{Name: "content-length"},
{Name: "content-location"},
{Name: "content-range"},
{Name: "content-type"},
{Name: "cookie"},
{Name: "date"},
{Name: "etag"},
{Name: "expect"},
{Name: "expires"},
{Name: "from"},
{Name: "host"},
{Name: "if-match"},
{Name: "if-modified-since"},
{Name: "if-none-match"},
{Name: "if-range"},
{Name: "if-unmodified-since"},
{Name: "last-modified"},
{Name: "link"},
{Name: "location"},
{Name: "max-forwards"},
{Name: "proxy-authenticate"},
{Name: "proxy-authorization"},
{Name: "range"},
{Name: "referer"},
{Name: "refresh"},
{Name: "retry-after"},
{Name: "server"},
{Name: "set-cookie"},
{Name: "strict-transport-security"},
{Name: "transfer-encoding"},
{Name: "user-agent"},
{Name: "vary"},
{Name: "via"},
{Name: "www-authenticate"},
}
func newStaticTable() *headerFieldTable {
t := &headerFieldTable{}
t.init()
for _, e := range staticTableEntries[:] {
t.addEntry(e)
}
return t
}
var huffmanCodes = [256]uint32{ var huffmanCodes = [256]uint32{
0x1ff8, 0x1ff8,
0x7fffd8, 0x7fffd8,

@ -55,14 +55,14 @@ const (
ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
// SETTINGS_MAX_FRAME_SIZE default // SETTINGS_MAX_FRAME_SIZE default
// http://http2.github.io/http2-spec/#rfc.section.6.5.2 // https://httpwg.org/specs/rfc7540.html#rfc.section.6.5.2
initialMaxFrameSize = 16384 initialMaxFrameSize = 16384
// NextProtoTLS is the NPN/ALPN protocol negotiated during // NextProtoTLS is the NPN/ALPN protocol negotiated during
// HTTP/2's TLS setup. // HTTP/2's TLS setup.
NextProtoTLS = "h2" NextProtoTLS = "h2"
// http://http2.github.io/http2-spec/#SettingValues // https://httpwg.org/specs/rfc7540.html#SettingValues
initialHeaderTableSize = 4096 initialHeaderTableSize = 4096
initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size
@ -111,7 +111,7 @@ func (st streamState) String() string {
// Setting is a setting parameter: which setting it is, and its value. // Setting is a setting parameter: which setting it is, and its value.
type Setting struct { type Setting struct {
// ID is which setting is being set. // ID is which setting is being set.
// See http://http2.github.io/http2-spec/#SettingValues // See https://httpwg.org/specs/rfc7540.html#SettingFormat
ID SettingID ID SettingID
// Val is the value. // Val is the value.
@ -143,7 +143,7 @@ func (s Setting) Valid() error {
} }
// A SettingID is an HTTP/2 setting as defined in // A SettingID is an HTTP/2 setting as defined in
// http://http2.github.io/http2-spec/#iana-settings // https://httpwg.org/specs/rfc7540.html#iana-settings
type SettingID uint16 type SettingID uint16
const ( const (

@ -143,7 +143,7 @@ type Server struct {
} }
func (s *Server) initialConnRecvWindowSize() int32 { func (s *Server) initialConnRecvWindowSize() int32 {
if s.MaxUploadBufferPerConnection > initialWindowSize { if s.MaxUploadBufferPerConnection >= initialWindowSize {
return s.MaxUploadBufferPerConnection return s.MaxUploadBufferPerConnection
} }
return 1 << 20 return 1 << 20
@ -622,7 +622,9 @@ type stream struct {
resetQueued bool // RST_STREAM queued for write; set by sc.resetStream resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
gotTrailerHeader bool // HEADER frame for trailers was seen gotTrailerHeader bool // HEADER frame for trailers was seen
wroteHeaders bool // whether we wrote headers (not status 100) wroteHeaders bool // whether we wrote headers (not status 100)
readDeadline *time.Timer // nil if unused
writeDeadline *time.Timer // nil if unused writeDeadline *time.Timer // nil if unused
closeErr error // set before cw is closed
trailer http.Header // accumulated trailers trailer http.Header // accumulated trailers
reqTrailer http.Header // handler's Request.Trailer reqTrailer http.Header // handler's Request.Trailer
@ -948,6 +950,8 @@ func (sc *serverConn) serve() {
} }
case *startPushRequest: case *startPushRequest:
sc.startPush(v) sc.startPush(v)
case func(*serverConn):
v(sc)
default: default:
panic(fmt.Sprintf("unexpected type %T", v)) panic(fmt.Sprintf("unexpected type %T", v))
} }
@ -1371,6 +1375,9 @@ func (sc *serverConn) startGracefulShutdownInternal() {
func (sc *serverConn) goAway(code ErrCode) { func (sc *serverConn) goAway(code ErrCode) {
sc.serveG.check() sc.serveG.check()
if sc.inGoAway { if sc.inGoAway {
if sc.goAwayCode == ErrCodeNo {
sc.goAwayCode = code
}
return return
} }
sc.inGoAway = true sc.inGoAway = true
@ -1458,6 +1465,21 @@ func (sc *serverConn) processFrame(f Frame) error {
sc.sawFirstSettings = true sc.sawFirstSettings = true
} }
// Discard frames for streams initiated after the identified last
// stream sent in a GOAWAY, or all frames after sending an error.
// We still need to return connection-level flow control for DATA frames.
// RFC 9113 Section 6.8.
if sc.inGoAway && (sc.goAwayCode != ErrCodeNo || f.Header().StreamID > sc.maxClientStreamID) {
if f, ok := f.(*DataFrame); ok {
if sc.inflow.available() < int32(f.Length) {
return sc.countError("data_flow", streamError(f.Header().StreamID, ErrCodeFlowControl))
}
sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
}
return nil
}
switch f := f.(type) { switch f := f.(type) {
case *SettingsFrame: case *SettingsFrame:
return sc.processSettings(f) return sc.processSettings(f)
@ -1500,9 +1522,6 @@ func (sc *serverConn) processPing(f *PingFrame) error {
// PROTOCOL_ERROR." // PROTOCOL_ERROR."
return sc.countError("ping_on_stream", ConnectionError(ErrCodeProtocol)) return sc.countError("ping_on_stream", ConnectionError(ErrCodeProtocol))
} }
if sc.inGoAway && sc.goAwayCode != ErrCodeNo {
return nil
}
sc.writeFrame(FrameWriteRequest{write: writePingAck{f}}) sc.writeFrame(FrameWriteRequest{write: writePingAck{f}})
return nil return nil
} }
@ -1564,6 +1583,9 @@ func (sc *serverConn) closeStream(st *stream, err error) {
panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state))
} }
st.state = stateClosed st.state = stateClosed
if st.readDeadline != nil {
st.readDeadline.Stop()
}
if st.writeDeadline != nil { if st.writeDeadline != nil {
st.writeDeadline.Stop() st.writeDeadline.Stop()
} }
@ -1589,6 +1611,14 @@ func (sc *serverConn) closeStream(st *stream, err error) {
p.CloseWithError(err) p.CloseWithError(err)
} }
if e, ok := err.(StreamError); ok {
if e.Cause != nil {
err = e.Cause
} else {
err = errStreamClosed
}
}
st.closeErr = err
st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc
sc.writeSched.CloseStream(st.id) sc.writeSched.CloseStream(st.id)
} }
@ -1685,16 +1715,6 @@ func (sc *serverConn) processSettingInitialWindowSize(val uint32) error {
func (sc *serverConn) processData(f *DataFrame) error { func (sc *serverConn) processData(f *DataFrame) error {
sc.serveG.check() sc.serveG.check()
id := f.Header().StreamID id := f.Header().StreamID
if sc.inGoAway && (sc.goAwayCode != ErrCodeNo || id > sc.maxClientStreamID) {
// Discard all DATA frames if the GOAWAY is due to an
// error, or:
//
// Section 6.8: After sending a GOAWAY frame, the sender
// can discard frames for streams initiated by the
// receiver with identifiers higher than the identified
// last stream.
return nil
}
data := f.Data() data := f.Data()
state, st := sc.state(id) state, st := sc.state(id)
@ -1747,6 +1767,12 @@ func (sc *serverConn) processData(f *DataFrame) error {
// Sender sending more than they'd declared? // Sender sending more than they'd declared?
if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
if sc.inflow.available() < int32(f.Length) {
return sc.countError("data_flow", streamError(id, ErrCodeFlowControl))
}
sc.inflow.take(int32(f.Length))
sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
// RFC 7540, sec 8.1.2.6: A request or response is also malformed if the // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the
// value of a content-length header field does not equal the sum of the // value of a content-length header field does not equal the sum of the
@ -1831,19 +1857,27 @@ func (st *stream) copyTrailersToHandlerRequest() {
} }
} }
// onReadTimeout is run on its own goroutine (from time.AfterFunc)
// when the stream's ReadTimeout has fired.
func (st *stream) onReadTimeout() {
// Wrap the ErrDeadlineExceeded to avoid callers depending on us
// returning the bare error.
st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded))
}
// onWriteTimeout is run on its own goroutine (from time.AfterFunc) // onWriteTimeout is run on its own goroutine (from time.AfterFunc)
// when the stream's WriteTimeout has fired. // when the stream's WriteTimeout has fired.
func (st *stream) onWriteTimeout() { func (st *stream) onWriteTimeout() {
st.sc.writeFrameFromHandler(FrameWriteRequest{write: streamError(st.id, ErrCodeInternal)}) st.sc.writeFrameFromHandler(FrameWriteRequest{write: StreamError{
StreamID: st.id,
Code: ErrCodeInternal,
Cause: os.ErrDeadlineExceeded,
}})
} }
func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
sc.serveG.check() sc.serveG.check()
id := f.StreamID id := f.StreamID
if sc.inGoAway {
// Ignore.
return nil
}
// http://tools.ietf.org/html/rfc7540#section-5.1.1 // http://tools.ietf.org/html/rfc7540#section-5.1.1
// Streams initiated by a client MUST use odd-numbered stream // Streams initiated by a client MUST use odd-numbered stream
// identifiers. [...] An endpoint that receives an unexpected // identifiers. [...] An endpoint that receives an unexpected
@ -1946,6 +1980,9 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// (in Go 1.8), though. That's a more sane option anyway. // (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout != 0 { if sc.hs.ReadTimeout != 0 {
sc.conn.SetReadDeadline(time.Time{}) sc.conn.SetReadDeadline(time.Time{})
if st.body != nil {
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
} }
go sc.runHandler(rw, req, handler) go sc.runHandler(rw, req, handler)
@ -2014,9 +2051,6 @@ func (sc *serverConn) checkPriority(streamID uint32, p PriorityParam) error {
} }
func (sc *serverConn) processPriority(f *PriorityFrame) error { func (sc *serverConn) processPriority(f *PriorityFrame) error {
if sc.inGoAway {
return nil
}
if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil { if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil {
return err return err
} }
@ -2090,12 +2124,6 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol)) return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol))
} }
bodyOpen := !f.StreamEnded()
if rp.method == "HEAD" && bodyOpen {
// HEAD requests can't have bodies
return nil, nil, sc.countError("head_body", streamError(f.StreamID, ErrCodeProtocol))
}
rp.header = make(http.Header) rp.header = make(http.Header)
for _, hf := range f.RegularFields() { for _, hf := range f.RegularFields() {
rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value)
@ -2108,6 +2136,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
bodyOpen := !f.StreamEnded()
if bodyOpen { if bodyOpen {
if vv, ok := rp.header["Content-Length"]; ok { if vv, ok := rp.header["Content-Length"]; ok {
if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil { if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil {
@ -2223,6 +2252,9 @@ func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler
didPanic := true didPanic := true
defer func() { defer func() {
rw.rws.stream.cancelCtx() rw.rws.stream.cancelCtx()
if req.MultipartForm != nil {
req.MultipartForm.RemoveAll()
}
if didPanic { if didPanic {
e := recover() e := recover()
sc.writeFrameFromHandler(FrameWriteRequest{ sc.writeFrameFromHandler(FrameWriteRequest{
@ -2334,7 +2366,7 @@ func (sc *serverConn) sendWindowUpdate(st *stream, n int) {
// a larger Read than this. Very unlikely, but we handle it here // a larger Read than this. Very unlikely, but we handle it here
// rather than elsewhere for now. // rather than elsewhere for now.
const maxUint31 = 1<<31 - 1 const maxUint31 = 1<<31 - 1
for n >= maxUint31 { for n > maxUint31 {
sc.sendWindowUpdate32(st, maxUint31) sc.sendWindowUpdate32(st, maxUint31)
n -= maxUint31 n -= maxUint31
} }
@ -2454,7 +2486,15 @@ type responseWriterState struct {
type chunkWriter struct{ rws *responseWriterState } type chunkWriter struct{ rws *responseWriterState }
func (cw chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) } func (cw chunkWriter) Write(p []byte) (n int, err error) {
n, err = cw.rws.writeChunk(p)
if err == errStreamClosed {
// If writing failed because the stream has been closed,
// return the reason it was closed.
err = cw.rws.stream.closeErr
}
return n, err
}
func (rws *responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 } func (rws *responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 }
@ -2493,6 +2533,10 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
rws.writeHeader(200) rws.writeHeader(200)
} }
if rws.handlerDone {
rws.promoteUndeclaredTrailers()
}
isHeadResp := rws.req.Method == "HEAD" isHeadResp := rws.req.Method == "HEAD"
if !rws.sentHeader { if !rws.sentHeader {
rws.sentHeader = true rws.sentHeader = true
@ -2564,10 +2608,6 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
return 0, nil return 0, nil
} }
if rws.handlerDone {
rws.promoteUndeclaredTrailers()
}
// only send trailers if they have actually been defined by the // only send trailers if they have actually been defined by the
// server handler. // server handler.
hasNonemptyTrailers := rws.hasNonemptyTrailers() hasNonemptyTrailers := rws.hasNonemptyTrailers()
@ -2648,23 +2688,85 @@ func (rws *responseWriterState) promoteUndeclaredTrailers() {
} }
} }
func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
st := w.rws.stream
if !deadline.IsZero() && deadline.Before(time.Now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onReadTimeout()
return nil
}
w.rws.conn.sendServeMsg(func(sc *serverConn) {
if st.readDeadline != nil {
if !st.readDeadline.Stop() {
// Deadline already exceeded, or stream has been closed.
return
}
}
if deadline.IsZero() {
st.readDeadline = nil
} else if st.readDeadline == nil {
st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout)
} else {
st.readDeadline.Reset(deadline.Sub(time.Now()))
}
})
return nil
}
func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
st := w.rws.stream
if !deadline.IsZero() && deadline.Before(time.Now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onWriteTimeout()
return nil
}
w.rws.conn.sendServeMsg(func(sc *serverConn) {
if st.writeDeadline != nil {
if !st.writeDeadline.Stop() {
// Deadline already exceeded, or stream has been closed.
return
}
}
if deadline.IsZero() {
st.writeDeadline = nil
} else if st.writeDeadline == nil {
st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout)
} else {
st.writeDeadline.Reset(deadline.Sub(time.Now()))
}
})
return nil
}
func (w *responseWriter) Flush() { func (w *responseWriter) Flush() {
w.FlushError()
}
func (w *responseWriter) FlushError() error {
rws := w.rws rws := w.rws
if rws == nil { if rws == nil {
panic("Header called after Handler finished") panic("Header called after Handler finished")
} }
var err error
if rws.bw.Buffered() > 0 { if rws.bw.Buffered() > 0 {
if err := rws.bw.Flush(); err != nil { err = rws.bw.Flush()
// Ignore the error. The frame writer already knows.
return
}
} else { } else {
// The bufio.Writer won't call chunkWriter.Write // The bufio.Writer won't call chunkWriter.Write
// (writeChunk with zero bytes, so we have to do it // (writeChunk with zero bytes, so we have to do it
// ourselves to force the HTTP response header and/or // ourselves to force the HTTP response header and/or
// final DATA frame (with END_STREAM) to be sent. // final DATA frame (with END_STREAM) to be sent.
rws.writeChunk(nil) _, err = chunkWriter{rws}.Write(nil)
if err == nil {
select {
case <-rws.stream.cw:
err = rws.stream.closeErr
default:
}
} }
}
return err
} }
func (w *responseWriter) CloseNotify() <-chan bool { func (w *responseWriter) CloseNotify() <-chan bool {

@ -16,6 +16,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"log" "log"
"math" "math"
mathrand "math/rand" mathrand "math/rand"
@ -67,13 +68,23 @@ const (
// A Transport internally caches connections to servers. It is safe // A Transport internally caches connections to servers. It is safe
// for concurrent use by multiple goroutines. // for concurrent use by multiple goroutines.
type Transport struct { type Transport struct {
// DialTLS specifies an optional dial function for creating // DialTLSContext specifies an optional dial function with context for
// TLS connections for requests. // creating TLS connections for requests.
// //
// If DialTLS is nil, tls.Dial is used. // If DialTLSContext and DialTLS is nil, tls.Dial is used.
// //
// If the returned net.Conn has a ConnectionState method like tls.Conn, // If the returned net.Conn has a ConnectionState method like tls.Conn,
// it will be used to set http.Response.TLS. // it will be used to set http.Response.TLS.
DialTLSContext func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error)
// DialTLS specifies an optional dial function for creating
// TLS connections for requests.
//
// If DialTLSContext and DialTLS is nil, tls.Dial is used.
//
// Deprecated: Use DialTLSContext instead, which allows the transport
// to cancel dials as soon as they are no longer needed.
// If both are set, DialTLSContext takes priority.
DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error) DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error)
// TLSClientConfig specifies the TLS configuration to use with // TLSClientConfig specifies the TLS configuration to use with
@ -249,6 +260,7 @@ func (t *Transport) initConnPool() {
type ClientConn struct { type ClientConn struct {
t *Transport t *Transport
tconn net.Conn // usually *tls.Conn, except specialized impls tconn net.Conn // usually *tls.Conn, except specialized impls
tconnClosed bool
tlsState *tls.ConnectionState // nil only for specialized impls tlsState *tls.ConnectionState // nil only for specialized impls
reused uint32 // whether conn is being reused; atomic reused uint32 // whether conn is being reused; atomic
singleUse bool // whether being used for a single http.Request singleUse bool // whether being used for a single http.Request
@ -335,7 +347,7 @@ type clientStream struct {
reqBody io.ReadCloser reqBody io.ReadCloser
reqBodyContentLength int64 // -1 means unknown reqBodyContentLength int64 // -1 means unknown
reqBodyClosed bool // body has been closed; guarded by cc.mu reqBodyClosed chan struct{} // guarded by cc.mu; non-nil on Close, closed when done
// owned by writeRequest: // owned by writeRequest:
sentEndStream bool // sent an END_STREAM flag to the peer sentEndStream bool // sent an END_STREAM flag to the peer
@ -375,9 +387,8 @@ func (cs *clientStream) abortStreamLocked(err error) {
cs.abortErr = err cs.abortErr = err
close(cs.abort) close(cs.abort)
}) })
if cs.reqBody != nil && !cs.reqBodyClosed { if cs.reqBody != nil {
cs.reqBody.Close() cs.closeReqBodyLocked()
cs.reqBodyClosed = true
} }
// TODO(dneil): Clean up tests where cs.cc.cond is nil. // TODO(dneil): Clean up tests where cs.cc.cond is nil.
if cs.cc.cond != nil { if cs.cc.cond != nil {
@ -390,13 +401,24 @@ func (cs *clientStream) abortRequestBodyWrite() {
cc := cs.cc cc := cs.cc
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
if cs.reqBody != nil && !cs.reqBodyClosed { if cs.reqBody != nil && cs.reqBodyClosed == nil {
cs.reqBody.Close() cs.closeReqBodyLocked()
cs.reqBodyClosed = true
cc.cond.Broadcast() cc.cond.Broadcast()
} }
} }
func (cs *clientStream) closeReqBodyLocked() {
if cs.reqBodyClosed != nil {
return
}
cs.reqBodyClosed = make(chan struct{})
reqBodyClosed := cs.reqBodyClosed
go func() {
cs.reqBody.Close()
close(reqBodyClosed)
}()
}
type stickyErrWriter struct { type stickyErrWriter struct {
conn net.Conn conn net.Conn
timeout time.Duration timeout time.Duration
@ -480,6 +502,15 @@ func authorityAddr(scheme string, authority string) (addr string) {
return net.JoinHostPort(host, port) return net.JoinHostPort(host, port)
} }
var retryBackoffHook func(time.Duration) *time.Timer
func backoffNewTimer(d time.Duration) *time.Timer {
if retryBackoffHook != nil {
return retryBackoffHook(d)
}
return time.NewTimer(d)
}
// RoundTripOpt is like RoundTrip, but takes options. // RoundTripOpt is like RoundTrip, but takes options.
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) {
@ -505,11 +536,14 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
} }
backoff := float64(uint(1) << (uint(retry) - 1)) backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64()) backoff += backoff * (0.1 * mathrand.Float64())
d := time.Second * time.Duration(backoff)
timer := backoffNewTimer(d)
select { select {
case <-time.After(time.Second * time.Duration(backoff)): case <-timer.C:
t.vlogf("RoundTrip retrying after failure: %v", err) t.vlogf("RoundTrip retrying after failure: %v", err)
continue continue
case <-req.Context().Done(): case <-req.Context().Done():
timer.Stop()
err = req.Context().Err() err = req.Context().Err()
} }
} }
@ -592,7 +626,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b
if err != nil { if err != nil {
return nil, err return nil, err
} }
tconn, err := t.dialTLS(ctx)("tcp", addr, t.newTLSConfig(host)) tconn, err := t.dialTLS(ctx, "tcp", addr, t.newTLSConfig(host))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -613,12 +647,14 @@ func (t *Transport) newTLSConfig(host string) *tls.Config {
return cfg return cfg
} }
func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) { func (t *Transport) dialTLS(ctx context.Context, network, addr string, tlsCfg *tls.Config) (net.Conn, error) {
if t.DialTLS != nil { if t.DialTLSContext != nil {
return t.DialTLS return t.DialTLSContext(ctx, network, addr, tlsCfg)
} else if t.DialTLS != nil {
return t.DialTLS(network, addr, tlsCfg)
} }
return func(network, addr string, cfg *tls.Config) (net.Conn, error) {
tlsCn, err := t.dialTLSWithContext(ctx, network, addr, cfg) tlsCn, err := t.dialTLSWithContext(ctx, network, addr, tlsCfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -630,7 +666,6 @@ func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Confi
return nil, errors.New("http2: could not negotiate protocol mutually") return nil, errors.New("http2: could not negotiate protocol mutually")
} }
return tlsCn, nil return tlsCn, nil
}
} }
// disableKeepAlives reports whether connections should be closed as // disableKeepAlives reports whether connections should be closed as
@ -910,10 +945,10 @@ func (cc *ClientConn) onIdleTimeout() {
cc.closeIfIdle() cc.closeIfIdle()
} }
func (cc *ClientConn) closeConn() error { func (cc *ClientConn) closeConn() {
t := time.AfterFunc(250*time.Millisecond, cc.forceCloseConn) t := time.AfterFunc(250*time.Millisecond, cc.forceCloseConn)
defer t.Stop() defer t.Stop()
return cc.tconn.Close() cc.tconn.Close()
} }
// A tls.Conn.Close can hang for a long time if the peer is unresponsive. // A tls.Conn.Close can hang for a long time if the peer is unresponsive.
@ -979,7 +1014,8 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error {
shutdownEnterWaitStateHook() shutdownEnterWaitStateHook()
select { select {
case <-done: case <-done:
return cc.closeConn() cc.closeConn()
return nil
case <-ctx.Done(): case <-ctx.Done():
cc.mu.Lock() cc.mu.Lock()
// Free the goroutine above // Free the goroutine above
@ -1016,7 +1052,7 @@ func (cc *ClientConn) sendGoAway() error {
// closes the client connection immediately. In-flight requests are interrupted. // closes the client connection immediately. In-flight requests are interrupted.
// err is sent to streams. // err is sent to streams.
func (cc *ClientConn) closeForError(err error) error { func (cc *ClientConn) closeForError(err error) {
cc.mu.Lock() cc.mu.Lock()
cc.closed = true cc.closed = true
for _, cs := range cc.streams { for _, cs := range cc.streams {
@ -1024,7 +1060,7 @@ func (cc *ClientConn) closeForError(err error) error {
} }
cc.cond.Broadcast() cc.cond.Broadcast()
cc.mu.Unlock() cc.mu.Unlock()
return cc.closeConn() cc.closeConn()
} }
// Close closes the client connection immediately. // Close closes the client connection immediately.
@ -1032,16 +1068,17 @@ func (cc *ClientConn) closeForError(err error) error {
// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. // In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
func (cc *ClientConn) Close() error { func (cc *ClientConn) Close() error {
err := errors.New("http2: client connection force closed via ClientConn.Close") err := errors.New("http2: client connection force closed via ClientConn.Close")
return cc.closeForError(err) cc.closeForError(err)
return nil
} }
// closes the client connection immediately. In-flight requests are interrupted. // closes the client connection immediately. In-flight requests are interrupted.
func (cc *ClientConn) closeForLostPing() error { func (cc *ClientConn) closeForLostPing() {
err := errors.New("http2: client connection lost") err := errors.New("http2: client connection lost")
if f := cc.t.CountError; f != nil { if f := cc.t.CountError; f != nil {
f("conn_close_lost_ping") f("conn_close_lost_ping")
} }
return cc.closeForError(err) cc.closeForError(err)
} }
// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not // errRequestCanceled is a copy of net/http's errRequestCanceled because it's not
@ -1051,7 +1088,7 @@ var errRequestCanceled = errors.New("net/http: request canceled")
func commaSeparatedTrailers(req *http.Request) (string, error) { func commaSeparatedTrailers(req *http.Request) (string, error) {
keys := make([]string, 0, len(req.Trailer)) keys := make([]string, 0, len(req.Trailer))
for k := range req.Trailer { for k := range req.Trailer {
k = http.CanonicalHeaderKey(k) k = canonicalHeader(k)
switch k { switch k {
case "Transfer-Encoding", "Trailer", "Content-Length": case "Transfer-Encoding", "Trailer", "Content-Length":
return "", fmt.Errorf("invalid Trailer key %q", k) return "", fmt.Errorf("invalid Trailer key %q", k)
@ -1419,11 +1456,19 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
// and in multiple cases: server replies <=299 and >299 // and in multiple cases: server replies <=299 and >299
// while still writing request body // while still writing request body
cc.mu.Lock() cc.mu.Lock()
mustCloseBody := false
if cs.reqBody != nil && cs.reqBodyClosed == nil {
mustCloseBody = true
cs.reqBodyClosed = make(chan struct{})
}
bodyClosed := cs.reqBodyClosed bodyClosed := cs.reqBodyClosed
cs.reqBodyClosed = true
cc.mu.Unlock() cc.mu.Unlock()
if !bodyClosed && cs.reqBody != nil { if mustCloseBody {
cs.reqBody.Close() cs.reqBody.Close()
close(bodyClosed)
}
if bodyClosed != nil {
<-bodyClosed
} }
if err != nil && cs.sentEndStream { if err != nil && cs.sentEndStream {
@ -1580,7 +1625,7 @@ func (cs *clientStream) writeRequestBody(req *http.Request) (err error) {
var sawEOF bool var sawEOF bool
for !sawEOF { for !sawEOF {
n, err := body.Read(buf[:len(buf)]) n, err := body.Read(buf)
if hasContentLen { if hasContentLen {
remainLen -= int64(n) remainLen -= int64(n)
if remainLen == 0 && err == nil { if remainLen == 0 && err == nil {
@ -1603,7 +1648,7 @@ func (cs *clientStream) writeRequestBody(req *http.Request) (err error) {
} }
if err != nil { if err != nil {
cc.mu.Lock() cc.mu.Lock()
bodyClosed := cs.reqBodyClosed bodyClosed := cs.reqBodyClosed != nil
cc.mu.Unlock() cc.mu.Unlock()
switch { switch {
case bodyClosed: case bodyClosed:
@ -1698,7 +1743,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
if cc.closed { if cc.closed {
return 0, errClientConnClosed return 0, errClientConnClosed
} }
if cs.reqBodyClosed { if cs.reqBodyClosed != nil {
return 0, errStopReqBodyWrite return 0, errStopReqBodyWrite
} }
select { select {
@ -1883,7 +1928,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
// Header list size is ok. Write the headers. // Header list size is ok. Write the headers.
enumerateHeaders(func(name, value string) { enumerateHeaders(func(name, value string) {
name, ascii := asciiToLower(name) name, ascii := lowerHeader(name)
if !ascii { if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x). // field names have to be ASCII characters (just as in HTTP/1.x).
@ -1936,7 +1981,7 @@ func (cc *ClientConn) encodeTrailers(trailer http.Header) ([]byte, error) {
} }
for k, vv := range trailer { for k, vv := range trailer {
lowKey, ascii := asciiToLower(k) lowKey, ascii := lowerHeader(k)
if !ascii { if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x). // field names have to be ASCII characters (just as in HTTP/1.x).
@ -1994,7 +2039,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) {
// wake up RoundTrip if there is a pending request. // wake up RoundTrip if there is a pending request.
cc.cond.Broadcast() cc.cond.Broadcast()
closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 {
if VerboseLogs { if VerboseLogs {
cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2) cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2)
@ -2070,6 +2115,7 @@ func (rl *clientConnReadLoop) cleanup() {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
cc.closed = true cc.closed = true
for _, cs := range cc.streams { for _, cs := range cc.streams {
select { select {
case <-cs.peerClosed: case <-cs.peerClosed:
@ -2268,7 +2314,7 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
Status: status + " " + http.StatusText(statusCode), Status: status + " " + http.StatusText(statusCode),
} }
for _, hf := range regularFields { for _, hf := range regularFields {
key := http.CanonicalHeaderKey(hf.Name) key := canonicalHeader(hf.Name)
if key == "Trailer" { if key == "Trailer" {
t := res.Trailer t := res.Trailer
if t == nil { if t == nil {
@ -2276,7 +2322,7 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
res.Trailer = t res.Trailer = t
} }
foreachHeaderElement(hf.Value, func(v string) { foreachHeaderElement(hf.Value, func(v string) {
t[http.CanonicalHeaderKey(v)] = nil t[canonicalHeader(v)] = nil
}) })
} else { } else {
vv := header[key] vv := header[key]
@ -2381,7 +2427,7 @@ func (rl *clientConnReadLoop) processTrailers(cs *clientStream, f *MetaHeadersFr
trailer := make(http.Header) trailer := make(http.Header)
for _, hf := range f.RegularFields() { for _, hf := range f.RegularFields() {
key := http.CanonicalHeaderKey(hf.Name) key := canonicalHeader(hf.Name)
trailer[key] = append(trailer[key], hf.Value) trailer[key] = append(trailer[key], hf.Value)
} }
cs.trailer = trailer cs.trailer = trailer
@ -2663,7 +2709,6 @@ func (rl *clientConnReadLoop) processGoAway(f *GoAwayFrame) error {
if fn := cc.t.CountError; fn != nil { if fn := cc.t.CountError; fn != nil {
fn("recv_goaway_" + f.ErrCode.stringToken()) fn("recv_goaway_" + f.ErrCode.stringToken())
} }
} }
cc.setGoAway(f) cc.setGoAway(f)
return nil return nil
@ -2953,7 +2998,11 @@ func (gz *gzipReader) Read(p []byte) (n int, err error) {
} }
func (gz *gzipReader) Close() error { func (gz *gzipReader) Close() error {
return gz.body.Close() if err := gz.body.Close(); err != nil {
return err
}
gz.zerr = fs.ErrClosed
return nil
} }
type errorReader struct{ err error } type errorReader struct{ err error }
@ -3017,7 +3066,7 @@ func traceGotConn(req *http.Request, cc *ClientConn, reused bool) {
cc.mu.Lock() cc.mu.Lock()
ci.WasIdle = len(cc.streams) == 0 && reused ci.WasIdle = len(cc.streams) == 0 && reused
if ci.WasIdle && !cc.lastActive.IsZero() { if ci.WasIdle && !cc.lastActive.IsZero() {
ci.IdleTime = time.Now().Sub(cc.lastActive) ci.IdleTime = time.Since(cc.lastActive)
} }
cc.mu.Unlock() cc.mu.Unlock()

@ -395,7 +395,7 @@ func New(family, title string) Trace {
} }
func (tr *trace) Finish() { func (tr *trace) Finish() {
elapsed := time.Now().Sub(tr.Start) elapsed := time.Since(tr.Start)
tr.mu.Lock() tr.mu.Lock()
tr.Elapsed = elapsed tr.Elapsed = elapsed
tr.mu.Unlock() tr.mu.Unlock()

@ -6,7 +6,10 @@ package cpu
import "runtime" import "runtime"
const cacheLineSize = 64 // cacheLineSize is used to prevent false sharing of cache lines.
// We choose 128 because Apple Silicon, a.k.a. M1, has 128-byte cache line size.
// It doesn't cost much and is much more future-proof.
const cacheLineSize = 128
func initOptions() { func initOptions() {
options = []option{ options = []option{
@ -41,13 +44,10 @@ func archInit() {
switch runtime.GOOS { switch runtime.GOOS {
case "freebsd": case "freebsd":
readARM64Registers() readARM64Registers()
case "linux", "netbsd": case "linux", "netbsd", "openbsd":
doinit() doinit()
default: default:
// Most platforms don't seem to allow reading these registers. // Many platforms don't seem to allow reading these registers.
//
// OpenBSD:
// See https://golang.org/issue/31746
setMinimalFeatures() setMinimalFeatures()
} }
} }

@ -0,0 +1,65 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cpu
import (
"syscall"
"unsafe"
)
// Minimal copy of functionality from x/sys/unix so the cpu package can call
// sysctl without depending on x/sys/unix.
const (
// From OpenBSD's sys/sysctl.h.
_CTL_MACHDEP = 7
// From OpenBSD's machine/cpu.h.
_CPU_ID_AA64ISAR0 = 2
_CPU_ID_AA64ISAR1 = 3
)
// Implemented in the runtime package (runtime/sys_openbsd3.go)
func syscall_syscall6(fn, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err syscall.Errno)
//go:linkname syscall_syscall6 syscall.syscall6
func sysctl(mib []uint32, old *byte, oldlen *uintptr, new *byte, newlen uintptr) (err error) {
_, _, errno := syscall_syscall6(libc_sysctl_trampoline_addr, uintptr(unsafe.Pointer(&mib[0])), uintptr(len(mib)), uintptr(unsafe.Pointer(old)), uintptr(unsafe.Pointer(oldlen)), uintptr(unsafe.Pointer(new)), uintptr(newlen))
if errno != 0 {
return errno
}
return nil
}
var libc_sysctl_trampoline_addr uintptr
//go:cgo_import_dynamic libc_sysctl sysctl "libc.so"
func sysctlUint64(mib []uint32) (uint64, bool) {
var out uint64
nout := unsafe.Sizeof(out)
if err := sysctl(mib, (*byte)(unsafe.Pointer(&out)), &nout, nil, 0); err != nil {
return 0, false
}
return out, true
}
func doinit() {
setMinimalFeatures()
// Get ID_AA64ISAR0 and ID_AA64ISAR1 from sysctl.
isar0, ok := sysctlUint64([]uint32{_CTL_MACHDEP, _CPU_ID_AA64ISAR0})
if !ok {
return
}
isar1, ok := sysctlUint64([]uint32{_CTL_MACHDEP, _CPU_ID_AA64ISAR1})
if !ok {
return
}
parseARM64SystemRegisters(isar0, isar1, 0)
Initialized = true
}

@ -0,0 +1,11 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "textflag.h"
TEXT libc_sysctl_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_sysctl(SB)
GLOBL ·libc_sysctl_trampoline_addr(SB), RODATA, $8
DATA ·libc_sysctl_trampoline_addr(SB)/8, $libc_sysctl_trampoline<>(SB)

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build !linux && !netbsd && arm64 //go:build !linux && !netbsd && !openbsd && arm64
// +build !linux,!netbsd,arm64 // +build !linux,!netbsd,!openbsd,arm64
package cpu package cpu

@ -0,0 +1,15 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !aix && !linux && (ppc64 || ppc64le)
// +build !aix
// +build !linux
// +build ppc64 ppc64le
package cpu
func archInit() {
PPC64.IsPOWER8 = true
Initialized = true
}

@ -29,8 +29,6 @@ import (
"bytes" "bytes"
"strings" "strings"
"unsafe" "unsafe"
"golang.org/x/sys/internal/unsafeheader"
) )
// ByteSliceFromString returns a NUL-terminated slice of bytes // ByteSliceFromString returns a NUL-terminated slice of bytes
@ -82,13 +80,7 @@ func BytePtrToString(p *byte) string {
ptr = unsafe.Pointer(uintptr(ptr) + 1) ptr = unsafe.Pointer(uintptr(ptr) + 1)
} }
var s []byte return string(unsafe.Slice(p, n))
h := (*unsafeheader.Slice)(unsafe.Pointer(&s))
h.Data = unsafe.Pointer(p)
h.Len = n
h.Cap = n
return string(s)
} }
// Single-word zero for use when we need a valid pointer to 0 bytes. // Single-word zero for use when we need a valid pointer to 0 bytes.

@ -0,0 +1,31 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (darwin || freebsd || netbsd || openbsd) && gc
// +build darwin freebsd netbsd openbsd
// +build gc
#include "textflag.h"
//
// System call support for ppc64, BSD
//
// Just jump to package syscall's implementation for all these functions.
// The runtime may know about them.
TEXT ·Syscall(SB),NOSPLIT,$0-56
JMP syscall·Syscall(SB)
TEXT ·Syscall6(SB),NOSPLIT,$0-80
JMP syscall·Syscall6(SB)
TEXT ·Syscall9(SB),NOSPLIT,$0-104
JMP syscall·Syscall9(SB)
TEXT ·RawSyscall(SB),NOSPLIT,$0-56
JMP syscall·RawSyscall(SB)
TEXT ·RawSyscall6(SB),NOSPLIT,$0-80
JMP syscall·RawSyscall6(SB)

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos
package unix package unix

@ -4,9 +4,7 @@
package unix package unix
import ( import "unsafe"
"unsafe"
)
// IoctlRetInt performs an ioctl operation specified by req on a device // IoctlRetInt performs an ioctl operation specified by req on a device
// associated with opened file descriptor fd, and returns a non-negative // associated with opened file descriptor fd, and returns a non-negative
@ -217,3 +215,19 @@ func IoctlKCMAttach(fd int, info KCMAttach) error {
func IoctlKCMUnattach(fd int, info KCMUnattach) error { func IoctlKCMUnattach(fd int, info KCMUnattach) error {
return ioctlPtr(fd, SIOCKCMUNATTACH, unsafe.Pointer(&info)) return ioctlPtr(fd, SIOCKCMUNATTACH, unsafe.Pointer(&info))
} }
// IoctlLoopGetStatus64 gets the status of the loop device associated with the
// file descriptor fd using the LOOP_GET_STATUS64 operation.
func IoctlLoopGetStatus64(fd int) (*LoopInfo64, error) {
var value LoopInfo64
if err := ioctlPtr(fd, LOOP_GET_STATUS64, unsafe.Pointer(&value)); err != nil {
return nil, err
}
return &value, nil
}
// IoctlLoopSetStatus64 sets the status of the loop device associated with the
// file descriptor fd using the LOOP_SET_STATUS64 operation.
func IoctlLoopSetStatus64(fd int, value *LoopInfo64) error {
return ioctlPtr(fd, LOOP_SET_STATUS64, unsafe.Pointer(value))
}

@ -52,6 +52,20 @@ func ParseSocketControlMessage(b []byte) ([]SocketControlMessage, error) {
return msgs, nil return msgs, nil
} }
// ParseOneSocketControlMessage parses a single socket control message from b, returning the message header,
// message data (a slice of b), and the remainder of b after that single message.
// When there are no remaining messages, len(remainder) == 0.
func ParseOneSocketControlMessage(b []byte) (hdr Cmsghdr, data []byte, remainder []byte, err error) {
h, dbuf, err := socketControlMessageHeaderAndData(b)
if err != nil {
return Cmsghdr{}, nil, nil, err
}
if i := cmsgAlignOf(int(h.Len)); i < len(b) {
remainder = b[i:]
}
return *h, dbuf, remainder, nil
}
func socketControlMessageHeaderAndData(b []byte) (*Cmsghdr, []byte, error) { func socketControlMessageHeaderAndData(b []byte) (*Cmsghdr, []byte, error) {
h := (*Cmsghdr)(unsafe.Pointer(&b[0])) h := (*Cmsghdr)(unsafe.Pointer(&b[0]))
if h.Len < SizeofCmsghdr || uint64(h.Len) > uint64(len(b)) { if h.Len < SizeofCmsghdr || uint64(h.Len) > uint64(len(b)) {

@ -1,27 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package unix
func itoa(val int) string { // do it here rather than with fmt to avoid dependency
if val < 0 {
return "-" + uitoa(uint(-val))
}
return uitoa(uint(val))
}
func uitoa(val uint) string {
var buf [32]byte // big enough for int64
i := len(buf) - 1
for val >= 10 {
buf[i] = byte(val%10 + '0')
i--
val /= 10
}
buf[i] = byte(val + '0')
return string(buf[i:])
}

@ -29,8 +29,6 @@ import (
"bytes" "bytes"
"strings" "strings"
"unsafe" "unsafe"
"golang.org/x/sys/internal/unsafeheader"
) )
// ByteSliceFromString returns a NUL-terminated slice of bytes // ByteSliceFromString returns a NUL-terminated slice of bytes
@ -82,13 +80,7 @@ func BytePtrToString(p *byte) string {
ptr = unsafe.Pointer(uintptr(ptr) + 1) ptr = unsafe.Pointer(uintptr(ptr) + 1)
} }
var s []byte return string(unsafe.Slice(p, n))
h := (*unsafeheader.Slice)(unsafe.Pointer(&s))
h.Data = unsafe.Pointer(p)
h.Len = n
h.Cap = n
return string(s)
} }
// Single-word zero for use when we need a valid pointer to 0 bytes. // Single-word zero for use when we need a valid pointer to 0 bytes.

@ -218,13 +218,62 @@ func Accept(fd int) (nfd int, sa Sockaddr, err error) {
} }
func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) { func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) {
// Recvmsg not implemented on AIX var msg Msghdr
return -1, -1, -1, ENOSYS msg.Name = (*byte)(unsafe.Pointer(rsa))
msg.Namelen = uint32(SizeofSockaddrAny)
var dummy byte
if len(oob) > 0 {
// receive at least one normal byte
if emptyIovecs(iov) {
var iova [1]Iovec
iova[0].Base = &dummy
iova[0].SetLen(1)
iov = iova[:]
}
msg.Control = (*byte)(unsafe.Pointer(&oob[0]))
msg.SetControllen(len(oob))
}
if len(iov) > 0 {
msg.Iov = &iov[0]
msg.SetIovlen(len(iov))
}
if n, err = recvmsg(fd, &msg, flags); n == -1 {
return
}
oobn = int(msg.Controllen)
recvflags = int(msg.Flags)
return
} }
func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) { func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) {
// SendmsgN not implemented on AIX var msg Msghdr
return -1, ENOSYS msg.Name = (*byte)(unsafe.Pointer(ptr))
msg.Namelen = uint32(salen)
var dummy byte
var empty bool
if len(oob) > 0 {
// send at least one normal byte
empty = emptyIovecs(iov)
if empty {
var iova [1]Iovec
iova[0].Base = &dummy
iova[0].SetLen(1)
iov = iova[:]
}
msg.Control = (*byte)(unsafe.Pointer(&oob[0]))
msg.SetControllen(len(oob))
}
if len(iov) > 0 {
msg.Iov = &iov[0]
msg.SetIovlen(len(iov))
}
if n, err = sendmsg(fd, &msg, flags); err != nil {
return 0, err
}
if len(oob) > 0 && empty {
n = 0
}
return n, nil
} }
func anyToSockaddr(fd int, rsa *RawSockaddrAny) (Sockaddr, error) { func anyToSockaddr(fd int, rsa *RawSockaddrAny) (Sockaddr, error) {

@ -363,7 +363,7 @@ func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Sockle
var empty bool var empty bool
if len(oob) > 0 { if len(oob) > 0 {
// send at least one normal byte // send at least one normal byte
empty := emptyIovecs(iov) empty = emptyIovecs(iov)
if empty { if empty {
var iova [1]Iovec var iova [1]Iovec
iova[0].Base = &dummy iova[0].Base = &dummy

@ -1,32 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin && go1.12 && !go1.13
// +build darwin,go1.12,!go1.13
package unix
import (
"unsafe"
)
const _SYS_GETDIRENTRIES64 = 344
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
// To implement this using libSystem we'd need syscall_syscallPtr for
// fdopendir. However, syscallPtr was only added in Go 1.13, so we fall
// back to raw syscalls for this func on Go 1.12.
var p unsafe.Pointer
if len(buf) > 0 {
p = unsafe.Pointer(&buf[0])
} else {
p = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(_SYS_GETDIRENTRIES64, uintptr(fd), uintptr(p), uintptr(len(buf)), uintptr(unsafe.Pointer(basep)), 0, 0)
n = int(r0)
if e1 != 0 {
return n, errnoErr(e1)
}
return n, nil
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save