1 module dparse.rollback_allocator;
2 
3 //version = debug_rollback_allocator;
4 
5 /**
6  * Pointer-bump allocator with rollback functionality.
7  */
8 struct RollbackAllocator
9 {
10 public:
11 
12     // must be multiple of 8
13     enum memoryAlignment = 16u;
14 
15     @disable this(this);
16 
17     ~this()
18     {
19         while (first !is null)
20             deallocateNode();
21     }
22 
23     /**
24      * Allocates `size` bytes of memory.
25      */
26     void[] allocate(const size_t size)
27     out (arr)
28     {
29         assert(arr.length == size);
30     }
31     do
32     {
33         import std.algorithm.comparison : min;
34 
35         if (first is null)
36             allocateNode(size);
37 
38         // Memory align the size
39         immutable size_t s = size & ~(cast(size_t) memoryAlignment - 1);
40         immutable size_t s2 = s == size ? size : s + memoryAlignment;
41 
42         size_t fu = first.used;
43         size_t end = fu + s2;
44         //assert(end >= fu + size);
45         //assert(end % 8 == 0);
46         if (end > first.mem.length)
47         {
48             allocateNode(size);
49             fu = first.used;
50             end = fu + s2;
51         }
52         //assert((cast(size_t) first.mem.ptr) % 8 == 0);
53         //assert(((cast(size_t) first.mem.ptr) + end) % 8 == 0);
54         void[] m = first.mem[fu .. fu + size];
55         // alignment can make our size here bigger than what we actually have, so we clamp down to the used amount
56         first.used = min(end, first.mem.length);
57         return m;
58     }
59 
60     /**
61      * Rolls back the allocator to the given checkpoint.
62      */
63     void rollback(size_t point)
64     {
65         import std.stdio : stderr;
66 
67         if (point == 0)
68         {
69             while (first)
70                 deallocateNode();
71             return;
72         }
73         else
74             assert(contains(point), "Attepmted to roll back to a point not in the allocator.");
75 
76         // while `first !is null` is always going to pass after the contains(point) check, it may no longer pass after deallocateNode
77         while (first !is null && !first.contains(point))
78             deallocateNode();
79         assert(first !is null);
80 
81         immutable begin = point - cast(size_t) first.mem.ptr;
82         version (debug_rollback_allocator)
83             (cast(ubyte[]) first.mem)[begin .. $] = 0;
84         first.used = begin;
85         assert(cast(size_t) first.mem.ptr + first.used == point);
86     }
87 
88     /**
89      * Get a checkpoint for the allocator.
90      */
91     size_t setCheckpoint() const nothrow @nogc
92     {
93         assert(first is null || first.used <= first.mem.length);
94         return first is null ? 0 : cast(size_t) first.mem.ptr + first.used;
95     }
96 
97     /**
98      * Allocates a T and returns a pointer to it
99      */
100     auto make(T, Args...)(auto ref Args args)
101     {
102         import std.algorithm.comparison : max;
103         import stdx.allocator : stateSize;
104         import std.conv : emplace;
105 
106         void[] mem = allocate(max(stateSize!T, 1));
107         if (mem.ptr is null)
108             return null;
109         static if (is(T == class))
110             return emplace!T(mem, args);
111         else
112             return emplace(cast(T*) mem.ptr, args);
113     }
114 
115 private:
116 
117     // Used for debugging
118     bool contains(size_t point) const
119     {
120         for (const(Node)* n = first; n !is null; n = n.next)
121             if (n.contains(point))
122                 return true;
123         return false;
124     }
125 
126     static struct Node
127     {
128         Node* next;
129         size_t used;
130         ubyte[] mem;
131 
132         bool contains(size_t p) const pure nothrow @nogc @safe
133         {
134             return p >= cast(size_t) mem.ptr && p <= cast(size_t) mem.ptr + mem.length;
135         }
136     }
137 
138     void allocateNode(size_t size)
139     {
140         import core.exception : onOutOfMemoryError;
141         import std.algorithm : max;
142         import std.conv : emplace;
143         import stdx.allocator.mallocator : AlignedMallocator;
144 
145         enum ALLOC_SIZE = 1024 * 8;
146 
147         ubyte[] m = cast(ubyte[]) AlignedMallocator.instance.alignedAllocate(max(size + Node.sizeof, ALLOC_SIZE), memoryAlignment);
148         if (m is null)
149             onOutOfMemoryError();
150 
151         version (debug_rollback_allocator)
152             m[] = 0;
153         Node* n = emplace!Node(cast(Node*) m.ptr, first, 0, m[Node.sizeof .. $]);
154         assert((cast(size_t) n.mem.ptr) % 8 == 0, "The memoriez!");
155         first = n;
156     }
157 
158     void deallocateNode()
159     {
160         assert(first !is null);
161         import stdx.allocator.mallocator : AlignedMallocator;
162 
163         Node* next = first.next;
164         ubyte[] mem = (cast(ubyte*) first)[0 .. Node.sizeof + first.mem.length];
165         version (debug_rollback_allocator)
166             mem[] = 0;
167         AlignedMallocator.instance.deallocate(mem);
168         first = next;
169     }
170 
171     Node* first;
172 }
173 
174 @("most simple usage, including memory across multiple pointers")
175 unittest
176 {
177     RollbackAllocator rba;
178     size_t[10] checkpoint;
179     foreach (i; 0 .. 10)
180     {
181         checkpoint[i] = rba.setCheckpoint();
182         rba.allocate(4000);
183     }
184 
185     foreach_reverse (i; 0 .. 10)
186     {
187         rba.rollback(checkpoint[i]);
188     }
189 }
190 
191 @("many allocates and frees while leaking memory")
192 unittest
193 {
194     RollbackAllocator rba;
195     foreach (i; 0 .. 10)
196     {
197         size_t[3] checkpoint;
198         foreach (n; 0 .. 3)
199         {
200             checkpoint[n] = rba.setCheckpoint();
201             rba.allocate(4000);
202         }
203         foreach_reverse (n; 1 .. 3)
204         {
205             rba.rollback(checkpoint[n]);
206         }
207     }
208 }
209 
210 @("allocating overly big")
211 unittest
212 {
213     import std.stdio : stderr;
214 
215     RollbackAllocator rba;
216     size_t[200] checkpoint;
217     size_t cp;
218     foreach (i; 1024 * 8 - 100 .. 1024 * 8 + 100)
219     {
220         try
221         {
222             checkpoint[cp++] = rba.setCheckpoint();
223             rba.allocate(i);
224         }
225         catch (Error e)
226         {
227             stderr.writeln("Unittest: crashed in allocating ", i, " bytes");
228             throw e;
229         }
230     }
231 
232     foreach_reverse (i, c; checkpoint[0 .. cp])
233     {
234         try
235         {
236             rba.rollback(c);
237         }
238         catch (Error e)
239         {
240             stderr.writeln("Unittest: crashed in rolling back ", i, " (address ", c, ")");
241             throw e;
242         }
243     }
244 }