You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

276 lines
8.8 KiB

  1. // Copyright 2015 go-swagger maintainers
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package security
  15. import (
  16. "context"
  17. "net/http"
  18. "strings"
  19. "github.com/go-openapi/errors"
  20. "github.com/go-openapi/runtime"
  21. )
  22. const (
  23. query = "query"
  24. header = "header"
  25. )
  26. // HttpAuthenticator is a function that authenticates a HTTP request
  27. func HttpAuthenticator(handler func(*http.Request) (bool, interface{}, error)) runtime.Authenticator {
  28. return runtime.AuthenticatorFunc(func(params interface{}) (bool, interface{}, error) {
  29. if request, ok := params.(*http.Request); ok {
  30. return handler(request)
  31. }
  32. if scoped, ok := params.(*ScopedAuthRequest); ok {
  33. return handler(scoped.Request)
  34. }
  35. return false, nil, nil
  36. })
  37. }
  38. // ScopedAuthenticator is a function that authenticates a HTTP request against a list of valid scopes
  39. func ScopedAuthenticator(handler func(*ScopedAuthRequest) (bool, interface{}, error)) runtime.Authenticator {
  40. return runtime.AuthenticatorFunc(func(params interface{}) (bool, interface{}, error) {
  41. if request, ok := params.(*ScopedAuthRequest); ok {
  42. return handler(request)
  43. }
  44. return false, nil, nil
  45. })
  46. }
  47. // UserPassAuthentication authentication function
  48. type UserPassAuthentication func(string, string) (interface{}, error)
  49. // UserPassAuthenticationCtx authentication function with context.Context
  50. type UserPassAuthenticationCtx func(context.Context, string, string) (context.Context, interface{}, error)
  51. // TokenAuthentication authentication function
  52. type TokenAuthentication func(string) (interface{}, error)
  53. // TokenAuthenticationCtx authentication function with context.Context
  54. type TokenAuthenticationCtx func(context.Context, string) (context.Context, interface{}, error)
  55. // ScopedTokenAuthentication authentication function
  56. type ScopedTokenAuthentication func(string, []string) (interface{}, error)
  57. // ScopedTokenAuthenticationCtx authentication function with context.Context
  58. type ScopedTokenAuthenticationCtx func(context.Context, string, []string) (context.Context, interface{}, error)
  59. var DefaultRealmName = "API"
  60. type secCtxKey uint8
  61. const (
  62. failedBasicAuth secCtxKey = iota
  63. oauth2SchemeName
  64. )
  65. func FailedBasicAuth(r *http.Request) string {
  66. return FailedBasicAuthCtx(r.Context())
  67. }
  68. func FailedBasicAuthCtx(ctx context.Context) string {
  69. v, ok := ctx.Value(failedBasicAuth).(string)
  70. if !ok {
  71. return ""
  72. }
  73. return v
  74. }
  75. func OAuth2SchemeName(r *http.Request) string {
  76. return OAuth2SchemeNameCtx(r.Context())
  77. }
  78. func OAuth2SchemeNameCtx(ctx context.Context) string {
  79. v, ok := ctx.Value(oauth2SchemeName).(string)
  80. if !ok {
  81. return ""
  82. }
  83. return v
  84. }
  85. // BasicAuth creates a basic auth authenticator with the provided authentication function
  86. func BasicAuth(authenticate UserPassAuthentication) runtime.Authenticator {
  87. return BasicAuthRealm(DefaultRealmName, authenticate)
  88. }
  89. // BasicAuthRealm creates a basic auth authenticator with the provided authentication function and realm name
  90. func BasicAuthRealm(realm string, authenticate UserPassAuthentication) runtime.Authenticator {
  91. if realm == "" {
  92. realm = DefaultRealmName
  93. }
  94. return HttpAuthenticator(func(r *http.Request) (bool, interface{}, error) {
  95. if usr, pass, ok := r.BasicAuth(); ok {
  96. p, err := authenticate(usr, pass)
  97. if err != nil {
  98. *r = *r.WithContext(context.WithValue(r.Context(), failedBasicAuth, realm))
  99. }
  100. return true, p, err
  101. }
  102. *r = *r.WithContext(context.WithValue(r.Context(), failedBasicAuth, realm))
  103. return false, nil, nil
  104. })
  105. }
  106. // BasicAuthCtx creates a basic auth authenticator with the provided authentication function with support for context.Context
  107. func BasicAuthCtx(authenticate UserPassAuthenticationCtx) runtime.Authenticator {
  108. return BasicAuthRealmCtx(DefaultRealmName, authenticate)
  109. }
  110. // BasicAuthRealmCtx creates a basic auth authenticator with the provided authentication function and realm name with support for context.Context
  111. func BasicAuthRealmCtx(realm string, authenticate UserPassAuthenticationCtx) runtime.Authenticator {
  112. if realm == "" {
  113. realm = DefaultRealmName
  114. }
  115. return HttpAuthenticator(func(r *http.Request) (bool, interface{}, error) {
  116. if usr, pass, ok := r.BasicAuth(); ok {
  117. ctx, p, err := authenticate(r.Context(), usr, pass)
  118. if err != nil {
  119. ctx = context.WithValue(ctx, failedBasicAuth, realm)
  120. }
  121. *r = *r.WithContext(ctx)
  122. return true, p, err
  123. }
  124. *r = *r.WithContext(context.WithValue(r.Context(), failedBasicAuth, realm))
  125. return false, nil, nil
  126. })
  127. }
  128. // APIKeyAuth creates an authenticator that uses a token for authorization.
  129. // This token can be obtained from either a header or a query string
  130. func APIKeyAuth(name, in string, authenticate TokenAuthentication) runtime.Authenticator {
  131. inl := strings.ToLower(in)
  132. if inl != query && inl != header {
  133. // panic because this is most likely a typo
  134. panic(errors.New(500, "api key auth: in value needs to be either \"query\" or \"header\"."))
  135. }
  136. var getToken func(*http.Request) string
  137. switch inl {
  138. case header:
  139. getToken = func(r *http.Request) string { return r.Header.Get(name) }
  140. case query:
  141. getToken = func(r *http.Request) string { return r.URL.Query().Get(name) }
  142. }
  143. return HttpAuthenticator(func(r *http.Request) (bool, interface{}, error) {
  144. token := getToken(r)
  145. if token == "" {
  146. return false, nil, nil
  147. }
  148. p, err := authenticate(token)
  149. return true, p, err
  150. })
  151. }
  152. // APIKeyAuthCtx creates an authenticator that uses a token for authorization with support for context.Context.
  153. // This token can be obtained from either a header or a query string
  154. func APIKeyAuthCtx(name, in string, authenticate TokenAuthenticationCtx) runtime.Authenticator {
  155. inl := strings.ToLower(in)
  156. if inl != query && inl != header {
  157. // panic because this is most likely a typo
  158. panic(errors.New(500, "api key auth: in value needs to be either \"query\" or \"header\"."))
  159. }
  160. var getToken func(*http.Request) string
  161. switch inl {
  162. case header:
  163. getToken = func(r *http.Request) string { return r.Header.Get(name) }
  164. case query:
  165. getToken = func(r *http.Request) string { return r.URL.Query().Get(name) }
  166. }
  167. return HttpAuthenticator(func(r *http.Request) (bool, interface{}, error) {
  168. token := getToken(r)
  169. if token == "" {
  170. return false, nil, nil
  171. }
  172. ctx, p, err := authenticate(r.Context(), token)
  173. *r = *r.WithContext(ctx)
  174. return true, p, err
  175. })
  176. }
  177. // ScopedAuthRequest contains both a http request and the required scopes for a particular operation
  178. type ScopedAuthRequest struct {
  179. Request *http.Request
  180. RequiredScopes []string
  181. }
  182. // BearerAuth for use with oauth2 flows
  183. func BearerAuth(name string, authenticate ScopedTokenAuthentication) runtime.Authenticator {
  184. const prefix = "Bearer "
  185. return ScopedAuthenticator(func(r *ScopedAuthRequest) (bool, interface{}, error) {
  186. var token string
  187. hdr := r.Request.Header.Get("Authorization")
  188. if strings.HasPrefix(hdr, prefix) {
  189. token = strings.TrimPrefix(hdr, prefix)
  190. }
  191. if token == "" {
  192. qs := r.Request.URL.Query()
  193. token = qs.Get("access_token")
  194. }
  195. //#nosec
  196. ct, _, _ := runtime.ContentType(r.Request.Header)
  197. if token == "" && (ct == "application/x-www-form-urlencoded" || ct == "multipart/form-data") {
  198. token = r.Request.FormValue("access_token")
  199. }
  200. if token == "" {
  201. return false, nil, nil
  202. }
  203. rctx := context.WithValue(r.Request.Context(), oauth2SchemeName, name)
  204. *r.Request = *r.Request.WithContext(rctx)
  205. p, err := authenticate(token, r.RequiredScopes)
  206. return true, p, err
  207. })
  208. }
  209. // BearerAuthCtx for use with oauth2 flows with support for context.Context.
  210. func BearerAuthCtx(name string, authenticate ScopedTokenAuthenticationCtx) runtime.Authenticator {
  211. const prefix = "Bearer "
  212. return ScopedAuthenticator(func(r *ScopedAuthRequest) (bool, interface{}, error) {
  213. var token string
  214. hdr := r.Request.Header.Get("Authorization")
  215. if strings.HasPrefix(hdr, prefix) {
  216. token = strings.TrimPrefix(hdr, prefix)
  217. }
  218. if token == "" {
  219. qs := r.Request.URL.Query()
  220. token = qs.Get("access_token")
  221. }
  222. //#nosec
  223. ct, _, _ := runtime.ContentType(r.Request.Header)
  224. if token == "" && (ct == "application/x-www-form-urlencoded" || ct == "multipart/form-data") {
  225. token = r.Request.FormValue("access_token")
  226. }
  227. if token == "" {
  228. return false, nil, nil
  229. }
  230. rctx := context.WithValue(r.Request.Context(), oauth2SchemeName, name)
  231. ctx, p, err := authenticate(rctx, token, r.RequiredScopes)
  232. *r.Request = *r.Request.WithContext(ctx)
  233. return true, p, err
  234. })
  235. }