1 module dparse.rollback_allocator;
2 
3 import core.memory : GC;
4 
5 //version = debug_rollback_allocator;
6 
7 /**
8  * Pointer-bump allocator with rollback functionality.
9  */
10 struct RollbackAllocator
11 {
12 public:
13 
14     // must be multiple of 8
15     enum memoryAlignment = 16u;
16 
17     @disable this(this);
18 
19     ~this()
20     {
21         while (first !is null)
22             deallocateNode();
23     }
24 
25     /**
26      * Allocates `size` bytes of memory.
27      */
28     void[] allocate(const size_t size)
29     out (arr)
30     {
31         assert(arr.length == size);
32     }
33     do
34     {
35         import std.algorithm.comparison : min;
36 
37         if (first is null)
38             allocateNode(size);
39 
40         // Memory align the size
41         immutable size_t s = size & ~(cast(size_t) memoryAlignment - 1);
42         immutable size_t s2 = s == size ? size : s + memoryAlignment;
43 
44         size_t fu = first.used;
45         size_t end = fu + s2;
46         //assert(end >= fu + size);
47         //assert(end % 8 == 0);
48         if (end > first.mem.length)
49         {
50             allocateNode(size);
51             fu = first.used;
52             end = fu + s2;
53         }
54         //assert((cast(size_t) first.mem.ptr) % 8 == 0);
55         //assert(((cast(size_t) first.mem.ptr) + end) % 8 == 0);
56         void[] m = first.mem[fu .. fu + size];
57         // alignment can make our size here bigger than what we actually have, so we clamp down to the used amount
58         first.used = min(end, first.mem.length);
59         return m;
60     }
61 
62     /**
63      * Rolls back the allocator to the given checkpoint.
64      */
65     void rollback(size_t point)
66     {
67         import std.stdio : stderr;
68 
69         if (point == 0)
70         {
71             while (first)
72                 deallocateNode();
73             return;
74         }
75         else
76             assert(contains(point), "Attepmted to roll back to a point not in the allocator.");
77 
78         // while `first !is null` is always going to pass after the contains(point) check, it may no longer pass after deallocateNode
79         while (first !is null && !first.contains(point))
80             deallocateNode();
81         assert(first !is null);
82 
83         immutable begin = point - cast(size_t) first.mem.ptr;
84         version (debug_rollback_allocator)
85             (cast(ubyte[]) first.mem)[begin .. $] = 0;
86         first.used = begin;
87         assert(cast(size_t) first.mem.ptr + first.used == point);
88     }
89 
90     /**
91      * Get a checkpoint for the allocator.
92      */
93     size_t setCheckpoint() const nothrow @nogc
94     {
95         assert(first is null || first.used <= first.mem.length);
96         return first is null ? 0 : cast(size_t) first.mem.ptr + first.used;
97     }
98 
99     /**
100      * Allocates a T and returns a pointer to it
101      */
102     auto make(T, Args...)(auto ref Args args)
103     {
104         import std.algorithm.comparison : max;
105         import stdx.allocator : stateSize;
106         import std.conv : emplace;
107 
108         void[] mem = allocate(max(stateSize!T, 1));
109         if (mem.ptr is null)
110             return null;
111         static if (is(T == class))
112             return emplace!T(mem, args);
113         else
114             return emplace(cast(T*) mem.ptr, args);
115     }
116 
117 private:
118 
119     // Used for debugging
120     bool contains(size_t point) const
121     {
122         for (const(Node)* n = first; n !is null; n = n.next)
123             if (n.contains(point))
124                 return true;
125         return false;
126     }
127 
128     static struct Node
129     {
130         Node* next;
131         size_t used;
132         ubyte[] mem;
133 
134         bool contains(size_t p) const pure nothrow @nogc @safe
135         {
136             return p >= cast(size_t) mem.ptr && p <= cast(size_t) mem.ptr + mem.length;
137         }
138     }
139 
140     void allocateNode(size_t size)
141     {
142         import core.exception : onOutOfMemoryError;
143         import std.algorithm : max;
144         import std.conv : emplace;
145         import stdx.allocator.mallocator : AlignedMallocator;
146 
147         enum ALLOC_SIZE = 1024 * 8;
148 
149         ubyte[] m = cast(ubyte[]) AlignedMallocator.instance.alignedAllocate(max(size + Node.sizeof, ALLOC_SIZE), memoryAlignment);
150         if (m is null)
151             onOutOfMemoryError();
152         GC.addRange(m.ptr, m.length);
153 
154         version (debug_rollback_allocator)
155             m[] = 0;
156         Node* n = emplace!Node(cast(Node*) m.ptr, first, 0, m[Node.sizeof .. $]);
157         assert((cast(size_t) n.mem.ptr) % 8 == 0, "The memoriez!");
158         first = n;
159     }
160 
161     void deallocateNode()
162     {
163         assert(first !is null);
164         import stdx.allocator.mallocator : AlignedMallocator;
165 
166         Node* next = first.next;
167         ubyte[] mem = (cast(ubyte*) first)[0 .. Node.sizeof + first.mem.length];
168         version (debug_rollback_allocator)
169             mem[] = 0;
170         GC.removeRange(mem.ptr);
171         AlignedMallocator.instance.deallocate(mem);
172         first = next;
173     }
174 
175     Node* first;
176 }
177 
178 @("most simple usage, including memory across multiple pointers")
179 unittest
180 {
181     RollbackAllocator rba;
182     size_t[10] checkpoint;
183     foreach (i; 0 .. 10)
184     {
185         checkpoint[i] = rba.setCheckpoint();
186         rba.allocate(4000);
187     }
188 
189     foreach_reverse (i; 0 .. 10)
190     {
191         rba.rollback(checkpoint[i]);
192     }
193 }
194 
195 @("many allocates and frees while leaking memory")
196 unittest
197 {
198     RollbackAllocator rba;
199     foreach (i; 0 .. 10)
200     {
201         size_t[3] checkpoint;
202         foreach (n; 0 .. 3)
203         {
204             checkpoint[n] = rba.setCheckpoint();
205             rba.allocate(4000);
206         }
207         foreach_reverse (n; 1 .. 3)
208         {
209             rba.rollback(checkpoint[n]);
210         }
211     }
212 }
213 
214 @("allocating overly big")
215 unittest
216 {
217     import std.stdio : stderr;
218 
219     RollbackAllocator rba;
220     size_t[200] checkpoint;
221     size_t cp;
222     foreach (i; 1024 * 8 - 100 .. 1024 * 8 + 100)
223     {
224         try
225         {
226             checkpoint[cp++] = rba.setCheckpoint();
227             rba.allocate(i);
228         }
229         catch (Error e)
230         {
231             stderr.writeln("Unittest: crashed in allocating ", i, " bytes");
232             throw e;
233         }
234     }
235 
236     foreach_reverse (i, c; checkpoint[0 .. cp])
237     {
238         try
239         {
240             rba.rollback(c);
241         }
242         catch (Error e)
243         {
244             stderr.writeln("Unittest: crashed in rolling back ", i, " (address ", c, ")");
245             throw e;
246         }
247     }
248 }