diff --git a/privacy/analysis/privacy_ledger.py b/privacy/analysis/privacy_ledger.py
index 1dc995f..f6394d0 100644
--- a/privacy/analysis/privacy_ledger.py
+++ b/privacy/analysis/privacy_ledger.py
@@ -234,13 +234,22 @@ class QueryWithLedger(dp_query.DPQuery):
     """See base class."""
     return self._query.derive_sample_params(global_state)
 
-  def initial_sample_state(self, global_state, tensors):
+  def initial_sample_state(self, global_state, template):
     """See base class."""
-    return self._query.initial_sample_state(global_state, tensors)
+    return self._query.initial_sample_state(global_state, template)
 
-  def accumulate_record(self, params, sample_state, record):
+  def preprocess_record(self, params, record):
     """See base class."""
-    return self._query.accumulate_record(params, sample_state, record)
+    return self._query.preprocess_record(params, record)
+
+  def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
+    """See base class."""
+    return self._query.accumulate_preprocessed_record(
+        sample_state, preprocessed_record)
+
+  def merge_sample_states(self, sample_state_1, sample_state_2):
+    """See base class."""
+    return self._query.merge_sample_states(sample_state_1, sample_state_2)
 
   def get_noised_result(self, sample_state, global_state):
     """Ensures sample is recorded to the ledger and returns noised result."""
diff --git a/privacy/dp_query/BUILD b/privacy/dp_query/BUILD
index 90dae9b..e91dce4 100644
--- a/privacy/dp_query/BUILD
+++ b/privacy/dp_query/BUILD
@@ -5,6 +5,10 @@ licenses(["notice"])  # Apache 2.0
 py_library(
     name = "dp_query",
     srcs = ["dp_query.py"],
+    deps = [
+        "//third_party/py/distutils",
+        "//third_party/py/tensorflow",
+    ],
 )
 
 py_library(
diff --git a/privacy/dp_query/dp_query.py b/privacy/dp_query/dp_query.py
index 5123520..116b8be 100644
--- a/privacy/dp_query/dp_query.py
+++ b/privacy/dp_query/dp_query.py
@@ -1,4 +1,4 @@
-# Copyright 2018, The TensorFlow Authors.
+# Copyright 2019, The TensorFlow Authors.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -47,6 +47,13 @@ from __future__ import division
 from __future__ import print_function
 
 import abc
+from distutils.version import LooseVersion
+
+import tensorflow as tf
+if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
+  nest = tf.contrib.framework.nest
+else:
+  nest = tf.nest
 
 
 class DPQuery(object):
@@ -54,12 +61,10 @@ class DPQuery(object):
 
   __metaclass__ = abc.ABCMeta
 
-  @abc.abstractmethod
   def initial_global_state(self):
     """Returns the initial global state for the DPQuery."""
-    pass
+    return ()
 
-  @abc.abstractmethod
   def derive_sample_params(self, global_state):
     """Given the global state, derives parameters to use for the next sample.
 
@@ -69,25 +74,74 @@ class DPQuery(object):
     Returns:
       Parameters to use to process records in the next sample.
     """
-    pass
+    del global_state  # unused.
+    return ()
 
   @abc.abstractmethod
-  def initial_sample_state(self, global_state, tensors):
+  def initial_sample_state(self, global_state, template):
     """Returns an initial state to use for the next sample.
 
     Args:
       global_state: The current global state.
-      tensors: A structure of tensors used as a template to create the initial
-        sample state.
+      template: A nested structure of tensors, TensorSpecs, or numpy arrays used
+        as a template to create the initial sample state. It is assumed that the
+        leaves of the structure are python scalars or some type that has
+        properties `shape` and `dtype`.
 
     Returns: An initial sample state.
     """
     pass
 
+  def preprocess_record(self, params, record):
+    """Preprocesses a single record.
+
+    This preprocessing is applied to one client's record, e.g. selecting vectors
+    and clipping them to a fixed L2 norm. This method can be executed in a
+    separate TF session, or even on a different machine, so it should not depend
+    on any TF inputs other than those provided as input arguments. In
+    particular, implementations should avoid accessing any TF tensors or
+    variables that are stored in self.
+
+    Args:
+      params: The parameters for the sample. In standard DP-SGD training,
+        the clipping norm for the sample's microbatch gradients (i.e.,
+        a maximum norm magnitude to which each gradient is clipped)
+      record: The record to be processed. In standard DP-SGD training,
+        the gradient computed for the examples in one microbatch, which
+        may be the gradient for just one example (for size 1 microbatches).
+
+    Returns:
+      A structure of tensors to be aggregated.
+    """
+    del params  # unused.
+    return record
+
   @abc.abstractmethod
+  def accumulate_preprocessed_record(
+      self, sample_state, preprocessed_record):
+    """Accumulates a single preprocessed record into the sample state.
+
+    This method is intended to only do simple aggregation, typically just a sum.
+    In the future, we might remove this method and replace it with a way to
+    declaratively specify the type of aggregation required.
+
+    Args:
+      sample_state: The current sample state. In standard DP-SGD training,
+        the accumulated sum of previous clipped microbatch gradients.
+      preprocessed_record: The preprocessed record to accumulate.
+
+    Returns:
+      The updated sample state.
+    """
+    pass
+
   def accumulate_record(self, params, sample_state, record):
     """Accumulates a single record into the sample state.
 
+    This is a helper method that simply delegates to `preprocess_record` and
+    `accumulate_preprocessed_record` for the common case when both of those
+    functions run on a single device.
+
     Args:
       params: The parameters for the sample. In standard DP-SGD training,
         the clipping norm for the sample's microbatch gradients (i.e.,
@@ -102,6 +156,21 @@ class DPQuery(object):
       The updated sample state. In standard DP-SGD training, the set of
       previous mcrobatch gradients with the addition of the record argument.
     """
+    preprocessed_record = self.preprocess_record(params, record)
+    return self.accumulate_preprocessed_record(
+        sample_state, preprocessed_record)
+
+  @abc.abstractmethod
+  def merge_sample_states(self, sample_state_1, sample_state_2):
+    """Merges two sample states into a single state.
+
+    Args:
+      sample_state_1: The first sample state to merge.
+      sample_state_2: The second sample state to merge.
+
+    Returns:
+      The merged sample state.
+    """
     pass
 
   @abc.abstractmethod
@@ -123,3 +192,26 @@ class DPQuery(object):
       averaging performed in a manner that guarantees differential privacy.
     """
     pass
+
+
+def zeros_like(arg):
+  """A `zeros_like` function that also works for `tf.TensorSpec`s."""
+  try:
+    arg = tf.convert_to_tensor(arg)
+  except TypeError:
+    pass
+  return tf.zeros(arg.shape, arg.dtype)
+
+
+class SumAggregationDPQuery(DPQuery):
+  """Base class for DPQueries that aggregate via sum."""
+
+  def initial_sample_state(self, global_state, template):
+    del global_state  # unused.
+    return nest.map_structure(zeros_like, template)
+
+  def accumulate_preprocessed_record(self, sample_state, preprocessed_record):
+    return nest.map_structure(tf.add, sample_state, preprocessed_record)
+
+  def merge_sample_states(self, sample_state_1, sample_state_2):
+    return nest.map_structure(tf.add, sample_state_1, sample_state_2)
diff --git a/privacy/dp_query/gaussian_query.py b/privacy/dp_query/gaussian_query.py
index bb2f6f8..07817f0 100644
--- a/privacy/dp_query/gaussian_query.py
+++ b/privacy/dp_query/gaussian_query.py
@@ -31,7 +31,7 @@ else:
   nest = tf.nest
 
 
-class GaussianSumQuery(dp_query.DPQuery):
+class GaussianSumQuery(dp_query.SumAggregationDPQuery):
   """Implements DPQuery interface for Gaussian sum queries.
 
   Accumulates clipped vectors, then adds Gaussian noise to the sum.
@@ -50,31 +50,10 @@ class GaussianSumQuery(dp_query.DPQuery):
     self._stddev = tf.cast(stddev, tf.float32)
     self._ledger = ledger
 
-  def initial_global_state(self):
-    """Returns the initial global state for the GaussianSumQuery."""
-    return None
-
   def derive_sample_params(self, global_state):
-    """Given the global state, derives parameters to use for the next sample.
-
-    Args:
-      global_state: The current global state.
-
-    Returns:
-      Parameters to use to process records in the next sample.
-    """
     return self._l2_norm_clip
 
-  def initial_sample_state(self, global_state, tensors):
-    """Returns an initial state to use for the next sample.
-
-    Args:
-      global_state: The current global state.
-      tensors: A structure of tensors used as a template to create the initial
-        sample state.
-
-    Returns: An initial sample state.
-    """
+  def initial_sample_state(self, global_state, template):
     if self._ledger:
       dependencies = [
           self._ledger.record_sum_query(self._l2_norm_clip, self._stddev)
@@ -82,51 +61,32 @@ class GaussianSumQuery(dp_query.DPQuery):
     else:
       dependencies = []
     with tf.control_dependencies(dependencies):
-      return nest.map_structure(tf.zeros_like, tensors)
+      return nest.map_structure(
+          dp_query.zeros_like, template)
 
-  def accumulate_record_impl(self, params, sample_state, record):
-    """Accumulates a single record into the sample state.
+  def preprocess_record_impl(self, params, record):
+    """Clips the l2 norm, returning the clipped record and the l2 norm.
 
     Args:
       params: The parameters for the sample.
-      sample_state: The current sample state.
-      record: The record to accumulate.
+      record: The record to be processed.
 
     Returns:
-      A tuple containing the updated sample state and the global norm.
+      A tuple (preprocessed_records, l2_norm) where `preprocessed_records` is
+        the structure of preprocessed tensors, and l2_norm is the total l2 norm
+        before clipping.
     """
     l2_norm_clip = params
     record_as_list = nest.flatten(record)
     clipped_as_list, norm = tf.clip_by_global_norm(record_as_list, l2_norm_clip)
-    clipped = nest.pack_sequence_as(record, clipped_as_list)
-    return nest.map_structure(tf.add, sample_state, clipped), norm
+    return nest.pack_sequence_as(record, clipped_as_list), norm
 
-  def accumulate_record(self, params, sample_state, record):
-    """Accumulates a single record into the sample state.
-
-    Args:
-      params: The parameters for the sample.
-      sample_state: The current sample state.
-      record: The record to accumulate.
-
-    Returns:
-      The updated sample state.
-    """
-    new_sample_state, _ = self.accumulate_record_impl(
-        params, sample_state, record)
-    return new_sample_state
+  def preprocess_record(self, params, record):
+    preprocessed_record, _ = self.preprocess_record_impl(params, record)
+    return preprocessed_record
 
   def get_noised_result(self, sample_state, global_state):
-    """Gets noised sum after all records of sample have been accumulated.
-
-    Args:
-      sample_state: The sample state after all records have been accumulated.
-      global_state: The global state.
-
-    Returns:
-      A tuple (estimate, new_global_state) where "estimate" is the estimated
-      sum of the records and "new_global_state" is the updated global state.
-    """
+    """See base class."""
     if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
       def add_noise(v):
         return v + tf.random_normal(tf.shape(v), stddev=self._stddev)
diff --git a/privacy/dp_query/gaussian_query_test.py b/privacy/dp_query/gaussian_query_test.py
index 08107e6..e2a1db0 100644
--- a/privacy/dp_query/gaussian_query_test.py
+++ b/privacy/dp_query/gaussian_query_test.py
@@ -91,6 +91,32 @@ class GaussianQueryTest(tf.test.TestCase, parameterized.TestCase):
       result_stddev = np.std(noised_sums)
       self.assertNear(result_stddev, stddev, 0.1)
 
+  def test_gaussian_sum_merge(self):
+    records1 = [tf.constant([2.0, 0.0]), tf.constant([-1.0, 1.0])]
+    records2 = [tf.constant([3.0, 5.0]), tf.constant([-1.0, 4.0])]
+
+    def get_sample_state(records):
+      query = gaussian_query.GaussianSumQuery(l2_norm_clip=10.0, stddev=1.0)
+      global_state = query.initial_global_state()
+      params = query.derive_sample_params(global_state)
+      sample_state = query.initial_sample_state(global_state, records[0])
+      for record in records:
+        sample_state = query.accumulate_record(params, sample_state, record)
+      return sample_state
+
+    sample_state_1 = get_sample_state(records1)
+    sample_state_2 = get_sample_state(records2)
+
+    merged = gaussian_query.GaussianSumQuery(10.0, 1.0).merge_sample_states(
+        sample_state_1,
+        sample_state_2)
+
+    with self.cached_session() as sess:
+      result = sess.run(merged)
+
+    expected = [3.0, 10.0]
+    self.assertAllClose(result, expected)
+
   def test_gaussian_average_no_noise(self):
     with self.cached_session() as sess:
       record1 = tf.constant([5.0, 0.0])   # Clipped to [3.0, 0.0].
diff --git a/privacy/dp_query/nested_query.py b/privacy/dp_query/nested_query.py
index 8e10c09..de5aa08 100644
--- a/privacy/dp_query/nested_query.py
+++ b/privacy/dp_query/nested_query.py
@@ -56,52 +56,39 @@ class NestedQuery(dp_query.DPQuery):
     """
     self._queries = queries
 
-  def _map_to_queries(self, fn, *inputs):
+  def _map_to_queries(self, fn, *inputs, **kwargs):
     def caller(query, *args):
-      return getattr(query, fn)(*args)
+      return getattr(query, fn)(*args, **kwargs)
     return nest.map_structure_up_to(
         self._queries, caller, self._queries, *inputs)
 
   def initial_global_state(self):
-    """Returns the initial global state for the NestedQuery."""
+    """See base class."""
     return self._map_to_queries('initial_global_state')
 
   def derive_sample_params(self, global_state):
-    """Given the global state, derives parameters to use for the next sample.
-
-    Args:
-      global_state: The current global state.
-
-    Returns:
-      Parameters to use to process records in the next sample.
-    """
+    """See base class."""
     return self._map_to_queries('derive_sample_params', global_state)
 
-  def initial_sample_state(self, global_state, tensors):
-    """Returns an initial state to use for the next sample.
+  def initial_sample_state(self, global_state, template):
+    """See base class."""
+    return self._map_to_queries('initial_sample_state', global_state, template)
 
-    Args:
-      global_state: The current global state.
-      tensors: A structure of tensors used as a template to create the initial
-        sample state.
+  def preprocess_record(self, params, record):
+    """See base class."""
+    return self._map_to_queries('preprocess_record', params, record)
 
-    Returns: An initial sample state.
-    """
-    return self._map_to_queries('initial_sample_state', global_state, tensors)
-
-  def accumulate_record(self, params, sample_state, record):
-    """Accumulates a single record into the sample state.
-
-    Args:
-      params: The parameters for the sample.
-      sample_state: The current sample state.
-      record: The record to accumulate.
-
-    Returns:
-      The updated sample state.
-    """
+  def accumulate_preprocessed_record(
+      self, sample_state, preprocessed_record):
+    """See base class."""
     return self._map_to_queries(
-        'accumulate_record', params, sample_state, record)
+        'accumulate_preprocessed_record',
+        sample_state,
+        preprocessed_record)
+
+  def merge_sample_states(self, sample_state_1, sample_state_2):
+    return self._map_to_queries(
+        'merge_sample_states', sample_state_1, sample_state_2)
 
   def get_noised_result(self, sample_state, global_state):
     """Gets query result after all records of sample have been accumulated.
diff --git a/privacy/dp_query/no_privacy_query.py b/privacy/dp_query/no_privacy_query.py
index a40e2c4..449b970 100644
--- a/privacy/dp_query/no_privacy_query.py
+++ b/privacy/dp_query/no_privacy_query.py
@@ -28,78 +28,44 @@ else:
   nest = tf.nest
 
 
-class NoPrivacySumQuery(dp_query.DPQuery):
+class NoPrivacySumQuery(dp_query.SumAggregationDPQuery):
   """Implements DPQuery interface for a sum query with no privacy.
 
   Accumulates vectors without clipping or adding noise.
   """
 
-  def initial_global_state(self):
-    """Returns the initial global state for the NoPrivacySumQuery."""
-    return None
-
-  def derive_sample_params(self, global_state):
-    """See base class."""
-    del global_state  # unused.
-    return None
-
-  def initial_sample_state(self, global_state, tensors):
-    """See base class."""
-    del global_state  # unused.
-    return nest.map_structure(tf.zeros_like, tensors)
-
-  def accumulate_record(self, params, sample_state, record, weight=1):
-    """See base class. Optional argument for weighted sum queries."""
-    del params  # unused.
-
-    def add_weighted(state_tensor, record_tensor):
-      return tf.add(state_tensor, weight * record_tensor)
-
-    return nest.map_structure(add_weighted, sample_state, record)
-
   def get_noised_result(self, sample_state, global_state):
     """See base class."""
     return sample_state, global_state
 
 
-class NoPrivacyAverageQuery(dp_query.DPQuery):
+class NoPrivacyAverageQuery(dp_query.SumAggregationDPQuery):
   """Implements DPQuery interface for an average query with no privacy.
 
   Accumulates vectors and normalizes by the total number of accumulated vectors.
   """
 
-  def __init__(self):
-    """Initializes the NoPrivacyAverageQuery."""
-    self._numerator = NoPrivacySumQuery()
-
-  def initial_global_state(self):
-    """Returns the initial global state for the NoPrivacyAverageQuery."""
-    return self._numerator.initial_global_state()
-
-  def derive_sample_params(self, global_state):
+  def initial_sample_state(self, global_state, template):
     """See base class."""
-    del global_state  # unused.
-    return None
+    return (
+        super(NoPrivacyAverageQuery, self).initial_sample_state(
+            global_state, template),
+        tf.constant(0.0))
 
-  def initial_sample_state(self, global_state, tensors):
-    """See base class."""
-    return self._numerator.initial_sample_state(global_state, tensors), 0.0
+  def preprocess_record(self, params, record, weight=1):
+    """Multiplies record by weight."""
+    weighted_record = nest.map_structure(lambda t: weight * t, record)
+    return (weighted_record, weight)
 
   def accumulate_record(self, params, sample_state, record, weight=1):
-    """See base class. Optional argument for weighted average queries."""
-    sum_sample_state, denominator = sample_state
-    return (
-        self._numerator.accumulate_record(
-            params, sum_sample_state, record, weight),
-        tf.add(denominator, weight))
+    """Accumulates record, multiplying by weight."""
+    weighted_record = nest.map_structure(lambda t: weight * t, record)
+    return self.accumulate_preprocessed_record(
+        sample_state, (weighted_record, weight))
 
   def get_noised_result(self, sample_state, global_state):
     """See base class."""
-    sum_sample_state, denominator = sample_state
-    exact_sum, new_global_state = self._numerator.get_noised_result(
-        sum_sample_state, global_state)
+    sum_state, denominator = sample_state
 
-    def normalize(v):
-      return tf.truediv(v, denominator)
-
-    return nest.map_structure(normalize, exact_sum), new_global_state
+    return nest.map_structure(
+        lambda t: tf.truediv(t, denominator), sum_state), ()
diff --git a/privacy/dp_query/no_privacy_query_test.py b/privacy/dp_query/no_privacy_query_test.py
index f408dc5..ed32e60 100644
--- a/privacy/dp_query/no_privacy_query_test.py
+++ b/privacy/dp_query/no_privacy_query_test.py
@@ -1,4 +1,4 @@
-# Copyright 2018, The TensorFlow Authors.
+# Copyright 2019, The TensorFlow Authors.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -27,7 +27,7 @@ from privacy.dp_query import test_utils
 
 class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
 
-  def test_no_privacy_sum(self):
+  def test_sum(self):
     with self.cached_session() as sess:
       record1 = tf.constant([2.0, 0.0])
       record2 = tf.constant([-1.0, 1.0])
@@ -38,20 +38,6 @@ class NoPrivacyQueryTest(tf.test.TestCase, parameterized.TestCase):
       expected = [1.0, 1.0]
       self.assertAllClose(result, expected)
 
-  def test_no_privacy_weighted_sum(self):
-    with self.cached_session() as sess:
-      record1 = tf.constant([2.0, 0.0])
-      record2 = tf.constant([-1.0, 1.0])
-
-      weights = [1, 2]
-
-      query = no_privacy_query.NoPrivacySumQuery()
-      query_result, _ = test_utils.run_query(
-          query, [record1, record2], weights=weights)
-      result = sess.run(query_result)
-      expected = [0.0, 2.0]
-      self.assertAllClose(result, expected)
-
   def test_no_privacy_average(self):
     with self.cached_session() as sess:
       record1 = tf.constant([5.0, 0.0])
diff --git a/privacy/dp_query/normalized_query.py b/privacy/dp_query/normalized_query.py
index 8d2df58..0cc73c4 100644
--- a/privacy/dp_query/normalized_query.py
+++ b/privacy/dp_query/normalized_query.py
@@ -38,65 +38,39 @@ class NormalizedQuery(dp_query.DPQuery):
 
     Args:
       numerator_query: A DPQuery for the numerator.
-      denominator: A value for the denominator.
+      denominator: A value for the denominator. May be None if it will be
+        supplied via the set_denominator function before get_noised_result is
+        called.
     """
     self._numerator = numerator_query
-    self._denominator = tf.cast(denominator,
-                                tf.float32) if denominator is not None else None
+    self._denominator = (
+        tf.cast(denominator, tf.float32) if denominator is not None else None)
 
   def initial_global_state(self):
-    """Returns the initial global state for the NormalizedQuery."""
+    """See base class."""
     # NormalizedQuery has no global state beyond the numerator state.
     return self._numerator.initial_global_state()
 
   def derive_sample_params(self, global_state):
-    """Given the global state, derives parameters to use for the next sample.
-
-    Args:
-      global_state: The current global state.
-
-    Returns:
-      Parameters to use to process records in the next sample.
-    """
+    """See base class."""
     return self._numerator.derive_sample_params(global_state)
 
-  def initial_sample_state(self, global_state, tensors):
-    """Returns an initial state to use for the next sample.
-
-    Args:
-      global_state: The current global state.
-      tensors: A structure of tensors used as a template to create the initial
-        sample state.
-
-    Returns: An initial sample state.
-    """
+  def initial_sample_state(self, global_state, template):
+    """See base class."""
     # NormalizedQuery has no sample state beyond the numerator state.
-    return self._numerator.initial_sample_state(global_state, tensors)
+    return self._numerator.initial_sample_state(global_state, template)
 
-  def accumulate_record(self, params, sample_state, record):
-    """Accumulates a single record into the sample state.
+  def preprocess_record(self, params, record):
+    return self._numerator.preprocess_record(params, record)
 
-    Args:
-      params: The parameters for the sample.
-      sample_state: The current sample state.
-      record: The record to accumulate.
-
-    Returns:
-      The updated sample state.
-    """
-    return self._numerator.accumulate_record(params, sample_state, record)
+  def accumulate_preprocessed_record(
+      self, sample_state, preprocessed_record):
+    """See base class."""
+    return self._numerator.accumulate_preprocessed_record(
+        sample_state, preprocessed_record)
 
   def get_noised_result(self, sample_state, global_state):
-    """Gets noised average after all records of sample have been accumulated.
-
-    Args:
-      sample_state: The sample state after all records have been accumulated.
-      global_state: The global state.
-
-    Returns:
-      A tuple (estimate, new_global_state) where "estimate" is the estimated
-      average of the records and "new_global_state" is the updated global state.
-    """
+    """See base class."""
     noised_sum, new_sum_global_state = self._numerator.get_noised_result(
         sample_state, global_state)
     def normalize(v):
@@ -104,5 +78,10 @@ class NormalizedQuery(dp_query.DPQuery):
 
     return nest.map_structure(normalize, noised_sum), new_sum_global_state
 
+  def merge_sample_states(self, sample_state_1, sample_state_2):
+    """See base class."""
+    return self._numerator.merge_sample_states(sample_state_1, sample_state_2)
+
   def set_denominator(self, denominator):
+    """Sets the denominator for the NormalizedQuery."""
     self._denominator = tf.cast(denominator, tf.float32)