You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

589 lines
13 KiB

  1. // Copyright (C) MongoDB, Inc. 2017-present.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License"); you may
  4. // not use this file except in compliance with the License. You may obtain
  5. // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
  6. package bsonrw
  7. import (
  8. "errors"
  9. "fmt"
  10. "io"
  11. "math"
  12. "strconv"
  13. "sync"
  14. "go.mongodb.org/mongo-driver/bson/bsontype"
  15. "go.mongodb.org/mongo-driver/bson/primitive"
  16. "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
  17. )
  18. var _ ValueWriter = (*valueWriter)(nil)
  19. var vwPool = sync.Pool{
  20. New: func() interface{} {
  21. return new(valueWriter)
  22. },
  23. }
  24. // BSONValueWriterPool is a pool for BSON ValueWriters.
  25. type BSONValueWriterPool struct {
  26. pool sync.Pool
  27. }
  28. // NewBSONValueWriterPool creates a new pool for ValueWriter instances that write to BSON.
  29. func NewBSONValueWriterPool() *BSONValueWriterPool {
  30. return &BSONValueWriterPool{
  31. pool: sync.Pool{
  32. New: func() interface{} {
  33. return new(valueWriter)
  34. },
  35. },
  36. }
  37. }
  38. // Get retrieves a BSON ValueWriter from the pool and resets it to use w as the destination.
  39. func (bvwp *BSONValueWriterPool) Get(w io.Writer) ValueWriter {
  40. vw := bvwp.pool.Get().(*valueWriter)
  41. if writer, ok := w.(*SliceWriter); ok {
  42. vw.reset(*writer)
  43. vw.w = writer
  44. return vw
  45. }
  46. vw.buf = vw.buf[:0]
  47. vw.w = w
  48. return vw
  49. }
  50. // Put inserts a ValueWriter into the pool. If the ValueWriter is not a BSON ValueWriter, nothing
  51. // happens and ok will be false.
  52. func (bvwp *BSONValueWriterPool) Put(vw ValueWriter) (ok bool) {
  53. bvw, ok := vw.(*valueWriter)
  54. if !ok {
  55. return false
  56. }
  57. if _, ok := bvw.w.(*SliceWriter); ok {
  58. bvw.buf = nil
  59. }
  60. bvw.w = nil
  61. bvwp.pool.Put(bvw)
  62. return true
  63. }
  64. // This is here so that during testing we can change it and not require
  65. // allocating a 4GB slice.
  66. var maxSize = math.MaxInt32
  67. var errNilWriter = errors.New("cannot create a ValueWriter from a nil io.Writer")
  68. type errMaxDocumentSizeExceeded struct {
  69. size int64
  70. }
  71. func (mdse errMaxDocumentSizeExceeded) Error() string {
  72. return fmt.Sprintf("document size (%d) is larger than the max int32", mdse.size)
  73. }
  74. type vwMode int
  75. const (
  76. _ vwMode = iota
  77. vwTopLevel
  78. vwDocument
  79. vwArray
  80. vwValue
  81. vwElement
  82. vwCodeWithScope
  83. )
  84. func (vm vwMode) String() string {
  85. var str string
  86. switch vm {
  87. case vwTopLevel:
  88. str = "TopLevel"
  89. case vwDocument:
  90. str = "DocumentMode"
  91. case vwArray:
  92. str = "ArrayMode"
  93. case vwValue:
  94. str = "ValueMode"
  95. case vwElement:
  96. str = "ElementMode"
  97. case vwCodeWithScope:
  98. str = "CodeWithScopeMode"
  99. default:
  100. str = "UnknownMode"
  101. }
  102. return str
  103. }
  104. type vwState struct {
  105. mode mode
  106. key string
  107. arrkey int
  108. start int32
  109. }
  110. type valueWriter struct {
  111. w io.Writer
  112. buf []byte
  113. stack []vwState
  114. frame int64
  115. }
  116. func (vw *valueWriter) advanceFrame() {
  117. if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack
  118. length := len(vw.stack)
  119. if length+1 >= cap(vw.stack) {
  120. // double it
  121. buf := make([]vwState, 2*cap(vw.stack)+1)
  122. copy(buf, vw.stack)
  123. vw.stack = buf
  124. }
  125. vw.stack = vw.stack[:length+1]
  126. }
  127. vw.frame++
  128. }
  129. func (vw *valueWriter) push(m mode) {
  130. vw.advanceFrame()
  131. // Clean the stack
  132. vw.stack[vw.frame].mode = m
  133. vw.stack[vw.frame].key = ""
  134. vw.stack[vw.frame].arrkey = 0
  135. vw.stack[vw.frame].start = 0
  136. vw.stack[vw.frame].mode = m
  137. switch m {
  138. case mDocument, mArray, mCodeWithScope:
  139. vw.reserveLength()
  140. }
  141. }
  142. func (vw *valueWriter) reserveLength() {
  143. vw.stack[vw.frame].start = int32(len(vw.buf))
  144. vw.buf = append(vw.buf, 0x00, 0x00, 0x00, 0x00)
  145. }
  146. func (vw *valueWriter) pop() {
  147. switch vw.stack[vw.frame].mode {
  148. case mElement, mValue:
  149. vw.frame--
  150. case mDocument, mArray, mCodeWithScope:
  151. vw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
  152. }
  153. }
  154. // NewBSONValueWriter creates a ValueWriter that writes BSON to w.
  155. //
  156. // This ValueWriter will only write entire documents to the io.Writer and it
  157. // will buffer the document as it is built.
  158. func NewBSONValueWriter(w io.Writer) (ValueWriter, error) {
  159. if w == nil {
  160. return nil, errNilWriter
  161. }
  162. return newValueWriter(w), nil
  163. }
  164. func newValueWriter(w io.Writer) *valueWriter {
  165. vw := new(valueWriter)
  166. stack := make([]vwState, 1, 5)
  167. stack[0] = vwState{mode: mTopLevel}
  168. vw.w = w
  169. vw.stack = stack
  170. return vw
  171. }
  172. func newValueWriterFromSlice(buf []byte) *valueWriter {
  173. vw := new(valueWriter)
  174. stack := make([]vwState, 1, 5)
  175. stack[0] = vwState{mode: mTopLevel}
  176. vw.stack = stack
  177. vw.buf = buf
  178. return vw
  179. }
  180. func (vw *valueWriter) reset(buf []byte) {
  181. if vw.stack == nil {
  182. vw.stack = make([]vwState, 1, 5)
  183. }
  184. vw.stack = vw.stack[:1]
  185. vw.stack[0] = vwState{mode: mTopLevel}
  186. vw.buf = buf
  187. vw.frame = 0
  188. vw.w = nil
  189. }
  190. func (vw *valueWriter) invalidTransitionError(destination mode, name string, modes []mode) error {
  191. te := TransitionError{
  192. name: name,
  193. current: vw.stack[vw.frame].mode,
  194. destination: destination,
  195. modes: modes,
  196. action: "write",
  197. }
  198. if vw.frame != 0 {
  199. te.parent = vw.stack[vw.frame-1].mode
  200. }
  201. return te
  202. }
  203. func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error {
  204. switch vw.stack[vw.frame].mode {
  205. case mElement:
  206. vw.buf = bsoncore.AppendHeader(vw.buf, t, vw.stack[vw.frame].key)
  207. case mValue:
  208. // TODO: Do this with a cache of the first 1000 or so array keys.
  209. vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey))
  210. default:
  211. modes := []mode{mElement, mValue}
  212. if addmodes != nil {
  213. modes = append(modes, addmodes...)
  214. }
  215. return vw.invalidTransitionError(destination, callerName, modes)
  216. }
  217. return nil
  218. }
  219. func (vw *valueWriter) WriteValueBytes(t bsontype.Type, b []byte) error {
  220. if err := vw.writeElementHeader(t, mode(0), "WriteValueBytes"); err != nil {
  221. return err
  222. }
  223. vw.buf = append(vw.buf, b...)
  224. vw.pop()
  225. return nil
  226. }
  227. func (vw *valueWriter) WriteArray() (ArrayWriter, error) {
  228. if err := vw.writeElementHeader(bsontype.Array, mArray, "WriteArray"); err != nil {
  229. return nil, err
  230. }
  231. vw.push(mArray)
  232. return vw, nil
  233. }
  234. func (vw *valueWriter) WriteBinary(b []byte) error {
  235. return vw.WriteBinaryWithSubtype(b, 0x00)
  236. }
  237. func (vw *valueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
  238. if err := vw.writeElementHeader(bsontype.Binary, mode(0), "WriteBinaryWithSubtype"); err != nil {
  239. return err
  240. }
  241. vw.buf = bsoncore.AppendBinary(vw.buf, btype, b)
  242. vw.pop()
  243. return nil
  244. }
  245. func (vw *valueWriter) WriteBoolean(b bool) error {
  246. if err := vw.writeElementHeader(bsontype.Boolean, mode(0), "WriteBoolean"); err != nil {
  247. return err
  248. }
  249. vw.buf = bsoncore.AppendBoolean(vw.buf, b)
  250. vw.pop()
  251. return nil
  252. }
  253. func (vw *valueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
  254. if err := vw.writeElementHeader(bsontype.CodeWithScope, mCodeWithScope, "WriteCodeWithScope"); err != nil {
  255. return nil, err
  256. }
  257. // CodeWithScope is a different than other types because we need an extra
  258. // frame on the stack. In the EndDocument code, we write the document
  259. // length, pop, write the code with scope length, and pop. To simplify the
  260. // pop code, we push a spacer frame that we'll always jump over.
  261. vw.push(mCodeWithScope)
  262. vw.buf = bsoncore.AppendString(vw.buf, code)
  263. vw.push(mSpacer)
  264. vw.push(mDocument)
  265. return vw, nil
  266. }
  267. func (vw *valueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
  268. if err := vw.writeElementHeader(bsontype.DBPointer, mode(0), "WriteDBPointer"); err != nil {
  269. return err
  270. }
  271. vw.buf = bsoncore.AppendDBPointer(vw.buf, ns, oid)
  272. vw.pop()
  273. return nil
  274. }
  275. func (vw *valueWriter) WriteDateTime(dt int64) error {
  276. if err := vw.writeElementHeader(bsontype.DateTime, mode(0), "WriteDateTime"); err != nil {
  277. return err
  278. }
  279. vw.buf = bsoncore.AppendDateTime(vw.buf, dt)
  280. vw.pop()
  281. return nil
  282. }
  283. func (vw *valueWriter) WriteDecimal128(d128 primitive.Decimal128) error {
  284. if err := vw.writeElementHeader(bsontype.Decimal128, mode(0), "WriteDecimal128"); err != nil {
  285. return err
  286. }
  287. vw.buf = bsoncore.AppendDecimal128(vw.buf, d128)
  288. vw.pop()
  289. return nil
  290. }
  291. func (vw *valueWriter) WriteDouble(f float64) error {
  292. if err := vw.writeElementHeader(bsontype.Double, mode(0), "WriteDouble"); err != nil {
  293. return err
  294. }
  295. vw.buf = bsoncore.AppendDouble(vw.buf, f)
  296. vw.pop()
  297. return nil
  298. }
  299. func (vw *valueWriter) WriteInt32(i32 int32) error {
  300. if err := vw.writeElementHeader(bsontype.Int32, mode(0), "WriteInt32"); err != nil {
  301. return err
  302. }
  303. vw.buf = bsoncore.AppendInt32(vw.buf, i32)
  304. vw.pop()
  305. return nil
  306. }
  307. func (vw *valueWriter) WriteInt64(i64 int64) error {
  308. if err := vw.writeElementHeader(bsontype.Int64, mode(0), "WriteInt64"); err != nil {
  309. return err
  310. }
  311. vw.buf = bsoncore.AppendInt64(vw.buf, i64)
  312. vw.pop()
  313. return nil
  314. }
  315. func (vw *valueWriter) WriteJavascript(code string) error {
  316. if err := vw.writeElementHeader(bsontype.JavaScript, mode(0), "WriteJavascript"); err != nil {
  317. return err
  318. }
  319. vw.buf = bsoncore.AppendJavaScript(vw.buf, code)
  320. vw.pop()
  321. return nil
  322. }
  323. func (vw *valueWriter) WriteMaxKey() error {
  324. if err := vw.writeElementHeader(bsontype.MaxKey, mode(0), "WriteMaxKey"); err != nil {
  325. return err
  326. }
  327. vw.pop()
  328. return nil
  329. }
  330. func (vw *valueWriter) WriteMinKey() error {
  331. if err := vw.writeElementHeader(bsontype.MinKey, mode(0), "WriteMinKey"); err != nil {
  332. return err
  333. }
  334. vw.pop()
  335. return nil
  336. }
  337. func (vw *valueWriter) WriteNull() error {
  338. if err := vw.writeElementHeader(bsontype.Null, mode(0), "WriteNull"); err != nil {
  339. return err
  340. }
  341. vw.pop()
  342. return nil
  343. }
  344. func (vw *valueWriter) WriteObjectID(oid primitive.ObjectID) error {
  345. if err := vw.writeElementHeader(bsontype.ObjectID, mode(0), "WriteObjectID"); err != nil {
  346. return err
  347. }
  348. vw.buf = bsoncore.AppendObjectID(vw.buf, oid)
  349. vw.pop()
  350. return nil
  351. }
  352. func (vw *valueWriter) WriteRegex(pattern string, options string) error {
  353. if err := vw.writeElementHeader(bsontype.Regex, mode(0), "WriteRegex"); err != nil {
  354. return err
  355. }
  356. vw.buf = bsoncore.AppendRegex(vw.buf, pattern, sortStringAlphebeticAscending(options))
  357. vw.pop()
  358. return nil
  359. }
  360. func (vw *valueWriter) WriteString(s string) error {
  361. if err := vw.writeElementHeader(bsontype.String, mode(0), "WriteString"); err != nil {
  362. return err
  363. }
  364. vw.buf = bsoncore.AppendString(vw.buf, s)
  365. vw.pop()
  366. return nil
  367. }
  368. func (vw *valueWriter) WriteDocument() (DocumentWriter, error) {
  369. if vw.stack[vw.frame].mode == mTopLevel {
  370. vw.reserveLength()
  371. return vw, nil
  372. }
  373. if err := vw.writeElementHeader(bsontype.EmbeddedDocument, mDocument, "WriteDocument", mTopLevel); err != nil {
  374. return nil, err
  375. }
  376. vw.push(mDocument)
  377. return vw, nil
  378. }
  379. func (vw *valueWriter) WriteSymbol(symbol string) error {
  380. if err := vw.writeElementHeader(bsontype.Symbol, mode(0), "WriteSymbol"); err != nil {
  381. return err
  382. }
  383. vw.buf = bsoncore.AppendSymbol(vw.buf, symbol)
  384. vw.pop()
  385. return nil
  386. }
  387. func (vw *valueWriter) WriteTimestamp(t uint32, i uint32) error {
  388. if err := vw.writeElementHeader(bsontype.Timestamp, mode(0), "WriteTimestamp"); err != nil {
  389. return err
  390. }
  391. vw.buf = bsoncore.AppendTimestamp(vw.buf, t, i)
  392. vw.pop()
  393. return nil
  394. }
  395. func (vw *valueWriter) WriteUndefined() error {
  396. if err := vw.writeElementHeader(bsontype.Undefined, mode(0), "WriteUndefined"); err != nil {
  397. return err
  398. }
  399. vw.pop()
  400. return nil
  401. }
  402. func (vw *valueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
  403. switch vw.stack[vw.frame].mode {
  404. case mTopLevel, mDocument:
  405. default:
  406. return nil, vw.invalidTransitionError(mElement, "WriteDocumentElement", []mode{mTopLevel, mDocument})
  407. }
  408. vw.push(mElement)
  409. vw.stack[vw.frame].key = key
  410. return vw, nil
  411. }
  412. func (vw *valueWriter) WriteDocumentEnd() error {
  413. switch vw.stack[vw.frame].mode {
  414. case mTopLevel, mDocument:
  415. default:
  416. return fmt.Errorf("incorrect mode to end document: %s", vw.stack[vw.frame].mode)
  417. }
  418. vw.buf = append(vw.buf, 0x00)
  419. err := vw.writeLength()
  420. if err != nil {
  421. return err
  422. }
  423. if vw.stack[vw.frame].mode == mTopLevel {
  424. if vw.w != nil {
  425. if sw, ok := vw.w.(*SliceWriter); ok {
  426. *sw = vw.buf
  427. } else {
  428. _, err = vw.w.Write(vw.buf)
  429. if err != nil {
  430. return err
  431. }
  432. // reset buffer
  433. vw.buf = vw.buf[:0]
  434. }
  435. }
  436. }
  437. vw.pop()
  438. if vw.stack[vw.frame].mode == mCodeWithScope {
  439. // We ignore the error here because of the gaurantee of writeLength.
  440. // See the docs for writeLength for more info.
  441. _ = vw.writeLength()
  442. vw.pop()
  443. }
  444. return nil
  445. }
  446. func (vw *valueWriter) WriteArrayElement() (ValueWriter, error) {
  447. if vw.stack[vw.frame].mode != mArray {
  448. return nil, vw.invalidTransitionError(mValue, "WriteArrayElement", []mode{mArray})
  449. }
  450. arrkey := vw.stack[vw.frame].arrkey
  451. vw.stack[vw.frame].arrkey++
  452. vw.push(mValue)
  453. vw.stack[vw.frame].arrkey = arrkey
  454. return vw, nil
  455. }
  456. func (vw *valueWriter) WriteArrayEnd() error {
  457. if vw.stack[vw.frame].mode != mArray {
  458. return fmt.Errorf("incorrect mode to end array: %s", vw.stack[vw.frame].mode)
  459. }
  460. vw.buf = append(vw.buf, 0x00)
  461. err := vw.writeLength()
  462. if err != nil {
  463. return err
  464. }
  465. vw.pop()
  466. return nil
  467. }
  468. // NOTE: We assume that if we call writeLength more than once the same function
  469. // within the same function without altering the vw.buf that this method will
  470. // not return an error. If this changes ensure that the following methods are
  471. // updated:
  472. //
  473. // - WriteDocumentEnd
  474. func (vw *valueWriter) writeLength() error {
  475. length := len(vw.buf)
  476. if length > maxSize {
  477. return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))}
  478. }
  479. length = length - int(vw.stack[vw.frame].start)
  480. start := vw.stack[vw.frame].start
  481. vw.buf[start+0] = byte(length)
  482. vw.buf[start+1] = byte(length >> 8)
  483. vw.buf[start+2] = byte(length >> 16)
  484. vw.buf[start+3] = byte(length >> 24)
  485. return nil
  486. }