You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

473 lines
15 KiB

10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
10 years ago
  1. // Copyright 2013 The Martini Contrib Authors. All rights reserved.
  2. // Copyright 2014 The Gogs Authors. All rights reserved.
  3. // Use of this source code is governed by a MIT-style
  4. // license that can be found in the LICENSE file.
  5. package binding
  6. import (
  7. "encoding/json"
  8. "fmt"
  9. "io"
  10. "net/http"
  11. "reflect"
  12. "regexp"
  13. "strconv"
  14. "strings"
  15. "unicode/utf8"
  16. "github.com/go-martini/martini"
  17. )
  18. /*
  19. To the land of Middle-ware Earth:
  20. One func to rule them all,
  21. One func to find them,
  22. One func to bring them all,
  23. And in this package BIND them.
  24. */
  25. // Bind accepts a copy of an empty struct and populates it with
  26. // values from the request (if deserialization is successful). It
  27. // wraps up the functionality of the Form and Json middleware
  28. // according to the Content-Type of the request, and it guesses
  29. // if no Content-Type is specified. Bind invokes the ErrorHandler
  30. // middleware to bail out if errors occurred. If you want to perform
  31. // your own error handling, use Form or Json middleware directly.
  32. // An interface pointer can be added as a second argument in order
  33. // to map the struct to a specific interface.
  34. func Bind(obj interface{}, ifacePtr ...interface{}) martini.Handler {
  35. return func(context martini.Context, req *http.Request) {
  36. contentType := req.Header.Get("Content-Type")
  37. if strings.Contains(contentType, "form-urlencoded") {
  38. context.Invoke(Form(obj, ifacePtr...))
  39. } else if strings.Contains(contentType, "multipart/form-data") {
  40. context.Invoke(MultipartForm(obj, ifacePtr...))
  41. } else if strings.Contains(contentType, "json") {
  42. context.Invoke(Json(obj, ifacePtr...))
  43. } else {
  44. context.Invoke(Json(obj, ifacePtr...))
  45. if getErrors(context).Count() > 0 {
  46. context.Invoke(Form(obj, ifacePtr...))
  47. }
  48. }
  49. context.Invoke(ErrorHandler)
  50. }
  51. }
  52. // BindIgnErr will do the exactly same thing as Bind but without any
  53. // error handling, which user has freedom to deal with them.
  54. // This allows user take advantages of validation.
  55. func BindIgnErr(obj interface{}, ifacePtr ...interface{}) martini.Handler {
  56. return func(context martini.Context, req *http.Request) {
  57. contentType := req.Header.Get("Content-Type")
  58. if strings.Contains(contentType, "form-urlencoded") {
  59. context.Invoke(Form(obj, ifacePtr...))
  60. } else if strings.Contains(contentType, "multipart/form-data") {
  61. context.Invoke(MultipartForm(obj, ifacePtr...))
  62. } else if strings.Contains(contentType, "json") {
  63. context.Invoke(Json(obj, ifacePtr...))
  64. } else {
  65. context.Invoke(Json(obj, ifacePtr...))
  66. if getErrors(context).Count() > 0 {
  67. context.Invoke(Form(obj, ifacePtr...))
  68. }
  69. }
  70. }
  71. }
  72. // Form is middleware to deserialize form-urlencoded data from the request.
  73. // It gets data from the form-urlencoded body, if present, or from the
  74. // query string. It uses the http.Request.ParseForm() method
  75. // to perform deserialization, then reflection is used to map each field
  76. // into the struct with the proper type. Structs with primitive slice types
  77. // (bool, float, int, string) can support deserialization of repeated form
  78. // keys, for example: key=val1&key=val2&key=val3
  79. // An interface pointer can be added as a second argument in order
  80. // to map the struct to a specific interface.
  81. func Form(formStruct interface{}, ifacePtr ...interface{}) martini.Handler {
  82. return func(context martini.Context, req *http.Request) {
  83. ensureNotPointer(formStruct)
  84. formStruct := reflect.New(reflect.TypeOf(formStruct))
  85. errors := newErrors()
  86. parseErr := req.ParseForm()
  87. // Format validation of the request body or the URL would add considerable overhead,
  88. // and ParseForm does not complain when URL encoding is off.
  89. // Because an empty request body or url can also mean absence of all needed values,
  90. // it is not in all cases a bad request, so let's return 422.
  91. if parseErr != nil {
  92. errors.Overall[BindingDeserializationError] = parseErr.Error()
  93. }
  94. mapForm(formStruct, req.Form, errors)
  95. validateAndMap(formStruct, context, errors, ifacePtr...)
  96. }
  97. }
  98. func MultipartForm(formStruct interface{}, ifacePtr ...interface{}) martini.Handler {
  99. return func(context martini.Context, req *http.Request) {
  100. ensureNotPointer(formStruct)
  101. formStruct := reflect.New(reflect.TypeOf(formStruct))
  102. errors := newErrors()
  103. // Workaround for multipart forms returning nil instead of an error
  104. // when content is not multipart
  105. // https://code.google.com/p/go/issues/detail?id=6334
  106. multipartReader, err := req.MultipartReader()
  107. if err != nil {
  108. errors.Overall[BindingDeserializationError] = err.Error()
  109. } else {
  110. form, parseErr := multipartReader.ReadForm(MaxMemory)
  111. if parseErr != nil {
  112. errors.Overall[BindingDeserializationError] = parseErr.Error()
  113. }
  114. req.MultipartForm = form
  115. }
  116. mapForm(formStruct, req.MultipartForm.Value, errors)
  117. validateAndMap(formStruct, context, errors, ifacePtr...)
  118. }
  119. }
  120. // Json is middleware to deserialize a JSON payload from the request
  121. // into the struct that is passed in. The resulting struct is then
  122. // validated, but no error handling is actually performed here.
  123. // An interface pointer can be added as a second argument in order
  124. // to map the struct to a specific interface.
  125. func Json(jsonStruct interface{}, ifacePtr ...interface{}) martini.Handler {
  126. return func(context martini.Context, req *http.Request) {
  127. ensureNotPointer(jsonStruct)
  128. jsonStruct := reflect.New(reflect.TypeOf(jsonStruct))
  129. errors := newErrors()
  130. if req.Body != nil {
  131. defer req.Body.Close()
  132. }
  133. if err := json.NewDecoder(req.Body).Decode(jsonStruct.Interface()); err != nil && err != io.EOF {
  134. errors.Overall[BindingDeserializationError] = err.Error()
  135. }
  136. validateAndMap(jsonStruct, context, errors, ifacePtr...)
  137. }
  138. }
  139. // Validate is middleware to enforce required fields. If the struct
  140. // passed in is a Validator, then the user-defined Validate method
  141. // is executed, and its errors are mapped to the context. This middleware
  142. // performs no error handling: it merely detects them and maps them.
  143. func Validate(obj interface{}) martini.Handler {
  144. return func(context martini.Context, req *http.Request) {
  145. errors := newErrors()
  146. validateStruct(errors, obj)
  147. if validator, ok := obj.(Validator); ok {
  148. validator.Validate(errors, req, context)
  149. }
  150. context.Map(*errors)
  151. }
  152. }
  153. var (
  154. alphaDashPattern = regexp.MustCompile("[^\\d\\w-_]")
  155. alphaDashDotPattern = regexp.MustCompile("[^\\d\\w-_\\.]")
  156. emailPattern = regexp.MustCompile("[\\w!#$%&'*+/=?^_`{|}~-]+(?:\\.[\\w!#$%&'*+/=?^_`{|}~-]+)*@(?:[\\w](?:[\\w-]*[\\w])?\\.)+[a-zA-Z0-9](?:[\\w-]*[\\w])?")
  157. urlPattern = regexp.MustCompile(`(http|https):\/\/[\w\-_]+(\.[\w\-_]+)+([\w\-\.,@?^=%&:/~\+#]*[\w\-\@?^=%&/~\+#])?`)
  158. )
  159. func validateStruct(errors *Errors, obj interface{}) {
  160. typ := reflect.TypeOf(obj)
  161. val := reflect.ValueOf(obj)
  162. if typ.Kind() == reflect.Ptr {
  163. typ = typ.Elem()
  164. val = val.Elem()
  165. }
  166. for i := 0; i < typ.NumField(); i++ {
  167. field := typ.Field(i)
  168. // Allow ignored fields in the struct
  169. if field.Tag.Get("form") == "-" {
  170. continue
  171. }
  172. fieldValue := val.Field(i).Interface()
  173. if field.Type.Kind() == reflect.Struct {
  174. validateStruct(errors, fieldValue)
  175. continue
  176. }
  177. zero := reflect.Zero(field.Type).Interface()
  178. // Match rules.
  179. for _, rule := range strings.Split(field.Tag.Get("binding"), ";") {
  180. if len(rule) == 0 {
  181. continue
  182. }
  183. switch {
  184. case rule == "Required":
  185. if reflect.DeepEqual(zero, fieldValue) {
  186. errors.Fields[field.Name] = BindingRequireError
  187. break
  188. }
  189. case rule == "AlphaDash":
  190. if alphaDashPattern.MatchString(fmt.Sprintf("%v", fieldValue)) {
  191. errors.Fields[field.Name] = BindingAlphaDashError
  192. break
  193. }
  194. case rule == "AlphaDashDot":
  195. if alphaDashDotPattern.MatchString(fmt.Sprintf("%v", fieldValue)) {
  196. errors.Fields[field.Name] = BindingAlphaDashDotError
  197. break
  198. }
  199. case strings.HasPrefix(rule, "MinSize("):
  200. min, err := strconv.Atoi(rule[8 : len(rule)-1])
  201. if err != nil {
  202. errors.Overall["MinSize"] = err.Error()
  203. break
  204. }
  205. if str, ok := fieldValue.(string); ok && utf8.RuneCountInString(str) < min {
  206. errors.Fields[field.Name] = BindingMinSizeError
  207. break
  208. }
  209. v := reflect.ValueOf(fieldValue)
  210. if v.Kind() == reflect.Slice && v.Len() < min {
  211. errors.Fields[field.Name] = BindingMinSizeError
  212. break
  213. }
  214. case strings.HasPrefix(rule, "MaxSize("):
  215. max, err := strconv.Atoi(rule[8 : len(rule)-1])
  216. if err != nil {
  217. errors.Overall["MaxSize"] = err.Error()
  218. break
  219. }
  220. if str, ok := fieldValue.(string); ok && utf8.RuneCountInString(str) > max {
  221. errors.Fields[field.Name] = BindingMaxSizeError
  222. break
  223. }
  224. v := reflect.ValueOf(fieldValue)
  225. if v.Kind() == reflect.Slice && v.Len() > max {
  226. errors.Fields[field.Name] = BindingMinSizeError
  227. break
  228. }
  229. case rule == "Email":
  230. if !emailPattern.MatchString(fmt.Sprintf("%v", fieldValue)) {
  231. errors.Fields[field.Name] = BindingEmailError
  232. break
  233. }
  234. case rule == "Url":
  235. str := fmt.Sprintf("%v", fieldValue)
  236. if len(str) == 0 {
  237. continue
  238. } else if !urlPattern.MatchString(str) {
  239. errors.Fields[field.Name] = BindingUrlError
  240. break
  241. }
  242. }
  243. }
  244. }
  245. }
  246. func mapForm(formStruct reflect.Value, form map[string][]string, errors *Errors) {
  247. typ := formStruct.Elem().Type()
  248. for i := 0; i < typ.NumField(); i++ {
  249. typeField := typ.Field(i)
  250. if inputFieldName := typeField.Tag.Get("form"); inputFieldName != "" {
  251. structField := formStruct.Elem().Field(i)
  252. if !structField.CanSet() {
  253. continue
  254. }
  255. inputValue, exists := form[inputFieldName]
  256. if !exists {
  257. continue
  258. }
  259. numElems := len(inputValue)
  260. if structField.Kind() == reflect.Slice && numElems > 0 {
  261. sliceOf := structField.Type().Elem().Kind()
  262. slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
  263. for i := 0; i < numElems; i++ {
  264. setWithProperType(sliceOf, inputValue[i], slice.Index(i), inputFieldName, errors)
  265. }
  266. formStruct.Elem().Field(i).Set(slice)
  267. } else {
  268. setWithProperType(typeField.Type.Kind(), inputValue[0], structField, inputFieldName, errors)
  269. }
  270. }
  271. }
  272. }
  273. // ErrorHandler simply counts the number of errors in the
  274. // context and, if more than 0, writes a 400 Bad Request
  275. // response and a JSON payload describing the errors with
  276. // the "Content-Type" set to "application/json".
  277. // Middleware remaining on the stack will not even see the request
  278. // if, by this point, there are any errors.
  279. // This is a "default" handler, of sorts, and you are
  280. // welcome to use your own instead. The Bind middleware
  281. // invokes this automatically for convenience.
  282. func ErrorHandler(errs Errors, resp http.ResponseWriter) {
  283. if errs.Count() > 0 {
  284. resp.Header().Set("Content-Type", "application/json; charset=utf-8")
  285. if _, ok := errs.Overall[BindingDeserializationError]; ok {
  286. resp.WriteHeader(http.StatusBadRequest)
  287. } else {
  288. resp.WriteHeader(422)
  289. }
  290. errOutput, _ := json.Marshal(errs)
  291. resp.Write(errOutput)
  292. return
  293. }
  294. }
  295. // This sets the value in a struct of an indeterminate type to the
  296. // matching value from the request (via Form middleware) in the
  297. // same type, so that not all deserialized values have to be strings.
  298. // Supported types are string, int, float, and bool.
  299. func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value, nameInTag string, errors *Errors) {
  300. switch valueKind {
  301. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  302. if val == "" {
  303. val = "0"
  304. }
  305. intVal, err := strconv.ParseInt(val, 10, 64)
  306. if err != nil {
  307. errors.Fields[nameInTag] = BindingIntegerTypeError
  308. } else {
  309. structField.SetInt(intVal)
  310. }
  311. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  312. if val == "" {
  313. val = "0"
  314. }
  315. uintVal, err := strconv.ParseUint(val, 10, 64)
  316. if err != nil {
  317. errors.Fields[nameInTag] = BindingIntegerTypeError
  318. } else {
  319. structField.SetUint(uintVal)
  320. }
  321. case reflect.Bool:
  322. structField.SetBool(val == "on")
  323. case reflect.Float32:
  324. if val == "" {
  325. val = "0.0"
  326. }
  327. floatVal, err := strconv.ParseFloat(val, 32)
  328. if err != nil {
  329. errors.Fields[nameInTag] = BindingFloatTypeError
  330. } else {
  331. structField.SetFloat(floatVal)
  332. }
  333. case reflect.Float64:
  334. if val == "" {
  335. val = "0.0"
  336. }
  337. floatVal, err := strconv.ParseFloat(val, 64)
  338. if err != nil {
  339. errors.Fields[nameInTag] = BindingFloatTypeError
  340. } else {
  341. structField.SetFloat(floatVal)
  342. }
  343. case reflect.String:
  344. structField.SetString(val)
  345. }
  346. }
  347. // Don't pass in pointers to bind to. Can lead to bugs. See:
  348. // https://github.com/codegangsta/martini-contrib/issues/40
  349. // https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659
  350. func ensureNotPointer(obj interface{}) {
  351. if reflect.TypeOf(obj).Kind() == reflect.Ptr {
  352. panic("Pointers are not accepted as binding models")
  353. }
  354. }
  355. // Performs validation and combines errors from validation
  356. // with errors from deserialization, then maps both the
  357. // resulting struct and the errors to the context.
  358. func validateAndMap(obj reflect.Value, context martini.Context, errors *Errors, ifacePtr ...interface{}) {
  359. context.Invoke(Validate(obj.Interface()))
  360. errors.Combine(getErrors(context))
  361. context.Map(*errors)
  362. context.Map(obj.Elem().Interface())
  363. if len(ifacePtr) > 0 {
  364. context.MapTo(obj.Elem().Interface(), ifacePtr[0])
  365. }
  366. }
  367. func newErrors() *Errors {
  368. return &Errors{make(map[string]string), make(map[string]string)}
  369. }
  370. func getErrors(context martini.Context) Errors {
  371. return context.Get(reflect.TypeOf(Errors{})).Interface().(Errors)
  372. }
  373. type (
  374. // Implement the Validator interface to define your own input
  375. // validation before the request even gets to your application.
  376. // The Validate method will be executed during the validation phase.
  377. Validator interface {
  378. Validate(*Errors, *http.Request, martini.Context)
  379. }
  380. )
  381. var (
  382. // Maximum amount of memory to use when parsing a multipart form.
  383. // Set this to whatever value you prefer; default is 10 MB.
  384. MaxMemory = int64(1024 * 1024 * 10)
  385. )
  386. // Errors represents the contract of the response body when the
  387. // binding step fails before getting to the application.
  388. type Errors struct {
  389. Overall map[string]string `json:"overall"`
  390. Fields map[string]string `json:"fields"`
  391. }
  392. // Total errors is the sum of errors with the request overall
  393. // and errors on individual fields.
  394. func (err Errors) Count() int {
  395. return len(err.Overall) + len(err.Fields)
  396. }
  397. func (this *Errors) Combine(other Errors) {
  398. for key, val := range other.Fields {
  399. if _, exists := this.Fields[key]; !exists {
  400. this.Fields[key] = val
  401. }
  402. }
  403. for key, val := range other.Overall {
  404. if _, exists := this.Overall[key]; !exists {
  405. this.Overall[key] = val
  406. }
  407. }
  408. }
  409. const (
  410. BindingRequireError string = "Required"
  411. BindingAlphaDashError string = "AlphaDash"
  412. BindingAlphaDashDotError string = "AlphaDashDot"
  413. BindingMinSizeError string = "MinSize"
  414. BindingMaxSizeError string = "MaxSize"
  415. BindingEmailError string = "Email"
  416. BindingUrlError string = "Url"
  417. BindingDeserializationError string = "DeserializationError"
  418. BindingIntegerTypeError string = "IntegerTypeError"
  419. BindingBooleanTypeError string = "BooleanTypeError"
  420. BindingFloatTypeError string = "FloatTypeError"
  421. )