Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions temporalio/bridge/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,21 +293,21 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:

async def decode_activation(
activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation,
codec: temporalio.converter.PayloadCodec,
data_converter: temporalio.converter.DataConverter,
decode_headers: bool,
) -> None:
"""Decode all payloads in the activation."""
await CommandAwarePayloadVisitor(
skip_search_attributes=True, skip_headers=not decode_headers
).visit(_Visitor(codec.decode), activation)
).visit(_Visitor(data_converter._decode_payload_sequence), activation)


async def encode_completion(
completion: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
codec: temporalio.converter.PayloadCodec,
data_converter: temporalio.converter.DataConverter,
encode_headers: bool,
) -> None:
"""Encode all payloads in the completion."""
await CommandAwarePayloadVisitor(
skip_search_attributes=True, skip_headers=not encode_headers
).visit(_Visitor(codec.encode), completion)
).visit(_Visitor(data_converter._encode_payload_sequence), completion)
94 changes: 24 additions & 70 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2977,10 +2977,7 @@ async def memo(self) -> Mapping[str, Any]:
Returns:
Mapping of all memo keys and they values without type hints.
"""
return {
k: (await self.data_converter.decode([v]))[0]
for k, v in self.raw_info.memo.fields.items()
}
return await self.data_converter._convert_from_memo(self.raw_info.memo)

@overload
async def memo_value(
Expand Down Expand Up @@ -3019,16 +3016,9 @@ async def memo_value(
Raises:
KeyError: Key not present and default not set.
"""
payload = self.raw_info.memo.fields.get(key)
if not payload:
if default is temporalio.common._arg_unset:
raise KeyError(f"Memo does not have a value for key {key}")
return default
return (
await self.data_converter.decode(
[payload], [type_hint] if type_hint else None
)
)[0]
return await self.data_converter._convert_from_memo_field(
self.raw_info.memo, key, default, type_hint
)


@dataclass
Expand Down Expand Up @@ -4209,18 +4199,9 @@ async def _to_proto(
workflow_run_timeout=run_timeout,
workflow_task_timeout=task_timeout,
retry_policy=retry_policy,
memo=(
temporalio.api.common.v1.Memo(
fields={
k: v
if isinstance(v, temporalio.api.common.v1.Payload)
else (await data_converter.encode([v]))[0]
for k, v in self.memo.items()
},
)
if self.memo
else None
),
memo=await data_converter._convert_to_memo(self.memo)
if self.memo
else None,
user_metadata=await _encode_user_metadata(
data_converter, self.static_summary, self.static_details
),
Expand Down Expand Up @@ -4249,7 +4230,7 @@ async def _to_proto(
client.config(active_config=True)["header_codec_behavior"]
== HeaderCodecBehavior.CODEC
and not self._from_raw,
client.data_converter.payload_codec,
client.data_converter,
)
return action

Expand Down Expand Up @@ -4521,10 +4502,7 @@ async def memo(self) -> Mapping[str, Any]:
Returns:
Mapping of all memo keys and they values without type hints.
"""
return {
k: (await self.data_converter.decode([v]))[0]
for k, v in self.raw_description.memo.fields.items()
}
return await self.data_converter._convert_from_memo(self.raw_description.memo)

@overload
async def memo_value(
Expand Down Expand Up @@ -4563,16 +4541,9 @@ async def memo_value(
Raises:
KeyError: Key not present and default not set.
"""
payload = self.raw_description.memo.fields.get(key)
if not payload:
if default is temporalio.common._arg_unset:
raise KeyError(f"Memo does not have a value for key {key}")
return default
return (
await self.data_converter.decode(
[payload], [type_hint] if type_hint else None
)
)[0]
return await self.data_converter._convert_from_memo_field(
self.raw_description.memo, key, default, type_hint
)


@dataclass
Expand Down Expand Up @@ -4770,10 +4741,7 @@ async def memo(self) -> Mapping[str, Any]:
Returns:
Mapping of all memo keys and they values without type hints.
"""
return {
k: (await self.data_converter.decode([v]))[0]
for k, v in self.raw_entry.memo.fields.items()
}
return await self.data_converter._convert_from_memo(self.raw_entry.memo)

@overload
async def memo_value(
Expand Down Expand Up @@ -4812,16 +4780,9 @@ async def memo_value(
Raises:
KeyError: Key not present and default not set.
"""
payload = self.raw_entry.memo.fields.get(key)
if not payload:
if default is temporalio.common._arg_unset:
raise KeyError(f"Memo does not have a value for key {key}")
return default
return (
await self.data_converter.decode(
[payload], [type_hint] if type_hint else None
)
)[0]
return await self.data_converter._convert_from_memo_field(
self.raw_entry.memo, key, default, type_hint
)


@dataclass
Expand Down Expand Up @@ -6014,8 +5975,7 @@ async def _populate_start_workflow_execution_request(
input.retry_policy.apply_to_proto(req.retry_policy)
req.cron_schedule = input.cron_schedule
if input.memo is not None:
for k, v in input.memo.items():
req.memo.fields[k].CopyFrom((await data_converter.encode([v]))[0])
await data_converter._convert_into_memo(input.memo, req.memo)
if input.search_attributes is not None:
temporalio.converter.encode_search_attributes(
input.search_attributes, req.search_attributes
Expand Down Expand Up @@ -6641,14 +6601,9 @@ async def create_schedule(self, input: CreateScheduleInput) -> ScheduleHandle:
initial_patch=initial_patch,
identity=self._client.identity,
request_id=str(uuid.uuid4()),
memo=None
if not input.memo
else temporalio.api.common.v1.Memo(
fields={
k: (await self._client.data_converter.encode([v]))[0]
for k, v in input.memo.items()
},
),
memo=await self._client.data_converter._convert_to_memo(input.memo)
if input.memo
else None,
)
if input.search_attributes:
temporalio.converter.encode_search_attributes(
Expand Down Expand Up @@ -6870,22 +6825,21 @@ async def _apply_headers(
dest,
self._client.config(active_config=True)["header_codec_behavior"]
== HeaderCodecBehavior.CODEC,
self._client.data_converter.payload_codec,
self._client.data_converter,
)


async def _apply_headers(
source: Mapping[str, temporalio.api.common.v1.Payload] | None,
dest: MessageMap[str, temporalio.api.common.v1.Payload],
encode_headers: bool,
codec: temporalio.converter.PayloadCodec | None,
data_converter: DataConverter,
) -> None:
if source is None:
return
if encode_headers and codec is not None:
if encode_headers:
for payload in source.values():
new_payload = (await codec.encode([payload]))[0]
payload.CopyFrom(new_payload)
payload.CopyFrom(await data_converter._encode_payload(payload))
temporalio.common._apply_headers(source, dest)


Expand Down
Loading
Loading