added cleaner migration function for peer graph data
[p2pool.git] / p2pool / util / graph.py
index 068dbf3..0c3adfc 100644 (file)
@@ -29,6 +29,11 @@ def keep_largest(n, squash_key=nothing, key=lambda x: x, add_func=lambda a, b: a
         return dict(items)
     return _
 
+def _shift_bins_so_t_is_not_past_end(bins, last_bin_end, bin_width, t):
+    # returns new_bins, new_last_bin_end
+    shift = max(0, int(math.ceil((t - last_bin_end)/bin_width)))
+    return _shift(bins, shift, {}), last_bin_end + shift*bin_width
+
 class DataView(object):
     def __init__(self, desc, ds_desc, last_bin_end, bins):
         assert len(bins) == desc.bin_count
@@ -43,19 +48,15 @@ class DataView(object):
             value = {'null': value}
         elif self.ds_desc.multivalue_undefined_means_0 and 'null' not in value:
             value = dict(value, null=0) # use null to hold sample counter
-        shift = max(0, int(math.ceil((t - self.last_bin_end)/self.desc.bin_width)))
-        self.bins = _shift(self.bins, shift, {})
-        self.last_bin_end += shift*self.desc.bin_width
+        self.bins, self.last_bin_end = _shift_bins_so_t_is_not_past_end(self.bins, self.last_bin_end, self.desc.bin_width, t)
         
-        bin = int(math.ceil((self.last_bin_end - self.desc.bin_width - t)/self.desc.bin_width))
+        bin = int(math.floor((self.last_bin_end - t)/self.desc.bin_width))
+        assert bin >= 0
         if bin < self.desc.bin_count:
             self.bins[bin] = self.ds_desc.keep_largest_func(combine_bins(self.bins[bin], dict((k, (v, 1)) for k, v in value.iteritems())))
     
     def get_data(self, t):
-        shift = max(0, int(math.ceil((t - self.last_bin_end)/self.desc.bin_width)))
-        bins = _shift(self.bins, shift, {})
-        last_bin_end = self.last_bin_end + shift*self.desc.bin_width
-        
+        bins, last_bin_end = _shift_bins_so_t_is_not_past_end(self.bins, self.last_bin_end, self.desc.bin_width, t)
         assert last_bin_end - self.desc.bin_width <= t <= last_bin_end
         
         def _((i, bin)):
@@ -134,3 +135,19 @@ class HistoryDatabase(object):
     def to_obj(self):
         return dict((ds_name, dict((dv_name, dict(last_bin_end=dv.last_bin_end, bin_width=dv.desc.bin_width, bins=dv.bins))
             for dv_name, dv in ds.dataviews.iteritems())) for ds_name, ds in self.datastreams.iteritems())
+
+
+def make_multivalue_migrator(multivalue_keys, post_func=lambda bins: bins):
+    def _(ds_name, ds_desc, dv_name, dv_desc, obj):
+        if not obj:
+            last_bin_end = 0
+            bins = dv_desc.bin_count*[{}]
+        else:
+            inputs = dict((k, obj.get(v, {dv_name: dict(bins=[{}]*dv_desc.bin_count, last_bin_end=0)})[dv_name]) for k, v in multivalue_keys.iteritems())
+            last_bin_end = max(inp['last_bin_end'] for inp in inputs.itervalues()) if inputs else 0
+            assert all(len(inp['bins']) == dv_desc.bin_count for inp in inputs.itervalues())
+            inputs = dict((k, dict(zip(['bins', 'last_bin_end'], _shift_bins_so_t_is_not_past_end(v['bins'], v['last_bin_end'], dv_desc.bin_width, last_bin_end)))) for k, v in inputs.iteritems())
+            assert len(set(inp['last_bin_end'] for inp in inputs.itervalues())) <= 1
+            bins = post_func([dict((k, v['bins'][i]['null']) for k, v in inputs.iteritems() if 'null' in v['bins'][i]) for i in xrange(dv_desc.bin_count)])
+        return DataView(dv_desc, ds_desc, last_bin_end, bins)
+    return _