⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions sdks/python/apache_beam/runners/interactive/interactive_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,8 @@ def collect(
options=None,
force_compute=False,
force_tuple=False,
raw_records=False):
raw_records=False,
wait_for_inputs=True):
"""Materializes the elements from a PCollection into a Dataframe.

This reads each element from file and reads only the amount that it needs
Expand All @@ -903,6 +904,10 @@ def collect(
the bare results if only one PCollection is computed
raw_records: (optional) if True, return a list of collected records
without converting to a DataFrame. Default False.
wait_for_inputs: Whether to wait until the asynchronous dependencies are
computed. Setting this to False allows to immediately schedule the
computation, but also potentially results in running the same pipeline
stages multiple times.

For example::

Expand Down Expand Up @@ -980,7 +985,8 @@ def as_pcollection(pcoll_or_df):
max_duration=duration,
runner=runner,
options=options,
force_compute=force_compute)
force_compute=force_compute,
wait_for_inputs=wait_for_inputs)

try:
for pcoll in uncomputed:
Expand Down
109 changes: 109 additions & 0 deletions sdks/python/apache_beam/runners/interactive/interactive_beam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,115 @@ def test_collect_raw_records_true_force_tuple(self):
self.assertIsInstance(result[0], list)
self.assertEqual(result[0], data)

def test_collect_wait_for_inputs_true(self):
with patch('apache_beam.runners.interactive.interactive_beam.ie.current_env'
) as mock_current_env:
mock_env = MagicMock()
mock_current_env.return_value = mock_env
mock_rm = MagicMock()
mock_env.get_recording_manager.return_value = mock_rm
mock_env.computed_pcollections = set()
mock_env.user_pipeline.side_effect = lambda x: x

p = beam.Pipeline(ir.InteractiveRunner())
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2)

# Simulate pcoll1 being computed asynchronously
mock_env.is_pcollection_computing.return_value = True
async_res = MagicMock(spec=AsyncComputationResult)
mock_rm._async_computations = {'id1': async_res}
mock_rm._get_all_dependencies.return_value = {pcoll1}
mock_rm._wait_for_dependencies.return_value = True

# Set up return value for record
mock_recording = MagicMock()
mock_rm.record.return_value = mock_recording

ib.collect(pcoll2, wait_for_inputs=True)

# Check that record was called with wait_for_inputs=True
mock_rm.record.assert_called_once_with({pcoll2},
max_n=float('inf'),
max_duration=float('inf'),
runner=None,
options=None,
force_compute=False,
wait_for_inputs=True)

def test_collect_wait_for_inputs_false(self):
with patch('apache_beam.runners.interactive.interactive_beam.ie.current_env'
) as mock_current_env:
mock_env = MagicMock()
mock_current_env.return_value = mock_env
mock_rm = MagicMock()
mock_env.get_recording_manager.return_value = mock_rm
mock_env.computed_pcollections = set()
mock_env.user_pipeline.side_effect = lambda x: x

p = beam.Pipeline(ir.InteractiveRunner())
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2)

# Simulate pcoll1 being computed asynchronously
mock_env.is_pcollection_computing.return_value = True
async_res = MagicMock(spec=AsyncComputationResult)
mock_rm._async_computations = {'id1': async_res}
mock_rm._get_all_dependencies.return_value = {pcoll1}

# Set up return value for record
mock_recording = MagicMock()
mock_rm.record.return_value = mock_recording

ib.collect(pcoll2, wait_for_inputs=False)

# Check that wait_for_dependencies was NOT called
mock_rm._wait_for_dependencies.assert_not_called()
# Check that record was called with wait_for_inputs=False
mock_rm.record.assert_called_once_with({pcoll2},
max_n=float('inf'),
max_duration=float('inf'),
runner=None,
options=None,
force_compute=False,
wait_for_inputs=False)

def test_collect_wait_for_inputs_default(self):
with patch('apache_beam.runners.interactive.interactive_beam.ie.current_env'
) as mock_current_env:
mock_env = MagicMock()
mock_current_env.return_value = mock_env
mock_rm = MagicMock()
mock_env.get_recording_manager.return_value = mock_rm
mock_env.computed_pcollections = set()
mock_env.user_pipeline.side_effect = lambda x: x

p = beam.Pipeline(ir.InteractiveRunner())
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
pcoll2 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2)

# Simulate pcoll1 being computed asynchronously
mock_env.is_pcollection_computing.return_value = True
async_res = MagicMock(spec=AsyncComputationResult)
mock_rm._async_computations = {'id1': async_res}
mock_rm._get_all_dependencies.return_value = {pcoll1}
mock_rm._wait_for_dependencies.return_value = True

# Set up return value for record
mock_recording = MagicMock()
mock_rm.record.return_value = mock_recording

ib.collect(pcoll2) # wait_for_inputs defaults to True

# Check that record was called with wait_for_inputs=True
mock_rm.record.assert_called_once_with({pcoll2},
max_n=float('inf'),
max_duration=float('inf'),
runner=None,
options=None,
force_compute=False,
wait_for_inputs=True)


@unittest.skipIf(
not ie.current_env().is_interactive_ready,
Expand Down
12 changes: 7 additions & 5 deletions sdks/python/apache_beam/runners/interactive/recording_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,8 @@ def record(
max_duration: Union[int, str],
runner: runner.PipelineRunner = None,
options: pipeline_options.PipelineOptions = None,
force_compute: bool = False) -> Recording:
force_compute: bool = False,
wait_for_inputs: bool = True) -> Recording:
# noqa: F821

"""Records the given PCollections."""
Expand Down Expand Up @@ -886,10 +887,11 @@ def record(
# Start a pipeline fragment to start computing the PCollections.
uncomputed_pcolls = set(pcolls).difference(computed_pcolls)
if uncomputed_pcolls:
if not self._wait_for_dependencies(uncomputed_pcolls):
raise RuntimeError(
'Cannot record because a dependency failed to compute'
' asynchronously.')
if wait_for_inputs:
if not self._wait_for_dependencies(uncomputed_pcolls):
raise RuntimeError(
'Cannot record because a dependency failed to compute'
' asynchronously.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help me understand what use case you're trying to solve for here? If I'm reading the PR right, you're not changing the default behavior, but you are letting the recording manager get into a potentially bad state if the user intentionally sets wait_for_inputs to false. Its not clear to me why this is desireable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Danny, my thinking here is:

  1. Previously, ib.collect() will not automatically wait for background caching job to finish. If a user ran collect() on a PCollection whose dependencies were still computing, they could get empty or partial results without warning.
  2. By defaulting wait_for_inputs=True, we ensure the standard user experience is consistent: we always wait for upstream dependencies to finish before collecting. Also, adding wait_for_inputs option will align with current implementation for ib.compute()
  3. Back to the exact code block that you are quoting, the 'bad state' happens when user decide to set wait_for_inputs=False, explicitly request to bypass the safety checks and synchronization. My thinking is it will delegate the failure handling to actual pipeline execution, which seems acceptable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By defaulting wait_for_inputs=True, we ensure the standard user experience is consistent: we always wait for upstream dependencies to finish before collecting. Also, adding wait_for_inputs option will align with current implementation for ib.compute()

We are not changing the default behavior in this PR. So I don't think this is doing what you think it is doing.

Back to the exact code block that you are quoting, the 'bad state' happens when user decide to set wait_for_inputs=False, explicitly request to bypass the safety checks and synchronization. My thinking is it will delegate the failure handling to actual pipeline execution, which seems acceptable

When is this desirable? It seems like it is always a bad outcome


self._clear()

Expand Down
Loading