You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

199 lines
5.0 KiB

  1. // Copyright 2019 The Gitea Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package session
  5. import (
  6. "encoding/json"
  7. "fmt"
  8. "sync"
  9. "gitea.com/macaron/session"
  10. couchbase "gitea.com/macaron/session/couchbase"
  11. memcache "gitea.com/macaron/session/memcache"
  12. mysql "gitea.com/macaron/session/mysql"
  13. nodb "gitea.com/macaron/session/nodb"
  14. postgres "gitea.com/macaron/session/postgres"
  15. redis "gitea.com/macaron/session/redis"
  16. )
  17. // VirtualSessionProvider represents a shadowed session provider implementation.
  18. type VirtualSessionProvider struct {
  19. lock sync.RWMutex
  20. provider session.Provider
  21. }
  22. // Init initializes the cookie session provider with given root path.
  23. func (o *VirtualSessionProvider) Init(gclifetime int64, config string) error {
  24. var opts session.Options
  25. if err := json.Unmarshal([]byte(config), &opts); err != nil {
  26. return err
  27. }
  28. // Note that these options are unprepared so we can't just use NewManager here.
  29. // Nor can we access the provider map in session.
  30. // So we will just have to do this by hand.
  31. // This is only slightly more wrong than modules/setting/session.go:23
  32. switch opts.Provider {
  33. case "memory":
  34. o.provider = &session.MemProvider{}
  35. case "file":
  36. o.provider = &session.FileProvider{}
  37. case "redis":
  38. o.provider = &redis.RedisProvider{}
  39. case "mysql":
  40. o.provider = &mysql.MysqlProvider{}
  41. case "postgres":
  42. o.provider = &postgres.PostgresProvider{}
  43. case "couchbase":
  44. o.provider = &couchbase.CouchbaseProvider{}
  45. case "memcache":
  46. o.provider = &memcache.MemcacheProvider{}
  47. case "nodb":
  48. o.provider = &nodb.NodbProvider{}
  49. default:
  50. return fmt.Errorf("VirtualSessionProvider: Unknown Provider: %s", opts.Provider)
  51. }
  52. return o.provider.Init(gclifetime, opts.ProviderConfig)
  53. }
  54. // Read returns raw session store by session ID.
  55. func (o *VirtualSessionProvider) Read(sid string) (session.RawStore, error) {
  56. o.lock.RLock()
  57. defer o.lock.RUnlock()
  58. if o.provider.Exist(sid) {
  59. return o.provider.Read(sid)
  60. }
  61. kv := make(map[interface{}]interface{})
  62. kv["_old_uid"] = "0"
  63. return NewVirtualStore(o, sid, kv), nil
  64. }
  65. // Exist returns true if session with given ID exists.
  66. func (o *VirtualSessionProvider) Exist(sid string) bool {
  67. return true
  68. }
  69. // Destroy deletes a session by session ID.
  70. func (o *VirtualSessionProvider) Destroy(sid string) error {
  71. o.lock.Lock()
  72. defer o.lock.Unlock()
  73. return o.provider.Destroy(sid)
  74. }
  75. // Regenerate regenerates a session store from old session ID to new one.
  76. func (o *VirtualSessionProvider) Regenerate(oldsid, sid string) (session.RawStore, error) {
  77. o.lock.Lock()
  78. defer o.lock.Unlock()
  79. return o.provider.Regenerate(oldsid, sid)
  80. }
  81. // Count counts and returns number of sessions.
  82. func (o *VirtualSessionProvider) Count() int {
  83. o.lock.RLock()
  84. defer o.lock.RUnlock()
  85. return o.provider.Count()
  86. }
  87. // GC calls GC to clean expired sessions.
  88. func (o *VirtualSessionProvider) GC() {
  89. o.provider.GC()
  90. }
  91. func init() {
  92. session.Register("VirtualSession", &VirtualSessionProvider{})
  93. }
  94. // VirtualStore represents a virtual session store implementation.
  95. type VirtualStore struct {
  96. p *VirtualSessionProvider
  97. sid string
  98. lock sync.RWMutex
  99. data map[interface{}]interface{}
  100. released bool
  101. }
  102. // NewVirtualStore creates and returns a virtual session store.
  103. func NewVirtualStore(p *VirtualSessionProvider, sid string, kv map[interface{}]interface{}) *VirtualStore {
  104. return &VirtualStore{
  105. p: p,
  106. sid: sid,
  107. data: kv,
  108. }
  109. }
  110. // Set sets value to given key in session.
  111. func (s *VirtualStore) Set(key, val interface{}) error {
  112. s.lock.Lock()
  113. defer s.lock.Unlock()
  114. s.data[key] = val
  115. return nil
  116. }
  117. // Get gets value by given key in session.
  118. func (s *VirtualStore) Get(key interface{}) interface{} {
  119. s.lock.RLock()
  120. defer s.lock.RUnlock()
  121. return s.data[key]
  122. }
  123. // Delete delete a key from session.
  124. func (s *VirtualStore) Delete(key interface{}) error {
  125. s.lock.Lock()
  126. defer s.lock.Unlock()
  127. delete(s.data, key)
  128. return nil
  129. }
  130. // ID returns current session ID.
  131. func (s *VirtualStore) ID() string {
  132. return s.sid
  133. }
  134. // Release releases resource and save data to provider.
  135. func (s *VirtualStore) Release() error {
  136. s.lock.Lock()
  137. defer s.lock.Unlock()
  138. // Now need to lock the provider
  139. s.p.lock.Lock()
  140. defer s.p.lock.Unlock()
  141. if oldUID, ok := s.data["_old_uid"]; (ok && (oldUID != "0" || len(s.data) > 1)) || (!ok && len(s.data) > 0) {
  142. // Now ensure that we don't exist!
  143. realProvider := s.p.provider
  144. if !s.released && realProvider.Exist(s.sid) {
  145. // This is an error!
  146. return fmt.Errorf("new sid '%s' already exists", s.sid)
  147. }
  148. realStore, err := realProvider.Read(s.sid)
  149. if err != nil {
  150. return err
  151. }
  152. if err := realStore.Flush(); err != nil {
  153. return err
  154. }
  155. for key, value := range s.data {
  156. if err := realStore.Set(key, value); err != nil {
  157. return err
  158. }
  159. }
  160. err = realStore.Release()
  161. if err == nil {
  162. s.released = true
  163. }
  164. return err
  165. }
  166. return nil
  167. }
  168. // Flush deletes all session data.
  169. func (s *VirtualStore) Flush() error {
  170. s.lock.Lock()
  171. defer s.lock.Unlock()
  172. s.data = make(map[interface{}]interface{})
  173. return nil
  174. }